# Calculate TIME_FOR_WORK_UNIT_SEC

This notebook calculates the optimal `TIME_FOR_WORK_UNIT_SEC` parameter based on real data measurements. This value is used in `submit_im_jobs.sh` to determine how many batch jobs are needed.

A work unit = 1 file × 1 combination × 1 final state


In [40]:
import sys
import os
sys.path.insert(0, os.path.abspath('..'))

import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import yaml
from pathlib import Path
import awkward as ak
from tqdm import tqdm
import gc
import random

from src.parse_atlas import parser
from src.calculations import combinatorics
from src.im_calculator.im_calculator import IMCalculator

%matplotlib inline
plt.style.use('default')


## Load Configuration


In [42]:
# Load configuration
config_path = 'configs/pipeline_config.yaml'
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

im_config = config['mass_calculate']
input_dir = im_config['input_dir']

print("Configuration loaded:")
print(f"  Input directory: {input_dir}")
print(f"  Objects to calculate: {im_config['objects_to_calculate']}")
print(f"  Min/Max particles: {im_config['min_particles']}-{im_config['max_particles']}")
print(f"  Min/Max count: {im_config['min_count']}-{im_config['max_count']}")

# Get all combinations
all_combinations = combinatorics.get_all_combinations(
    im_config["objects_to_calculate"],
    min_particles=im_config["min_particles"],
    max_particles=im_config["max_particles"],
    min_count=im_config["min_count"],
    max_count=im_config["max_count"],
    limit=im_config.get("limit_combinations")
)

print(f"\nTotal combinations: {len(all_combinations)}")


Configuration loaded:
  Input directory: /storage/agrp/netalev/data/root_files/
  Objects to calculate: ['Electrons', 'Muons', 'Jets', 'Photons']
  Min/Max particles: 2-4
  Min/Max count: 2-4

Total combinations: 28512


## Sample Files for Measurement


In [44]:
# Get ROOT files
root_files = [f for f in os.listdir(input_dir) if f.endswith(".root")]
print(f"Found {len(root_files)} ROOT files")

# Sample files for measurement (stratified by size if possible)
sample_size = min(10, len(root_files))  # Sample 10 files or all if less

# Get file sizes and sample
file_sizes = [(f, os.path.getsize(os.path.join(input_dir, f))) for f in root_files]
file_sizes.sort(key=lambda x: x[1], reverse=True)

# Sample from different size ranges
if len(file_sizes) >= sample_size:
    # Take some large, some medium, some small
    large = file_sizes[:len(file_sizes)//3]
    medium = file_sizes[len(file_sizes)//3:2*len(file_sizes)//3]
    small = file_sizes[2*len(file_sizes)//3:]
    
    sample_files = (
        random.sample(large, min(sample_size//3, len(large))) +
        random.sample(medium, min(sample_size//3, len(medium))) +
        random.sample(small, min(sample_size - 2*(sample_size//3), len(small)))
    )
else:
    sample_files = file_sizes

sample_filenames = [f[0] for f in sample_files]
print(f"\nSampling {len(sample_filenames)} files for measurement:")
for filename, size_bytes in sample_files[:5]:
    print(f"  {filename}: {size_bytes/(1024**2):.1f} MB")
if len(sample_files) > 5:
    print(f"  ... and {len(sample_files)-5} more")


Found 14 ROOT files

Sampling 10 files for measurement:
  2024r-pp_0f75d28651db83f5.root: 388.5 MB
  2024r-pp_1da876f9c54974c0.root: 515.7 MB
  2024r-pp_0fced08cf06527b7.root: 86.1 MB
  2024r-pp_bc43f0e26da96200.root: 5.0 MB
  2024r-pp_mc_5742f42bd168c3c5.root: 5.7 MB
  ... and 5 more


## Measure Time per Work Unit


In [None]:
def measure_work_unit_time(file_path, filename, all_combinations, config):
    """
    Measure time for processing one work unit (file × combination × final_state).
    Returns list of measurements.
    """
    measurements = []
    
    try:
        # Parse file
        parse_start = time.time()
        particle_arrays = parser.AtlasOpenParser.parse_root_file(file_path)
        parse_time = time.time() - parse_start
        
        if particle_arrays is None or len(particle_arrays) == 0:
            return measurements
        
        # Initialize calculator
        calculator = IMCalculator(particle_arrays)
        
        # Process each final state
        for final_state, fs_events in calculator.group_by_final_state():
            if len(fs_events) == 0:
                continue
            
            # For each combination that matches this final state
            for combination in all_combinations:
                if not calculator.does_final_state_contain_combination(final_state, combination):
                    continue
                
                # Measure time for this work unit
                work_start = time.time()
                
                # Filter
                filtered = calculator.filter_by_particle_counts(
                    fs_events, combination, is_exact_count=True
                )
                if len(filtered) == 0:
                    continue
                
                # Slice
                field_to_slice_by = config.get("field_to_slice_by", "pt")
                sliced = calculator.slice_by_field(
                    filtered, combination, field_to_slice_by
                )
                if len(sliced) == 0:
                    continue
                
                # Calculate invariant mass
                inv_mass = calculator.calculate_invariant_mass(sliced)
                if not ak.any(inv_mass):
                    continue
                
                work_time = time.time() - work_start
                
                measurements.append({
                    'filename': filename,
                    'final_state': final_state,
                    'combination': str(combination),
                    'parse_time_sec': parse_time,
                    'work_time_sec': work_time,
                    'num_events': len(fs_events),
                    'num_mass_values': len(inv_mass)
                })
        
        # Clean up
        del particle_arrays, calculator
        gc.collect()
        
    except Exception as e:
        print(f"Error processing {filename}: {e}")
    
    return measurements

# Measure work unit times
print("Measuring work unit times...")
all_measurements = []

for filename in tqdm(sample_filenames, desc="Processing files"):
    file_path = os.path.join(input_dir, filename)
    measurements = measure_work_unit_time(file_path, filename, all_combinations, im_config)
    all_measurements.extend(measurements)
    
    if measurements:
        avg_time = np.mean([m['work_time_sec'] for m in measurements])
        print(f"  {filename}: {len(measurements)} work units, avg {avg_time:.4f}s per unit")

if not all_measurements:
    print("⚠️  No measurements collected!")
else:
    df = pd.DataFrame(all_measurements)
    print(f"\n✓ Collected {len(df)} work unit measurements")


Measuring work unit times...


Processing files:   0%|                      | 0/10 [00:00<?, ?it/s]

## Calculate Statistics and Recommended TIME_FOR_WORK_UNIT_SEC


In [None]:
if len(all_measurements) > 0:
    df = pd.DataFrame(all_measurements)
    
    # Calculate statistics
    work_times = df['work_time_sec'].values
    
    stats = {
        'mean': np.mean(work_times),
        'median': np.median(work_times),
        'std': np.std(work_times),
        'min': np.min(work_times),
        'max': np.max(work_times),
        'p25': np.percentile(work_times, 25),
        'p75': np.percentile(work_times, 75),
        'p95': np.percentile(work_times, 95),
        'p99': np.percentile(work_times, 99)
    }
    
    print("="*80)
    print("WORK UNIT TIME STATISTICS")
    print("="*80)
    print(f"Number of measurements: {len(work_times)}")
    print(f"\nTime per work unit (seconds):")
    print(f"  Mean:   {stats['mean']:.4f}")
    print(f"  Median: {stats['median']:.4f}")
    print(f"  Std:    {stats['std']:.4f}")
    print(f"  Min:    {stats['min']:.4f}")
    print(f"  Max:    {stats['max']:.4f}")
    print(f"\nPercentiles:")
    print(f"  25th:   {stats['p25']:.4f}")
    print(f"  75th:   {stats['p75']:.4f}")
    print(f"  95th:   {stats['p95']:.4f}")
    print(f"  99th:   {stats['p99']:.4f}")
    
    # Recommendations
    print("\n" + "="*80)
    print("RECOMMENDATIONS")
    print("="*80)
    
    # Conservative estimate (p95 to handle outliers)
    recommended_p95 = stats['p95']
    recommended_mean = stats['mean']
    recommended_median = stats['median']
    
    print(f"\n1. Conservative estimate (95th percentile): {recommended_p95:.4f} seconds")
    print(f"   → Use this for safety margin: handles 95% of cases")
    
    print(f"\n2. Mean estimate: {recommended_mean:.4f} seconds")
    print(f"   → Use this for average case estimation")
    
    print(f"\n3. Median estimate: {recommended_median:.4f} seconds")
    print(f"   → Use this if data has outliers")
    
    # Account for parsing time
    if 'parse_time_sec' in df.columns:
        avg_parse_time = df.groupby('filename')['parse_time_sec'].first().mean()
        print(f"\n4. Average parse time per file: {avg_parse_time:.2f} seconds")
        print(f"   → This is one-time cost per file, not per work unit")
        
        # Calculate total time including parsing overhead
        # Parse time is amortized across all work units in a file
        avg_work_units_per_file = df.groupby('filename').size().mean()
        parse_overhead_per_unit = avg_parse_time / avg_work_units_per_file if avg_work_units_per_file > 0 else 0
        
        print(f"\n5. Parse overhead per work unit: {parse_overhead_per_unit:.4f} seconds")
        print(f"   → Average work units per file: {avg_work_units_per_file:.1f}")
        
        total_recommended = recommended_p95 + parse_overhead_per_unit
        print(f"\n6. Total recommended (work + parse overhead): {total_recommended:.4f} seconds")
    
    # Final recommendation
    print("\n" + "="*80)
    print("FINAL RECOMMENDATION")
    print("="*80)
    
    # Use p95 for safety, but add some buffer
    final_recommendation = max(recommended_p95 * 1.2, recommended_mean * 1.5)
    
    print(f"\nRecommended TIME_FOR_WORK_UNIT_SEC: {final_recommendation:.4f} seconds")
    print(f"\nThis value accounts for:")
    print(f"  - 95th percentile work unit time: {recommended_p95:.4f}s")
    print(f"  - Safety margin (20-50%): {final_recommendation - recommended_p95:.4f}s")
    if 'parse_time_sec' in df.columns:
        print(f"  - Parse overhead: {parse_overhead_per_unit:.4f}s per work unit")
    
    print(f"\nTo use in submit_im_jobs.sh:")
    print(f"  --time_for_work_unit_sec {final_recommendation:.4f}")
    
    # Save to file
    output_file = '../execute_as_batch_jobs/recommended_time_per_work_unit.txt'
    with open(output_file, 'w') as f:
        f.write(f"{final_recommendation:.4f}\n")
    print(f"\n✓ Saved recommendation to: {output_file}")
else:
    print("⚠️  No measurements to analyze!")


## Visualize Distribution


In [None]:
if len(all_measurements) > 0:
    df = pd.DataFrame(all_measurements)
    work_times = df['work_time_sec'].values
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Histogram
    axes[0].hist(work_times, bins=50, edgecolor='black', alpha=0.7)
    axes[0].axvline(np.mean(work_times), color='red', linestyle='--', label=f'Mean: {np.mean(work_times):.4f}s')
    axes[0].axvline(np.median(work_times), color='green', linestyle='--', label=f'Median: {np.median(work_times):.4f}s')
    axes[0].axvline(np.percentile(work_times, 95), color='orange', linestyle='--', label=f'95th: {np.percentile(work_times, 95):.4f}s')
    axes[0].set_xlabel('Time per Work Unit (seconds)')
    axes[0].set_ylabel('Frequency')
    axes[0].set_title('Distribution of Work Unit Times')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Log scale if needed
    if np.max(work_times) / np.min(work_times[work_times > 0]) > 100:
        axes[1].hist(work_times, bins=50, edgecolor='black', alpha=0.7)
        axes[1].set_xscale('log')
        axes[1].set_xlabel('Time per Work Unit (seconds, log scale)')
        axes[1].set_ylabel('Frequency')
        axes[1].set_title('Distribution (Log Scale)')
        axes[1].grid(True, alpha=0.3)
    else:
        # Box plot
        axes[1].boxplot(work_times, vert=True)
        axes[1].set_ylabel('Time per Work Unit (seconds)')
        axes[1].set_title('Box Plot of Work Unit Times')
        axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Show breakdown by file
    if 'filename' in df.columns:
        print("\n" + "="*80)
        print("BREAKDOWN BY FILE")
        print("="*80)
        file_stats = df.groupby('filename')['work_time_sec'].agg(['mean', 'std', 'count'])
        file_stats.columns = ['Avg Time (s)', 'Std Dev (s)', 'Work Units']
        print(file_stats.to_string())
