## This notebook is for Data and MC corrections:
1. Trigger Efficiencies
2. Leptom effeciencies

## TRIGGER EFFICIENCY 

### Only for data

In [1]:
import os
import sys
import time
import gc 
import psutil
import json
from pathlib import Path

import uproot
import awkward as ak
import numpy as np

import vector
vector.register_awkward()

import dask
from dask.distributed import Client

print("All imports added")

All imports added


In [2]:
client = Client("tls://localhost:8786")
client

0,1
Connection method: Direct,
Dashboard: /user/anujraghav.physics@gmail.com/proxy/8787/status,

0,1
Comm: tls://192.168.161.139:8786,Workers: 0
Dashboard: /user/anujraghav.physics@gmail.com/proxy/8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


In [3]:
HOME_DIR = Path(os.environ.get("HOME", "/home/cms-jovyan"))
PROJECT_NAME = "H-to-WW-NanoAOD-analysis"

PROJECT_DIR = HOME_DIR / PROJECT_NAME
DATASETS_DIR = PROJECT_DIR / "Datasets"
DATA_DIR = DATASETS_DIR / "DATA"
MC_DIR = DATASETS_DIR / "MC_samples"
AUX_DIR = PROJECT_DIR / "Auxillary_files"

GOLDEN_JSON_PATH = AUX_DIR / "Cert_271036-284044_13TeV_Legacy2016_Collisions16_JSON.txt"

RUN_PERIODS_2016 = {
    "Run2016G": {"run_min": 278820, "run_max": 280385},
    "Run2016H": {"run_min": 280919, "run_max": 284044}
}

print(f"HOME_DIR:         {HOME_DIR}")
print(f"PROJECT_DIR:     {PROJECT_DIR}")
print(f"DATA_DIR:        {DATA_DIR}")
print(f"MC_DIR:          {MC_DIR}")
print(f"AUX_DIR:         {AUX_DIR}")
print(f"GOLDEN_JSON:      {GOLDEN_JSON_PATH}")
print(f"JSON exists:     {GOLDEN_JSON_PATH.exists()}")


HOME_DIR:         /home/cms-jovyan
PROJECT_DIR:     /home/cms-jovyan/H-to-WW-NanoAOD-analysis
DATA_DIR:        /home/cms-jovyan/H-to-WW-NanoAOD-analysis/Datasets/DATA
MC_DIR:          /home/cms-jovyan/H-to-WW-NanoAOD-analysis/Datasets/MC_samples
AUX_DIR:         /home/cms-jovyan/H-to-WW-NanoAOD-analysis/Auxillary_files
GOLDEN_JSON:      /home/cms-jovyan/H-to-WW-NanoAOD-analysis/Auxillary_files/Cert_271036-284044_13TeV_Legacy2016_Collisions16_JSON.txt
JSON exists:     True


In [4]:
SAMPLE_MAPPING = {
    'data' : "Data",
    # 'dytoll' : "DY_to_Tau_Tau",
}

def load_urls_from_files(filepath, max_files = None):
    urls = []

    if not os.path.exists(filepath):
        return urls

    with open(filepath, 'r') as f:
        for line in f:
            line = line.strip()
            if line and line.startswith('root://'):
                urls.append(line)
                if max_files and len(urls) >= max_files:
                    break
    return urls

def load_all_files(data_dir, mc_dir, max_per_sample = None):

    files_dict = {}

    for directory in [data_dir, mc_dir]:
        if not os.path.exists(directory):
            continue

        for filename in os.listdir(directory):
            if not filename.endswith(".txt"):
                continue

            filepath = os.path.join(directory, filename)
            filename_lower = filename.lower().replace('.txt', '')

            label = None

            for pattern, sample_label in SAMPLE_MAPPING.items():
                if pattern in filename_lower:
                    label = sample_label
                    break

            if not label:
                print(f" unknown file: {filename}- skipping")
                continue

            urls = load_urls_from_files(filepath, max_per_sample)

            if urls: 
                if label in files_dict:
                    files_dict[label].extend(urls)
                else:
                    files_dict[label] =urls

    return files_dict

files = load_all_files(DATA_DIR, MC_DIR, max_per_sample= 1)
# files = load_all_files(DATA_DIR, MC_DIR)

print("\n" + "="*70)
print("FILES TO PROCESS")
print("="*70)
total = 0
for label, urls in files.items():
    print(f"{label:20s}: {len(urls):4d} files")
    total += len(urls)
print("_"*70)
print(f"{'TOTAL':20s}: {total:4d} files")
print("="*70)

 unknown file: VG.txt- skipping
 unknown file: Higgs.txt- skipping
 unknown file: WW.txt- skipping
 unknown file: Fakes.txt- skipping
 unknown file: VZ.txt- skipping
 unknown file: DYtoLL.txt- skipping
 unknown file: ggWW.txt- skipping
 unknown file: Top.txt- skipping

FILES TO PROCESS
Data                :    1 files
______________________________________________________________________
TOTAL               :    1 files


In [5]:
def load_golden_json(json_input, run_periods=None):
    """
    Load golden JSON from either a file path (str) or a dict.
    """
    
    if isinstance(json_input, str):
        with open(json_input, 'r') as f:
            golden_json = json.load(f)
    elif isinstance(json_input, dict):
        golden_json = json_input
    else:
        raise TypeError(f"Expected str or dict, got {type(json_input)}")
    
    valid_lumis = {}
    for run_str, lumi_ranges in golden_json.items():
        run = int(run_str)
        
        # Filter by run periods 
        if run_periods is not None: 
            in_period = any(
                period['run_min'] <= run <= period['run_max']
                for period in run_periods.values()
            )
            if not in_period:
                continue
        
        valid_lumis[run] = [tuple(lr) for lr in lumi_ranges]
    
    return valid_lumis


def apply_json_mask(arrays, json_input, run_periods=None):

    valid_lumis = load_golden_json(json_input, run_periods)
    
    runs = ak.to_numpy(arrays.run)
    lumis = ak.to_numpy(arrays.luminosityBlock)
    
    mask = np. zeros(len(runs), dtype=bool)
    
    for run, lumi_ranges in valid_lumis.items():
        run_mask = (runs == run)
        
        if not np.any(run_mask):
            continue
        
        # Check lumi sections 
        run_lumis = lumis[run_mask]
        run_lumi_mask = np.zeros(len(run_lumis), dtype=bool)
        
        for lumi_start, lumi_end in lumi_ranges: 
            run_lumi_mask |= (run_lumis >= lumi_start) & (run_lumis <= lumi_end)
        
        mask[run_mask] = run_lumi_mask
    
    return ak.Array(mask)

In [6]:
# #get name of the branch required for trigger efficiency 

# # DATA
# root_file_name = "root://eospublic.cern.ch//eos/opendata/cms/Run2016G/MuonEG/NANOAOD/UL2016_MiniAODv2_NanoAODv9-v1/120000/2ADBED61-A06A-D64B-BE90-E9B267D15700.root"

# #  MC 

# # file_url = "root://eospublic.cern.ch//eos/opendata/cms/mc/RunIISummer20UL16NanoAODv9/DYJetsToLL_M-50_TuneCP5_13TeV-madgraphMLM-pythia8/NANOAODSIM/106X_mcRun2_asymptotic_v17-v1/40000/14B6A8AE-C9FE-D744-80A4-DDE5D008C1CD.root"

# with uproot.open(root_file_name) as file:
#         # Access the Events tree
#         if "Events" not in file:
#             print("Error: 'Events' tree not found in file.")
#         else:
#             tree = file["Events"]
#             branches = tree.keys()
            
#             print(f"\nConnection Successful!")
#             print(f"Total Branches found: {len(branches)}")
#             print("=" * 60)
            
#             # Print all branches alphabetically
#             for branch in sorted(branches):
#                 if "HLT_Mu12_TrkIsoVVL_Ele23_CaloIdL_TrackIdL_IsoVL_DZ" in branch or "HLT_Mu23_TrkIsoVVL_Ele12_CaloIdL_TrackIdL_IsoVL_DZ"  in branch:
#                     print(branch)


In [7]:
Batch_size = 1_250_000

def load_events(file_url, batch_size=1_250_000, timeout=600, max_retries=3, retry_wait=10, is_data=False):
    columns = [
        "Electron_pt", "Electron_eta", "Electron_phi", "Electron_mass", 
        "Electron_mvaFall17V2Iso_WP90", "Electron_charge",
        
        "Muon_pt", "Muon_eta", "Muon_phi", "Muon_mass", 
        "Muon_tightId", "Muon_charge", "Muon_pfRelIso04_all",
        "PuppiMET_pt", "PuppiMET_phi",
        
        "Jet_pt", "Jet_eta", "Jet_phi", "Jet_mass",
        "Jet_btagDeepFlavB", "nJet", "Jet_jetId", "Jet_puId",

        "HLT_Mu12_TrkIsoVVL_Ele23_CaloIdL_TrackIdL_IsoVL_DZ",
        "HLT_Mu23_TrkIsoVVL_Ele12_CaloIdL_TrackIdL_IsoVL_DZ"
    ]

    if is_data:
        columns.extend(["run", "luminosityBlock"])
    else:
        columns.append("genWeight")
        
    for attempt in range(max_retries):
        try:
            with uproot.open(file_url, timeout=timeout) as f:
                tree = f['Events']
                
                for arrays in tree.iterate(columns, step_size=batch_size, library="ak"):
                    yield arrays
                
                return
                
        except (TimeoutError, OSError, IOError, ConnectionError) as e:
            error_type = type(e).__name__
            file_name = file_url.split('/')[-1]
            
            if attempt < max_retries - 1:
                print(f"      {error_type} on {file_name}")
                print(f"       Retry {attempt+1}/{max_retries-1} in {retry_wait}s...")
                time.sleep(retry_wait)
            else:
                print(f"     FAILED after {max_retries} attempts: {file_name}")
                print(f"       Error: {str(e)[:100]}")
                raise
                
        except Exception as e:
            file_name = file_url.split('/')[-1]
            print(f"     Unexpected error on {file_name}: {str(e)[:100]}")
            raise

## TRIGGER PART

Trigger efficiency = $\frac{denominator + Trigger cut}{\#\ of\ events\ after\ passing\ preselection}$

> Preselection inlcudes:
> 1. 2 leptons
> 2. lepton ID (Electron \& Muon)
> 3. |$\eta$| < 2.5
> 4. pT requirement: lead >25 and sublead > 13

In [8]:
def select_tight_leptons(arrays):
    tight_electron_mask = arrays.Electron_mvaFall17V2Iso_WP90 == 1
    tight_muon_mask = (arrays.Muon_tightId == 1) & (arrays.Muon_pfRelIso04_all < 0.15)
    
    tight_electrons = ak.zip({
        "pt": arrays.Electron_pt[tight_electron_mask],
        "eta": arrays.Electron_eta[tight_electron_mask],
        "phi": arrays.Electron_phi[tight_electron_mask],
        "mass": arrays.Electron_mass[tight_electron_mask],
        "charge": arrays.Electron_charge[tight_electron_mask],
        "flavor": ak.values_astype(ak.ones_like(arrays.Electron_pt[tight_electron_mask]) * 11, "int32")
    })
    
    tight_muons = ak.zip({
        "pt": arrays.Muon_pt[tight_muon_mask],
        "eta": arrays.Muon_eta[tight_muon_mask],
        "phi": arrays.Muon_phi[tight_muon_mask],
        "mass": arrays.Muon_mass[tight_muon_mask],
        "charge": arrays.Muon_charge[tight_muon_mask],
        "flavor": ak.values_astype(ak.ones_like(arrays.Muon_pt[tight_muon_mask]) * 13, "int32")
    })
    
    tight_leptons = ak.concatenate([tight_electrons, tight_muons], axis=1)
    return tight_leptons
    

In [9]:
def select_emu_events(tight_leptons, arrays):
    # Sort leptons
    sorted_leptons = tight_leptons[ak.argsort(tight_leptons.pt, ascending=False)]

    mask_2lep = ak.num(sorted_leptons) == 2
    
    events_2lep = sorted_leptons[mask_2lep]
    arrays_2lep = arrays[mask_2lep]  # Keeps HLT branches aligned

    if len(events_2lep) == 0:
        return 0, 0, None

    # Kinematic Cuts 
    leading = events_2lep[:, 0]
    subleading = events_2lep[:, 1]

    mask_flavor = ((leading.flavor == 13) & (subleading.flavor == 11)) | \
                  ((leading.flavor == 11) & (subleading.flavor == 13))
    mask_charge = leading.charge * subleading.charge < 0
    mask_pt = (leading.pt > 25) & (subleading.pt > 13)
    mask_eta = (abs(leading.eta) < 2.5) & (abs(subleading.eta) < 2.5)

    # E-Mu Selected mask
    mask_emu_kinematics = mask_flavor & mask_charge & mask_pt & mask_eta

    # Trigger (HLT) Cut
    # apply this to 'arrays_2lep' which matches 'events_2lep' size
    mask_hlt = (arrays_2lep.HLT_Mu12_TrkIsoVVL_Ele23_CaloIdL_TrackIdL_IsoVL_DZ == 1) | \
               (arrays_2lep.HLT_Mu23_TrkIsoVVL_Ele12_CaloIdL_TrackIdL_IsoVL_DZ == 1)

    #  Final Masks
    # Events passing ONLY kinematics
    events_passing_emu = events_2lep[mask_emu_kinematics]
    
    # Events passing Kinematics AND HLT
    final_mask = mask_emu_kinematics & mask_hlt
    events_passing_all = events_2lep[final_mask]

    # Return counts and the final objects
    n_emu = len(events_passing_emu)
    n_final = len(events_passing_all)
    
    return n_emu, n_final, events_passing_all

In [10]:
import time
import awkward as ak
import numpy as np

def make_processor(golden_json_data, run_periods):
    """
    Factory function that returns a worker function with 
    JSON data and Run Periods baked in (Closure Pattern).
    """

    def processing_file(label, file_url, file_idx):
        
        # Initialize Counters
        count_emu_kinematics = 0
        count_emu_hlt = 0
        
        file_name = file_url.split('/')[-1] 
        is_data = (label == 'Data')
        
        max_file_retries = 3

        for file_attempt in range(max_file_retries):
            try:
                # Load Events
                for arrays in load_events(file_url, batch_size=1_250_000, is_data=is_data):
                    
                    #  Apply JSON Mask to Data
                    if is_data and golden_json_data is not None:
                        try:
                            json_mask = apply_json_mask(arrays, golden_json_data, run_periods=run_periods)
                            if np.sum(json_mask) == 0: continue
                            arrays = arrays[json_mask]
                        except Exception as e: 
                            print(f"Warning: JSON mask failed for {file_name}: {e}")
                            continue
                    
                    # Object Selection
                    tight_leptons = select_tight_leptons(arrays)
                    
                    # Event Selection (Kinematics + HLT)
                    n_emu, n_hlt, _ = select_emu_events(tight_leptons, arrays)
                    
                    #  Accumulate
                    count_emu_kinematics += n_emu
                    count_emu_hlt += n_hlt
                
                # Success
                return label, count_emu_kinematics, count_emu_hlt, None

            except (OSError, IOError, ValueError) as e:
                if file_attempt < max_file_retries - 1:
                    time.sleep(3)
                    continue
                else: 
                    return label, 0, 0, f"{file_name}: Failed after retries - {str(e)[:100]}"
            
            except Exception as e:
                return label, 0, 0, f"{file_name}: Unexpected error - {str(e)[:100]}"

        return label, 0, 0, "Unknown loop exit"

    # Return the inner function
    return processing_file

In [11]:
# %%
# MAIN PROCESSING (Trigger Efficiency)

import time
import json
from collections import defaultdict
from dask.distributed import progress

print(f"\n{'='*70}\nTRIGGER EFFICIENCY PROCESSING START\n{'='*70}")

golden_json_data = None
if GOLDEN_JSON_PATH.exists():
    # print(f"Reading Golden JSON: {GOLDEN_JSON_PATH.name}")
    with open(GOLDEN_JSON_PATH, 'r') as f:
        golden_json_data = json.load(f)
    # print(f"  Loaded {len(golden_json_data)} runs into memory\n")
else:
    print(f"WARNING: Golden JSON not found at {GOLDEN_JSON_PATH}")

processing_task = make_processor(
    golden_json_data=golden_json_data,
    run_periods=RUN_PERIODS_2016
)

arg_labels = []
arg_urls = []
arg_indices = []

print("Preparing file lists...")

for label, urls in files.items():
    is_data = (label == 'Data')
    
    if is_data:
        if golden_json_data is not None:
             print(f"  {label}: Validation enabled ({len(urls)} files)")
    
    for file_idx, file_url in enumerate(urls):
        arg_labels.append(label)
        arg_urls.append(str(file_url))
        arg_indices.append(file_idx)

start_time = time.perf_counter()

print(f"\nSubmitting {len(arg_urls)} files to the cluster...")

futures = client.map(
    processing_task,    
    arg_labels,
    arg_urls,
    arg_indices,
    retries=1
)

progress(futures)
results = client.gather(futures)
elapsed = time.perf_counter() - start_time

final_stats = defaultdict(lambda: [0, 0, 0]) 
errors = []

for label, n_kinematics, n_hlt, error in results:
    if error:
        errors.append((label, error))
    else:
        stats = final_stats[label]
        stats[0] += n_kinematics # Denominator (Events passing cuts)
        stats[1] += n_hlt        # Numerator (Events passing cuts + HLT)
        stats[2] += 1            # File count

print(f"\n{'='*70}")
print(f"{'SAMPLE':<20} | {'FILES':<8} | {'KINEMATICS':>14} | {'HLT PASS':>12} | {'TRIG EFF':>10}")
print("="*70)

tot_kinematics = tot_hlt = tot_files = 0

for label, (n_kinematics, n_hlt, n_files) in sorted(final_stats.items()):
    # Efficiency = (HLT Pass / Kinematics Selection)
    eff = (n_hlt / n_kinematics * 100) if n_kinematics > 0 else 0.0
    
    print(f"{label:<20} | {n_files:<8} | {n_kinematics:>14,} | {n_hlt:>12,} | {eff:>9.2f}%")
    
    tot_kinematics += n_kinematics
    tot_hlt += n_hlt
    tot_files += n_files

print("_"*70)
tot_eff = (tot_hlt / tot_kinematics * 100) if tot_kinematics > 0 else 0.0
print(f"{'TOTAL':<20} | {tot_files:<8} | {tot_kinematics:>14,} | {tot_hlt:>12,} | {tot_eff:>9.2f}%")
print(f"{'='*70}")

if errors:
    print(f"\n[!] Encountered {len(errors)} errors:")
    for label, err in errors[:5]: print(f"  - {label}: {err}")
    if len(errors) > 5: print(f"  ... and {len(errors)-5} more.")

print(f"\nDone in {elapsed:.1f}s ({elapsed/len(arg_urls):.2f}s/file)")


TRIGGER EFFICIENCY PROCESSING START
Preparing file lists...
  Data: Validation enabled (1 files)

Submitting 1 files to the cluster...

SAMPLE               | FILES    |     KINEMATICS |     HLT PASS |   TRIG EFF
Data                 | 1        |          9,635 |        7,935 |     82.36%
______________________________________________________________________
TOTAL                | 1        |          9,635 |        7,935 |     82.36%

Done in 50.6s (50.61s/file)


## Lepton efficiency-- Tag and probe method

In [32]:
def lepton_array(arrays):
    electrons = ak.zip({
        "pt": arrays.Electron_pt,
        "eta": arrays.Electron_eta,
        "phi": arrays.Electron_phi,
        "mass": arrays.Electron_mass,
        "charge": arrays.Electron_charge,
        "id_pass": arrays.Electron_mvaFall17V2Iso_WP90 == 1, 
        "flavor": ak.ones_like(arrays.Electron_pt) * 11
    })
    
    muons = ak.zip({
        "pt": arrays.Muon_pt,
        "eta": arrays.Muon_eta,
        "phi": arrays.Muon_phi,
        "mass": arrays.Muon_mass,
        "charge": arrays.Muon_charge,
        "id_pass": (arrays.Muon_tightId == 1) & (arrays.Muon_pfRelIso04_all < 0.15), 
        "flavor": ak.ones_like(arrays.Muon_pt) * 13
    })

    return electrons, muons

In [36]:
def select_tag_probe_events(leptons, probe_pt_lower = 10, probe_pt_upper = 50):
    """
    Ordered Tag & Probe:
    - Tag   = Leading Lepton (Must pass Tight ID)
    - Probe = Subleading Lepton (No ID check yet)
    """
    
    sorted_leptons = leptons[ak.argsort(leptons.pt, ascending=False)]

    mask_2lep = ak.num(sorted_leptons) == 2
    events_2lep = sorted_leptons[mask_2lep]

    if len(events_2lep) == 0:
        return None, None

    tag_candidate = events_2lep[:, 0]   # Leading
    probe_candidate = events_2lep[:, 1] # Subleading

    # Charge
    mask_charge = tag_candidate.charge * probe_candidate.charge < 0
    
    # Kinematics (Tag > 35, Probe > 10)
    mask_pt = (tag_candidate.pt > 35) & (probe_candidate.pt < probe_pt_upper) & (probe_candidate.pt >probe_pt_lower)
    mask_eta = (abs(tag_candidate.eta) < 2.5) & (abs(probe_candidate.eta) < 2.5)
    
    #  Leading must pass ID
    mask_tag_id = (tag_candidate.id_pass == True)

    # 5. Final Mask
    final_mask = mask_charge & mask_pt & mask_eta & mask_tag_id

    # Return valid pairs
    return tag_candidate[final_mask], probe_candidate[final_mask], probe_pt_lower, probe_pt_upper

In [34]:
def create_lepton_vector(lepton):
    """Create 4-vector from lepton properties """
    return vector.array({
        "pt": lepton.pt,
        "eta": lepton.eta,
        "phi": lepton.phi,
        "mass": lepton.mass
    })

def calculate_mll(lepton_1, lepton_2):
    vec_1 = create_lepton_vector(lepton_1)
    vec_2 = create_lepton_vector(lepton_2)

    dilepton = vec_1 + vec_2

    mll = dilepton.mass

    return mll



In [37]:
import time
import awkward as ak
import numpy as np
import vector

vector.register_awkward() 

def tag_prob_process(golden_json_data, run_periods):
    """
    Factory function for Tag & Probe Analysis.
    Returns: label, numerator (passing probes), denominator (total probes), error
    """

    def tag_probe_processing(label, file_url, file_idx):
        
        count_total_probes = 0   # Denominator
        count_passing_probes = 0 # Numerator 
        
        file_name = file_url.split('/')[-1] 
        is_data = (label == 'Data')
        
        target_flavor = 'electron' 
        
        max_file_retries = 3

        for file_attempt in range(max_file_retries):
            try:
                #  Load Events
                for arrays in load_events(file_url, batch_size=1_250_000, is_data=is_data):
                    
                    # Apply JSON Mask 
                    if is_data and golden_json_data is not None:
                        try:
                            json_mask = apply_json_mask(arrays, golden_json_data, run_periods=run_periods)
                            if np.sum(json_mask) == 0: continue
                            arrays = arrays[json_mask]
                        except Exception as e: 
                            print(f"Warning: JSON mask failed for {file_name}: {e}")
                            continue
                    
                    # Create Lepton Objects 
                    electrons, muons = lepton_array(arrays)
                    
                    # Select Flavor
                    leptons = electrons if target_flavor == 'electron' else muons

                    # 4. Tag & Probe Selection
                    tags, probes, pt_lower, pt_upper = select_tag_probe_events(leptons)

                    
                    print(f"For pT range {pt_lower} to {pt_upper} GeV")

                    if tags is None or len(tags) == 0:
                        continue

                    # m_ll = (tags + probes).mass
                    m_ll = calculate_mll(tags, probes)
                    
                    z_mask = (m_ll > 60) & (m_ll < 120)
                    
                    # Apply Mask
                    valid_tags = tags[z_mask]
                    valid_probes = probes[z_mask]

                    
                    if len(valid_tags) == 0: 
                        continue

                    # Count Events
                    # Denominator
                    n_total = len(valid_probes)
                    
                    # Numerator: 
                    n_pass = ak.sum(valid_probes.id_pass)
                    
                    count_total_probes += n_total
                    count_passing_probes += n_pass
                
                # Success 
                return label, count_passing_probes, count_total_probes, None

            except (OSError, IOError, ValueError) as e:
                if file_attempt < max_file_retries - 1:
                    time.sleep(3)
                    continue
                else: 
                    return label, 0, 0, f"{file_name}: Retry limit - {str(e)[:100]}"
            
            except Exception as e:
                return label, 0, 0, f"{file_name}: Crash - {str(e)[:100]}"

        return label, 0, 0, "Unknown loop exit"

    return tag_probe_processing

In [38]:
# %%
# MAIN PROCESSING 

import time
import json
from collections import defaultdict
from dask.distributed import progress

print(f"\n{'='*70}\nTAG & PROBE PROCESSING START\n{'='*70}")

golden_json_data = None
if GOLDEN_JSON_PATH.exists():
    # print(f"Reading Golden JSON: {GOLDEN_JSON_PATH.name}")
    with open(GOLDEN_JSON_PATH, 'r') as f:
        golden_json_data = json.load(f)
    # print(f"  Loaded {len(golden_json_data)} runs into memory\n")
else:
    print(f"WARNING: Golden JSON not found at {GOLDEN_JSON_PATH}")

processing_task = tag_prob_process(
    golden_json_data=golden_json_data,
    run_periods=RUN_PERIODS_2016
)

arg_labels = []
arg_urls = []
arg_indices = []

# print("Preparing file lists...")

for label, urls in files.items():
    is_data = (label == 'Data')
    
    # if is_data:
    #     if golden_json_data is not None:
             # print(f"  {label}: Validation enabled ({len(urls)} files)")
    
    for file_idx, file_url in enumerate(urls):
        arg_labels.append(label)
        arg_urls.append(str(file_url))
        arg_indices.append(file_idx)

start_time = time.perf_counter()

# print(f"\nSubmitting {len(arg_urls)} files to the cluster...")

futures = client.map(
    processing_task,   
    arg_labels,
    arg_urls,
    arg_indices,
    retries=1
)

progress(futures)
results = client.gather(futures)
elapsed = time.perf_counter() - start_time

final_stats = defaultdict(lambda: [0, 0, 0]) 
errors = []

for label, n_pass, n_total, error in results:
    if error:
        errors.append((label, error))
    else:
        stats = final_stats[label]
        stats[0] += n_pass   # Numerator (Passing ID)
        stats[1] += n_total  # Denominator (All Probes)
        stats[2] += 1        # File count

print(f"\n{'='*90}")
print(f"{'SAMPLE':<20} | {'FILES':<8} | {'TOTAL PROBES':>14} | {'PASSING ID':>12} | {'EFFICIENCY':>10}")
print("="*90)

tot_pass = tot_total = tot_files = 0

for label, (n_pass, n_total, n_files) in sorted(final_stats.items()):
    eff = (n_pass / n_total * 100) if n_total > 0 else 0.0
    print(f"{label:<20} | {n_files:<8} | {n_total:>14,} | {n_pass:>12,} | {eff:>9.2f}%")
    tot_pass += n_pass
    tot_total += n_total
    tot_files += n_files

print("_"*90)
tot_eff = (tot_pass / tot_total * 100) if tot_total > 0 else 0.0
print(f"{'TOTAL':<20} | {tot_files:<8} | {tot_total:>14,} | {tot_pass:>12,} | {tot_eff:>9.2f}%")
print(f"{'='*90}")

if errors:
    print(f"\n[!] Encountered {len(errors)} errors:")
    for label, err in errors[:5]: print(f"  - {label}: {err}")
    if len(errors) > 5: print(f"  ... and {len(errors)-5} more.")

print(f"\nDone in {elapsed:.1f}s ({elapsed/len(arg_urls):.2f}s/file)")


TAG & PROBE PROCESSING START
Preparing file lists...

SAMPLE               | FILES    |   TOTAL PROBES |   PASSING ID | EFFICIENCY
Data                 | 48       |         51,046 |       25,154 |     49.28%
__________________________________________________________________________________________
TOTAL                | 48       |         51,046 |       25,154 |     49.28%

Done in 161.6s (3.37s/file)
