Calculate Training Steps

In [4]:
import math

# ==========================================
# 1. GLOBAL CONSTANTS & ASSUMPTIONS
# ==========================================

# CIFAR-10 Training Set Size
N_TRAIN_SAMPLES = 50_000 

def calculate_steps_per_epoch(n_samples, batch_size):
    """Calculates optimizer steps per epoch given dataset size and batch size."""
    # Assuming drop_last=False (standard) -> usage of ceil
    return math.ceil(n_samples / batch_size)

# ==========================================
# 2. METHOD CALCULATORS
# ==========================================

def calc_standard_training(name, epochs, batch_size):
    """
    Calculates steps for 'Random Grid' or 'Base' (Single Model).
    Methodology: Simple sum of updates for one model.
    """
    steps_per_epoch = calculate_steps_per_epoch(N_TRAIN_SAMPLES, batch_size)
    total_steps = steps_per_epoch * epochs
    
    print(f"--- {name} ---")
    print(f"Config: {epochs} epochs, BS={batch_size}")
    print(f"Calculation: {steps_per_epoch} steps/epoch * {epochs} epochs")
    print(f"Total Steps: {total_steps:,}\n")
    return total_steps

def calc_pbt(name, epochs, batch_size, population_size):
    """
    Calculates steps for Population Based Training (PBT).
    Methodology: Sum of steps taken by ALL population members.
    Cost = (Steps per Member) * (Population Size)
    """
    steps_per_epoch = calculate_steps_per_epoch(N_TRAIN_SAMPLES, batch_size)
    steps_per_member = steps_per_epoch * epochs
    total_steps = steps_per_member * population_size
    
    print(f"--- {name} ---")
    print(f"Config: {epochs} epochs, BS={batch_size}, Pop={population_size}")
    print(f"Total Steps: {total_steps:,}\n")
    return total_steps

def calc_multistage_grid(name, stage_details, batch_size):
    """
    Calculates steps for a 3-Stage Grid Search (e.g., Successive Halving or Coarse-to-Fine).
    
    stage_details: List of tuples [(num_configs, epochs_run), ...]
    
    Example for Successive Halving:
    Stage 1: Start with 27 configs, run for 3 epochs.
    Stage 2: Keep top 9 configs, run for 9 MORE epochs.
    Stage 3: Keep top 3 configs, run for 27 MORE epochs.
    Input: [(27, 3), (9, 9), (3, 27)]
    """
    total_steps = 0
    print(f"--- {name} ---")
    steps_per_epoch = calculate_steps_per_epoch(N_TRAIN_SAMPLES, batch_size)
    
    for i, (num_configs, epochs) in enumerate(stage_details):
        # Calculate cost for this specific stage
        stage_cost = num_configs * epochs * steps_per_epoch
        total_steps += stage_cost
        
        print(f"  Stage {i+1}: {num_configs} configs x {epochs} eps x {steps_per_epoch} steps/ep")
        print(f"           = {stage_cost:,} steps")

    print(f"Total Steps: {total_steps:,}\n")
    return total_steps

def calc_hdg(name, epochs, batch_size, hypergrad_calc_every_k_steps=1):
    """
    Calculates steps for Hyperparameter Gradient (HDG).
    Methodology: Parameter Updates + Hyperparameter Updates.
    """
    steps_per_epoch = calculate_steps_per_epoch(N_TRAIN_SAMPLES, batch_size)
    param_updates = steps_per_epoch * epochs
    hypergrad_updates = math.floor(param_updates / hypergrad_calc_every_k_steps)
    total_steps = param_updates + hypergrad_updates
    
    print(f"--- {name} ---")
    print(f"Total Steps: {total_steps:,}\n")
    return total_steps

# ==========================================
# 3. EXECUTION WITH REPORTED VALUES
# ==========================================

# 1. RANDOM GRID / BASE (Single Config Cost)
steps_random = calc_standard_training(
    name="Random Grid (1 Config)", 
    epochs=70,      
    batch_size=256
)

# 3. 3-STAGE GRID SEARCH
# Modify the tuples below to match your exact experiment.
# Format: (number_of_configurations_running, number_of_epochs_they_run_for)
# NOTE: If Stage 2 continues previous training, put only the *additional* epochs here.
stage_schedule = [
    (27, 20),   # Stage 1: 27 configs run for 20 epochs
    (27,  20),   # Stage 2: 27 configs continue for 20 more epochs
    (27,  30)    # Stage 3: 27 configs continue for 30 more epochs (Total 70)
]

steps_3stage_grid = calc_multistage_grid(
    name="3-Stage Grid Search",
    stage_details=stage_schedule,
    batch_size=256
)

# 4. HDG
steps_hdg = calc_hdg(
    name="HDG", 
    epochs=70, 
    batch_size=256
)

# ==========================================
# 4. SUMMARY TABLE PRINT
# ==========================================
print("=== FINAL COMPARISON ===")
print(f"{'Method':<25} | {'Total Steps':<15} | {'Relative Cost':<15}")
print("-" * 60)
base_cost = steps_random # Using single run as base unit
for name, cost in [("Random Grid (1 run)", steps_random), ("PBT", steps_pbt), ("3-Stage Grid", steps_3stage_grid), ("HDG", steps_hdg)]:
    relative = f"{cost / base_cost:.1f}x"
    print(f"{name:<25} | {cost:<15,} | {relative:<15}")

--- Random Grid (1 Config) ---
Config: 70 epochs, BS=256
Calculation: 196 steps/epoch * 70 epochs
Total Steps: 13,720

--- 3-Stage Grid Search ---
  Stage 1: 27 configs x 20 eps x 196 steps/ep
           = 105,840 steps
  Stage 2: 27 configs x 20 eps x 196 steps/ep
           = 105,840 steps
  Stage 3: 27 configs x 30 eps x 196 steps/ep
           = 158,760 steps
Total Steps: 370,440

--- HDG ---
Total Steps: 27,440

=== FINAL COMPARISON ===
Method                    | Total Steps     | Relative Cost  
------------------------------------------------------------
Random Grid (1 run)       | 13,720          | 1.0x           
PBT                       | 136,850         | 10.0x          
3-Stage Grid              | 370,440         | 27.0x          
HDG                       | 27,440          | 2.0x           


In [7]:
import numpy as np
import scipy.stats as st

# ==========================================
# 1. HELPER FUNCTION
# ==========================================

def compute_stats(accuracies, confidence=0.95):
    """
    Calculates mean and the 95% Confidence Interval margin of error.
    
    Args:
        accuracies (list or np.array): List of validation accuracies (e.g., [85.1, 84.9, ...])
        confidence (float): The confidence level (default 0.95)
        
    Returns:
        dict: {'mean': float, 'error': float, 'formatted': str}
    """
    a = np.array(accuracies)
    n = len(a)
    
    # 1. Calculate Mean
    m = np.mean(a)
    
    # 2. Calculate Margin of Error
    if n < 2:
        # Cannot compute std dev with 1 sample
        h = 0.0
    else:
        # Standard Error of the Mean (SE)
        se = st.sem(a) 
        
        # Critical Value (T-statistic for small samples, converges to 1.96 for large N)
        # ppf is the "Percent Point Function" (inverse of CDF)
        h = se * st.t.ppf((1 + confidence) / 2., n-1)

    return {
        "mean": m,
        "error": h,
        "n": n,
        "formatted": f"{m:.2f}% ± {h:.2f}%"
    }

# ==========================================
# 2. INPUT DATA (Replace with your actual lists)
# ==========================================

# Example: Accuracies collected from multiple random seeds

data = {
    "Random Grid (1 Config)": [84.68, 85.05, 85.13, 85.13, 85.13],    # ~5 seeds
    "3-Stage Grid Search":    [80.29],           # ~1 seed
    "HDG":                    [69.46, 70.37, 69.61, 69.14, 69.54]     # ~5 seeds
}

# ==========================================
# 3. GENERATE TABLE
# ==========================================

print("=== ACCURACY RESULTS (95% CI) ===")
print(f"{'Method':<25} | {'N':<3} | {'Mean Acc':<10} | {'95% CI':<10} | {'Formatted'}")
print("-" * 75)

for method, acc_list in data.items():
    stats = compute_stats(acc_list)
    
    # Unpack for cleaner print statement
    n = stats['n']
    mean = stats['mean']
    err = stats['error']
    fmt = stats['formatted']
    
    print(f"{method:<25} | {n:<3} | {mean:<9.2f}% | ± {err:<8.2f} | {fmt}")

=== ACCURACY RESULTS (95% CI) ===
Method                    | N   | Mean Acc   | 95% CI     | Formatted
---------------------------------------------------------------------------
Random Grid (1 Config)    | 5   | 85.02    % | ± 0.24     | 85.02% ± 0.24%
3-Stage Grid Search       | 1   | 80.29    % | ± 0.00     | 80.29% ± 0.00%
HDG                       | 5   | 69.62    % | ± 0.56     | 69.62% ± 0.56%
