In [None]:
# Nagel-Schreckenberg Traffic Model

# Imports
import numpy as np
import matplotlib.pyplot as plt
import os

# Create images directory to save plots if it doesn't exist
os.makedirs("images", exist_ok=True)

# Core model logic
def init_road(L, N, rng):
    """Initialize road with N cars placed randomly on length L."""
    road = -np.ones(L, dtype=int)
    pos = rng.choice(L, size=N, replace=False)
    road[pos] = 0
    return road

def step(road, v_max, p, rng):
    """
    One time step of the model.
    Returns:
        new_road: The state of the road after updates
        total_distance: Total distance moved by all cars (for flow calc)
    """
    L = len(road)
    positions = np.where(road >= 0)[0]
    N = len(positions)

    new_road = -np.ones(L, dtype=int)
    total_distance = 0

    if N == 0:
        return new_road, 0

    positions.sort()
    velocities = road[positions].copy()

    for i, pos in enumerate(positions):
        v = velocities[i]
        next_pos = positions[(i + 1) % N]
        d = (next_pos - pos) % L

        if d == 0:
            d = L

        # [DEBUG] Check for collisions
        # if d <= 0:
        #     # print(f"Warning: Car at {pos} crashed into leader at {next_pos}")

        # 1. Acceleration
        if v < v_max:
            v += 1

        # 2. Deceleration (due to other cars)
        if v >= d:
            v = d - 1

        # 3. Randomization
        if v > 0 and rng.random() < p:
            v -= 1

        # 4. Movement
        new_pos = (pos + v) % L
        new_road[new_pos] = v
        total_distance += v

    return new_road, total_distance
# Simulation Wrapper
def run_spacetime(L, N, v_max, p, steps, seed=0, warmup=0):
    """Run simulation and return history for space-time plotting."""
    rng = np.random.default_rng(seed)
    road = init_road(L, N, rng)

    for _ in range(warmup):
        road, _ = step(road, v_max, p, rng)

    history = np.zeros((steps, L), dtype=int)
    for t in range(steps):
        history[t] = road
        road, _ = step(road, v_max, p, rng)

    return history
# Investigation 1 - Space-Time Diagram
# print("\n--- Running Investigation 1: Space-Time Diagram ---")

L = 200
N = 80
v_max = 5
p = 0.25

history = run_spacetime(L, N, v_max, p, steps=300)

# # print(f"Simulation finished. Output shape: {history.shape}")

plt.figure(figsize=(10, 6))
plt.imshow(history, aspect="auto", cmap="viridis_r", vmin=-1, vmax=v_max)
plt.xlabel("Position")
plt.ylabel("Time step")
plt.title(f"Space–time diagram (N={N}, p={p})")
plt.colorbar(label="Velocity v")
plt.show()
# Fundamental diagram function
def fundamental_diagram(L, v_max, p, N_values, warmup=200, measure_steps=500):
    """Calculate flow vs density."""
    rng = np.random.default_rng(1)
    densities = []
    flows = []

    # # print(f"Starting Fundamental diagram calculation for {len(N_values)} density points...")

    for i, N in enumerate(N_values):

        road = init_road(L, N, rng)

        # Warmup phase to reach steady state
        for _ in range(warmup):
            road, _ = step(road, v_max, p, rng)

        # Measurement phase
        total_distance = 0
        for _ in range(measure_steps):
            road, dist = step(road, v_max, p, rng)
            total_distance += dist

        densities.append(N / L)
        flows.append(total_distance / (L * measure_steps))

    # # print("Calculation complete.")
    return np.array(densities), np.array(flows)
# Error Analysis Function - runs multiple simulations to compute error bars
def fundamental_diagram_with_errors(L, v_max, p, N_values, warmup=200, measure_steps=500, n_runs=20):
    """
    Calculate flow vs density with error bars using multiple independent runs.
    
    Parameters:
    - L: road length
    - v_max: maximum velocity
    - p: randomization probability
    - N_values: list of car numbers to test
    - warmup: number of warmup steps
    - measure_steps: number of measurement steps
    - n_runs: number of independent runs per density (default 20)
    
    Returns:
    - densities: array of density values
    - mean_flows: mean flow at each density
    - std_flows: standard deviation at each density
    - sem_flows: standard error of mean at each density
    """
    densities = []
    mean_flows = []
    std_flows = []
    sem_flows = []
    
    for N in N_values:
        density = N / L
        flows_at_this_density = []
        
        # Run multiple simulations with different seeds
        for run in range(n_runs):
            rng = np.random.default_rng(1000 + run)  # Unique seed for each run
            road = init_road(L, N, rng)
            
            # Warmup phase to reach steady state
            for _ in range(warmup):
                road, _ = step(road, v_max, p, rng)
            
            # Measurement phase
            total_distance = 0
            for _ in range(measure_steps):
                road, dist = step(road, v_max, p, rng)
                total_distance += dist
            
            # Calculate flow for this run
            flow = total_distance / (L * measure_steps)
            flows_at_this_density.append(flow)
        
        # Compute statistics for this density
        densities.append(density)
        mean_flows.append(np.mean(flows_at_this_density))
        std_flows.append(np.std(flows_at_this_density, ddof=1))
        sem_flows.append(np.std(flows_at_this_density, ddof=1) / np.sqrt(n_runs))
    
    return np.array(densities), np.array(mean_flows), np.array(std_flows), np.array(sem_flows)
# Investigation 1 (Part 2) - Plot Fundamental diagram WITH ERROR BARS
print("\n--- Running Investigation 1: Fundamental diagram (with error bars) ---")
print("This may take a few minutes...")

N_values = range(5, 195, 10)
n_runs = 20  # Number of independent runs per density

# Run error analysis for Investigation 1
rho, J, std_J, sem_J = fundamental_diagram_with_errors(
    L, v_max=5, p=0.25, N_values=N_values, n_runs=n_runs
)

# Find peak flow and corresponding density
peak_idx = np.argmax(J)
peak_rho = rho[peak_idx]
peak_J = J[peak_idx]

# Plot with error bars
fig, ax = plt.subplots(figsize=(10, 6))
ax.errorbar(rho, J, yerr=sem_J, fmt='o-', color='blue', 
            capsize=5, capthick=2, markersize=6, label='Flow with SEM')
ax.axvline(peak_rho, color='red', linestyle='--', linewidth=2, 
           label=f'Peak at ρ={peak_rho:.3f}')
ax.set_xlabel(r"Density $\rho$", fontsize=12)
ax.set_ylabel(r"Flow $J$", fontsize=12)
ax.set_title(r"Investigation 1: Fundamental Diagram with Error Bars" + "\n" +
             r"($v_{\max}=5$, $p=0.25$)", fontsize=14)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('images/investigation_1_fundamental_diagram_with_errors.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"✓ Investigation 1 complete (n={n_runs} runs per density)")
print(f"  Peak flow: J={peak_J:.3f} at density ρ={peak_rho:.3f}")
# Investigation 2 - Effect of randomness (p) WITH ERROR BARS
print("\n--- Running Investigation 2: Effect of Randomness (with error bars) ---")

N = 50
# First, show space-time diagrams
print("Generating space-time diagrams...")
history_p0 = run_spacetime(L, N, v_max=5, p=0.0, steps=120, warmup=200)
history_p5 = run_spacetime(L, N, v_max=5, p=0.5, steps=120, warmup=200)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].imshow(history_p0, aspect="auto", cmap="viridis_r", vmin=-1, vmax=5)
axes[0].set_title("p = 0 (Deterministic)")

axes[1].imshow(history_p5, aspect="auto", cmap="viridis_r", vmin=-1, vmax=5)
axes[1].set_title("p = 0.5 (Stochastic)")

for ax in axes:
    ax.set_xlabel("Position")
axes[0].set_ylabel("Time step")

plt.suptitle("Effect of random braking")
plt.savefig("images/investigation_2_spacetime_p0_vs_p05.png", dpi=300, bbox_inches='tight')
plt.show()

# Fundamental diagram comparison WITH ERROR BARS
print("Calculating fundamental diagrams with error analysis...")
print("This may take a few minutes (running 20 simulations per density point)...")

n_runs = 20  # Number of independent runs per density

# Run for p=0.0
print("  Running p=0.0...")
densities_p0, flows_p0, std_p0, sem_p0 = fundamental_diagram_with_errors(
    L, v_max=5, p=0.0, N_values=N_values, n_runs=n_runs
)

# Run for p=0.5
print("  Running p=0.5...")
densities_p5, flows_p5, std_p5, sem_p5 = fundamental_diagram_with_errors(
    L, v_max=5, p=0.5, N_values=N_values, n_runs=n_runs
)

# Plot with error bars
fig, ax = plt.subplots(figsize=(10, 6))
ax.errorbar(densities_p0, flows_p0, yerr=sem_p0, fmt='o-', color='blue',
            label='p=0.0 (deterministic)', capsize=5, capthick=2, markersize=6)
ax.errorbar(densities_p5, flows_p5, yerr=sem_p5, fmt='s-', color='red',
            label='p=0.5 (stochastic)', capsize=5, capthick=2, markersize=6)
ax.set_xlabel(r"Density $\rho = N/L$", fontsize=12)
ax.set_ylabel(r"Flow $J$", fontsize=12)
ax.set_title(r"Investigation 2: Effect of Randomization" + "\n" + 
             r"Fundamental Diagram with Error Bars ($v_{\max}=5$)", fontsize=14)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig("images/investigation_2_fundamental_diagram_with_errors.png", dpi=300, bbox_inches='tight')
plt.show()

print(f"✓ Investigation 2 complete (n={n_runs} runs per density)")
print(f"  Note: p=0.0 has very small error bars (deterministic behavior)")
print(f"  Note: p=0.5 shows larger variability due to stochasticity")
# Investigation 3 - Effect of Speed Limit WITH ERROR BARS
print("\n--- Running Investigation 3: Effect of Speed Limit (with error bars) ---")

p_fixed = 0.5
print(f"Comparing v_max=5 vs v_max=2 at p={p_fixed}...")
print("This may take a few minutes...")

# Run for v_max=5
print("  Running v_max=5...")
densities_v5, flows_v5, std_v5, sem_v5 = fundamental_diagram_with_errors(
    L, v_max=5, p=p_fixed, N_values=N_values, n_runs=n_runs
)

# Run for v_max=2
print("  Running v_max=2...")
densities_v2, flows_v2, std_v2, sem_v2 = fundamental_diagram_with_errors(
    L, v_max=2, p=p_fixed, N_values=N_values, n_runs=n_runs
)

# Plot with error bars
fig, ax = plt.subplots(figsize=(10, 6))
ax.errorbar(densities_v5, flows_v5, yerr=sem_v5, fmt='o-', color='blue',
            label=r'$v_{\max}=5$', capsize=5, capthick=2, markersize=6)
ax.errorbar(densities_v2, flows_v2, yerr=sem_v2, fmt='s-', color='red',
            label=r'$v_{\max}=2$', capsize=5, capthick=2, markersize=6)
ax.set_xlabel(r"Density $\rho = N/L$", fontsize=12)
ax.set_ylabel(r"Flow $J$", fontsize=12)
ax.set_title(f"Investigation 3: Effect of Maximum Speed\n" + 
             f"Fundamental Diagram with Error Bars (p={p_fixed})", fontsize=14)
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig("images/investigation_3_fundamental_diagram_with_errors.png", dpi=300, bbox_inches='tight')
plt.show()

print(f"✓ Investigation 3 complete (n={n_runs} runs per density)")
print("  Analysis: Check if curves are separated beyond uncertainty at low density")
print("  Analysis: Check if curves overlap within uncertainty at high density")

# Steady-State Analysis

print("Analyzing convergence to steady-state...\n")

# Test parameters
L = 200
n_cars = 50
v_max = 5
p = 0.25
max_iterations = 500

# Track average velocity over time
rng = np.random.default_rng(42)
road = init_road(L, n_cars, rng)
avg_velocities = []
avg_flows = []

print("Running simulation for", max_iterations, "iterations...")
for t in range(max_iterations):
    road, total_dist = step(road, v_max, p, rng)

    # Calculate average velocity
    velocities = road[road >= 0]
    avg_v = np.mean(velocities) if len(velocities) > 0 else 0
    avg_velocities.append(avg_v)

    # Calculate flow (cars passing per unit time)
    flow = total_dist / L
    avg_flows.append(flow)

# Plot convergence
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8))

# Plot average velocity over time
ax1.plot(avg_velocities, color='blue', linewidth=0.8)
ax1.axvline(x=200, color='red', linestyle='--', linewidth=2, label='Assumed steady-state (200 iter)')
ax1.set_xlabel("Iteration")
ax1.set_ylabel("Average velocity")
ax1.set_title("Convergence to steady-state: average velocity\n(N=50, ρ=0.25, p=0.25)")
ax1.grid(True, alpha=0.3)
ax1.legend()

# Plot flow over time
ax2.plot(avg_flows, color='green', linewidth=0.8)
ax2.axvline(x=200, color='red', linestyle='--', linewidth=2, label='Assumed steady-state (200 iter)')
ax2.set_xlabel("Iteration")
ax2.set_ylabel("Flow (J)")
ax2.set_title("Convergence to steady-state: flow")
ax2.grid(True, alpha=0.3)
ax2.legend()

plt.tight_layout()
plt.savefig('images/steadystate_convergence_analysis.png', dpi=300, bbox_inches='tight')
plt.show()

# Calculate statistics before and after iteration 200
before_200 = avg_flows[:200]
after_200 = avg_flows[200:]

print("\nStatistics:")
print(f"  First 200 iterations - Mean flow: {np.mean(before_200):.4f}, Std: {np.std(before_200):.4f}")
print(f"  After 200 iterations - Mean flow: {np.mean(after_200):.4f}, Std: {np.std(after_200):.4f}")
print(f"\nDifference in mean: {abs(np.mean(before_200) - np.mean(after_200)):.4f}")

# Calculate when system stabilizes (when std deviation becomes small)
window = 50
for i in range(window, len(avg_flows)):
    window_std = np.std(avg_flows[i-window:i])
    if window_std < 0.05:  # Threshold for "stable"
        print(f"\nSystem appears stable after ~{i} iterations (std < 0.05 over 50-step window)")
        break

In [None]:
# Traffic jam dispersal time analysis

# METHOD A

# Basic parameters
L = 200
N = 80
v_max = 5
p = 0.25

M_form = 10
K_clear = 10

T_warm = 500         # Very short warm-up
T_obs = 5000         # Short observation

print(f"\nParameters: L={L}, N={N}, p={p}, ρ={N/L:.2f}")
print(f"Jam detection: min_cluster={10} cars, M_form={M_form}, K_clear={K_clear}")
print(f"Duration: T_warm={T_warm}, T_obs={T_obs}")

# Helper function
def jam_present(road, min_cluster_size=10):
    """Detect if there are localized jam clusters (contiguous slow cars)"""
    slow_mask = (road >= 0) & (road <= 1)  # Slow/stopped cars
    cluster_size = 0

    for i in range(len(road)):
        if slow_mask[i]:
            cluster_size += 1
            if cluster_size >= min_cluster_size:
                return True  # Found a jam cluster
        else:
            cluster_size = 0  # Reset counter

    return False  # No jam clusters found

# Initialize
print("\nInitializing...")
rng = np.random.default_rng(42)
road = init_road(L, N, rng)
print("✓ Road initialized")

# Warm-up
print(f"\nWarm-up ({T_warm} steps)...")
for t in range(T_warm):
    road, _ = step(road, v_max, p, rng)
    if (t + 1) % 100 == 0:
        print(f"  {t + 1}/{T_warm}")
print("✓ Warm-up complete")

# Observation
print(f"\nObservation ({T_obs} steps)...")
jam_on_streak = 0
jam_off_streak = 0
in_jam = False
t_form = None
jam_lifetimes = []

for t in range(T_obs):
    is_jam_present = jam_present(road, min_cluster_size=10)

    if is_jam_present:
        jam_on_streak += 1
        jam_off_streak = 0
    else:
        jam_off_streak += 1
        jam_on_streak = 0

    if (not in_jam) and (jam_on_streak == M_form):
        in_jam = True
        t_form = t

    if in_jam and (jam_off_streak == K_clear):
        t_clear = t
        tau = t_clear - t_form
        jam_lifetimes.append(tau)
        in_jam = False
        t_form = None

    road, _ = step(road, v_max, p, rng)

    if (t + 1) % 250 == 0:
        print(f"  {t + 1}/{T_obs} - jams: {len(jam_lifetimes)}, in_jam: {in_jam}, on_streak: {jam_on_streak}, off_streak: {jam_off_streak}")

print("\n" + "="*70)
print(f"RESULT: {len(jam_lifetimes)} jams detected")
if len(jam_lifetimes) > 0:
    print(f"Mean lifetime: {np.mean(jam_lifetimes):.1f} steps")
    print(f"Median lifetime: {np.median(jam_lifetimes):.1f} steps")
    print(f"Std deviation: {np.std(jam_lifetimes):.1f} steps")
    print(f"Min/Max: {np.min(jam_lifetimes)} / {np.max(jam_lifetimes)} steps")
print("="*70)

# Visualization
if len(jam_lifetimes) > 0:
    # Figure 1: Histogram of jam lifetimes
    fig1, ax1 = plt.subplots(figsize=(10, 6))
    ax1.hist(jam_lifetimes, bins=15, edgecolor='black', alpha=0.7, color='steelblue')
    ax1.axvline(np.mean(jam_lifetimes), color='red', linestyle='--', linewidth=2, label=f'Mean = {np.mean(jam_lifetimes):.1f}')
    ax1.axvline(np.median(jam_lifetimes), color='orange', linestyle='--', linewidth=2, label=f'Median = {np.median(jam_lifetimes):.1f}')
    ax1.set_xlabel('Jam Lifetime (time steps)', fontsize=12)
    ax1.set_ylabel('Frequency', fontsize=12)
    ax1.set_title('Distribution of Traffic Jam Lifetimes', fontsize=14, fontweight='bold')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig('images/investigation_4_jam_lifetime_histogram.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("✓ Histogram saved: images/investigation_4_jam_lifetime_histogram.png")

    # Figure 2: Summary statistics table
    fig2, ax2 = plt.subplots(figsize=(8, 6))
    stats_data = [
        ['Total jams detected', f'{len(jam_lifetimes)}'],
        ['Mean lifetime', f'{np.mean(jam_lifetimes):.1f} steps'],
        ['Median lifetime', f'{np.median(jam_lifetimes):.1f} steps'],
        ['Std deviation', f'{np.std(jam_lifetimes):.1f} steps'],
        ['Min lifetime', f'{np.min(jam_lifetimes)} steps'],
        ['Max lifetime', f'{np.max(jam_lifetimes)} steps'],
        ['Observation period', f'{T_obs} steps'],
        ['Jam frequency', f'{len(jam_lifetimes)/T_obs*1000:.2f} jams/1000 steps'],
        ['', ''],
        ['Jam detection criteria', ''],
        ['Min cluster size', f'{10} consecutive slow cars'],
        ['Formation persistence (M)', f'{M_form} steps'],
        ['Clearance persistence (K)', f'{K_clear} steps']
    ]

    ax2.axis('tight')
    ax2.axis('off')
    table = ax2.table(cellText=stats_data, colLabels=['Metric', 'Value'],
                      cellLoc='left', loc='center', colWidths=[0.6, 0.4])
    table.auto_set_font_size(False)
    table.set_fontsize(11)
    table.scale(1, 2.5)

    # Style header
    for i in range(2):
        table[(0, i)].set_facecolor('#4472C4')
        table[(0, i)].set_text_props(weight='bold', color='white')

    # Style section header (row 11: 'Jam Detection Criteria')
    for j in range(2):
        table[(11, j)].set_facecolor('#8FAADC')
        table[(11, j)].set_text_props(weight='bold')

    # Alternate row colors (skip empty row 10 and section header row 11)
    for i in range(1, len(stats_data) + 1):
        if i == 10 or i == 11:  # Skip empty row and section header
            continue
        for j in range(2):
            if i % 2 == 0:
                table[(i, j)].set_facecolor('#E7E6E6')

    ax2.set_title('Jam dispersal analysis - summary statistics', fontsize=14, fontweight='bold', pad=20)
    plt.tight_layout()
    plt.savefig('images/investigation_4_jam_statistics_table.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("✓ Table saved: images/investigation_4_jam_statistics_table.png")


In [None]:

# Jam dispersal analysis - Method B - Monte Carlo analysis

# Parameters (same as Method A)
L = 200
N = 80
v_max = 5
p = 0.25
min_cluster_size = 10
M_form = 10
K_clear = 10
T_warm = 500
T_max = 5000  # Maximum observation time per run
R = 30  # Number of runs

print(f"\nParameters: L={L}, N={N}, p={p}, ρ={N/L:.2f}")
print(f"Jam detection: min_cluster={min_cluster_size} cars, M_form={M_form}, K_clear={K_clear}")
print(f"Monte Carlo: R={R} runs, T_warm={T_warm}, T_max={T_max}")

# Helper function (same as Method A)
def jam_present(road, min_cluster_size=10):
    """Detect if there are localized jam clusters (contiguous slow cars)"""
    slow_mask = (road >= 0) & (road <= 1)
    cluster_size = 0
    for i in range(len(road)):
        if slow_mask[i]:
            cluster_size += 1
            if cluster_size >= min_cluster_size:
                return True
        else:
            cluster_size = 0
    return False

def simulate_one_run(seed, L, N, v_max, p, min_cluster_size, M_form, K_clear, T_warm, T_max):
    """Run one simulation and return (jam_lifetimes, persistent_flag)"""
    rng = np.random.default_rng(seed)
    road = init_road(L, N, rng)

    # Warm-up
    for t in range(T_warm):
        road, _ = step(road, v_max, p, rng)

    # Observation
    jam_on_streak = 0
    jam_off_streak = 0
    in_jam = False
    t_form = None
    jam_lifetimes = []

    for t in range(T_max):
        is_jam_present = jam_present(road, min_cluster_size)

        if is_jam_present:
            jam_on_streak += 1
            jam_off_streak = 0
        else:
            jam_off_streak += 1
            jam_on_streak = 0

        if (not in_jam) and (jam_on_streak == M_form):
            in_jam = True
            t_form = t

        if in_jam and (jam_off_streak == K_clear):
            t_clear = t
            tau = t_clear - t_form
            jam_lifetimes.append(tau)
            in_jam = False
            t_form = None

        road, _ = step(road, v_max, p, rng)

    # Check if jam persisted through entire observation
    persistent_flag = (in_jam and t_form is not None)

    return jam_lifetimes, persistent_flag

def run_monte_carlo(R, L, N, v_max, p, min_cluster_size, M_form, K_clear, T_warm, T_max):
    """Run Monte Carlo simulation over R runs"""
    all_lifetimes = []  # All jam lifetimes from all runs
    taus_per_run = []  # Mean lifetime per run (only runs with cleared jams)
    persistent_count = 0

    print("\nRunning Monte Carlo simulations...")
    for r in range(R):
        seed = 1000 + r  # Unique seed for each run
        jam_lifetimes, persistent = simulate_one_run(
            seed, L, N, v_max, p, min_cluster_size, M_form, K_clear, T_warm, T_max
        )

        if persistent:
            persistent_count += 1

        if len(jam_lifetimes) > 0:
            all_lifetimes.extend(jam_lifetimes)
            taus_per_run.append(np.mean(jam_lifetimes))

        if (r + 1) % 10 == 0:
            print(f"  Completed {r + 1}/{R} runs")

    return all_lifetimes, taus_per_run, persistent_count

# Run Monte Carlo
all_lifetimes, taus_per_run, persistent_count = run_monte_carlo(
    R, L, N, v_max, p, min_cluster_size, M_form, K_clear, T_warm, T_max
)

# Compute statistics
persistent_fraction = persistent_count / R
cleared_runs = len(taus_per_run)

print("\n" + "="*70)
print("MONTE CARLO RESULTS")
print("="*70)
print(f"Total runs: {R}")
print(f"Runs with cleared jams: {cleared_runs}")
print(f"Runs with persistent jams: {persistent_count} ({persistent_fraction*100:.1f}%)")
print(f"\nTotal jam episodes detected: {len(all_lifetimes)}")

if len(all_lifetimes) > 0:
    print(f"\nMean jam lifetime (all episodes): {np.mean(all_lifetimes):.1f} steps")
    print(f"Median jam lifetime: {np.median(all_lifetimes):.1f} steps")
    print(f"Std deviation: {np.std(all_lifetimes):.1f} steps")
    print(f"Min/Max: {np.min(all_lifetimes)} / {np.max(all_lifetimes)} steps")

if len(taus_per_run) > 0:
    print(f"\nMean of per-run averages: {np.mean(taus_per_run):.1f} steps")
    print(f"Std of per-run averages: {np.std(taus_per_run):.1f} steps")

print("="*70)

# Store Method B results for comparison
method_b_mean = np.mean(all_lifetimes) if len(all_lifetimes) > 0 else 0
method_b_median = np.median(all_lifetimes) if len(all_lifetimes) > 0 else 0
method_b_std = np.std(all_lifetimes) if len(all_lifetimes) > 0 else 0
method_b_episodes = len(all_lifetimes)
method_b_persistent = persistent_fraction * 100

# Create comparison table
print("\n" + "="*70)
print("METHOD COMPARISON: Time-Average (A) vs Monte Carlo (B)")
print("="*70)
print(f"{'Metric':<35} {'Method A':<15} {'Method B':<15}")
print("-" * 70)
print(f"{'Approach':<35} {'Single run':<15} {f'{R} runs':<15}")
print(f"{'Total jam episodes':<35} {'16':<15} {method_b_episodes:<15}")
print(f"{'Mean lifetime (steps)':<35} {'46.4':<15} {method_b_mean:<15.1f}")
print(f"{'Median lifetime (steps)':<35} {'45.0':<15} {method_b_median:<15.1f}")
print(f"{'Std deviation (steps)':<35} {'N/A':<15} {method_b_std:<15.1f}")
print(f"{'Persistent jam fraction (%)':<35} {'0.0':<15} {method_b_persistent:<15.1f}")
print("="*70)
print("Note: Method A values are from the single time-average run.")
print("      Method B provides ensemble statistics over multiple realizations.")
print("="*70)

# Create visual comparison table and save as PNG
fig_comp, ax_comp = plt.subplots(figsize=(10, 5))
comparison_data = [
    ['Approach', 'Single long run', f'{R} independent runs'],
    ['Total Jam Episodes', '16', f'{method_b_episodes}'],
    ['Mean Lifetime', '46.4 steps', f'{method_b_mean:.1f} steps'],
    ['Median Lifetime', '45.0 steps', f'{method_b_median:.1f} steps'],
    ['Std Deviation', 'N/A', f'{method_b_std:.1f} steps'],
    ['Persistent Jam Fraction', '0.0%', f'{method_b_persistent:.1f}%'],
    ['', '', ''],
    ['Observation Period', f'{T_max} steps', f'{R} × {T_max} steps'],
    ['Statistical Confidence', 'Single realization', 'Ensemble average']
]

ax_comp.axis('tight')
ax_comp.axis('off')
table_comp = ax_comp.table(cellText=comparison_data, 
                            colLabels=['Metric', 'Method A\n(Time-Average)', 'Method B\n(Monte Carlo)'],
                            cellLoc='left', loc='center', colWidths=[0.4, 0.3, 0.3])
table_comp.auto_set_font_size(False)
table_comp.set_fontsize(11)
table_comp.scale(1, 2.5)

# Style header
for i in range(3):
    table_comp[(0, i)].set_facecolor('#4472C4')
    table_comp[(0, i)].set_text_props(weight='bold', color='white')

# Alternate row colors (skip empty row 7)
for i in range(1, len(comparison_data) + 1):
    if i == 7:  # Skip empty row
        continue
    for j in range(3):
        if i % 2 == 0:
            table_comp[(i, j)].set_facecolor('#E7E6E6')

# Highlight section separator
for j in range(3):
    table_comp[(8, j)].set_facecolor('#D9E1F2')

ax_comp.set_title('Method Comparison: Time-Average vs Monte Carlo', 
                  fontsize=14, fontweight='bold', pad=20)
plt.tight_layout()
plt.savefig('images/investigation_4_method_comparison_table.png', dpi=150, bbox_inches='tight')
plt.show()
print("\n✓ Comparison table saved: images/investigation_4_method_comparison_table.png")

# Visualization
if len(all_lifetimes) > 0:
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))

    # 1. Histogram of all jam lifetimes
    ax1 = axes[0, 0]
    ax1.hist(all_lifetimes, bins=20, edgecolor='black', alpha=0.7, color='steelblue')
    ax1.axvline(np.mean(all_lifetimes), color='red', linestyle='--', linewidth=2,
                label=f'Mean = {np.mean(all_lifetimes):.1f}')
    ax1.axvline(np.median(all_lifetimes), color='orange', linestyle='--', linewidth=2,
                label=f'Median = {np.median(all_lifetimes):.1f}')
    ax1.set_xlabel('Jam Lifetime (steps)', fontsize=11)
    ax1.set_ylabel('Frequency', fontsize=11)
    ax1.set_title('Distribution of all jam lifetimes\n(Monte Carlo)', fontsize=12, fontweight='bold')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # 2. Per-run mean lifetimes
    ax2 = axes[0, 1]
    if len(taus_per_run) > 0:
        ax2.hist(taus_per_run, bins=15, edgecolor='black', alpha=0.7, color='seagreen')
        ax2.axvline(np.mean(taus_per_run), color='red', linestyle='--', linewidth=2,
                    label=f'Mean = {np.mean(taus_per_run):.1f}')
        ax2.set_xlabel('Mean jam lifetime per run (steps)', fontsize=11)
        ax2.set_ylabel('Frequency', fontsize=11)
        ax2.set_title(f'Distribution of per-run averages\n({cleared_runs} runs with cleared jams)',
                      fontsize=12, fontweight='bold')
        ax2.legend()
        ax2.grid(True, alpha=0.3)

    # 3. Summary statistics table
    ax3 = axes[1, 0]
    stats_data = [
        ['Monte Carlo runs (R)', f'{R}'],
        ['Runs with cleared jams', f'{cleared_runs}'],
        ['Persistent jam fraction', f'{persistent_fraction*100:.1f}%'],
        ['Total jam episodes', f'{len(all_lifetimes)}'],
        ['', ''],
        ['Mean lifetime (all)', f'{np.mean(all_lifetimes):.1f} steps'],
        ['Median lifetime', f'{np.median(all_lifetimes):.1f} steps'],
        ['Std deviation', f'{np.std(all_lifetimes):.1f} steps'],
    ]

    if len(taus_per_run) > 0:
        stats_data.extend([
            ['', ''],
            ['Mean of run averages', f'{np.mean(taus_per_run):.1f} steps'],
            ['Std of run averages', f'{np.std(taus_per_run):.1f} steps']
        ])

    ax3.axis('tight')
    ax3.axis('off')
    table = ax3.table(cellText=stats_data, colLabels=['Metric', 'Value'],
                      cellLoc='left', loc='center', colWidths=[0.65, 0.35])
    table.auto_set_font_size(False)
    table.set_fontsize(10)
    table.scale(1, 2)

    for i in range(2):
        table[(0, i)].set_facecolor('#4472C4')
        table[(0, i)].set_text_props(weight='bold', color='white')

    for i in range(1, len(stats_data) + 1):
        for j in range(2):
            if i % 2 == 0:
                table[(i, j)].set_facecolor('#E7E6E6')

    ax3.set_title('Monte Carlo statistics', fontsize=12, fontweight='bold', pad=20)

    # 4. Comparison text box
    ax4 = axes[1, 1]
    ax4.axis('off')

    comparison_text = f"""Method Comparison

Method A (time-average):
• Single long run in steady state
• Mean jam lifetime from one realization

Method B (Monte Carlo):
• {R} independent runs
• Mean across {len(all_lifetimes)} jam episodes
• {persistent_fraction*100:.1f}% of runs had persistent jams

Conclusion:
At ρ={N/L:.2f} and p={p}, jam episodes have
an average dispersal time of {np.mean(all_lifetimes):.1f} steps.
Across independent realizations, the mean
dispersal time is {np.mean(taus_per_run):.1f} ± {np.std(taus_per_run):.1f} steps,
and {persistent_fraction*100:.1f}% of runs remain jammed
for the full observation window.
"""

    ax4.text(0.1, 0.5, comparison_text, fontsize=10, verticalalignment='center',
             bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.3))

    # Save combined figure
    plt.tight_layout()
    plt.savefig('images/investigation_4_monte_carlo_analysis.png', dpi=150, bbox_inches='tight')
    plt.show()
    print("\n✓ Combined figure saved: images/investigation_4_monte_carlo_analysis.png")

    # Save individual subplots as separate files
    # 1. All lifetimes histogram
    fig1, ax = plt.subplots(figsize=(10, 6))
    ax.hist(all_lifetimes, bins=20, edgecolor='black', alpha=0.7, color='steelblue')
    ax.axvline(np.mean(all_lifetimes), color='red', linestyle='--', linewidth=2,
                label=f'Mean = {np.mean(all_lifetimes):.1f}')
    ax.axvline(np.median(all_lifetimes), color='orange', linestyle='--', linewidth=2,
                label=f'Median = {np.median(all_lifetimes):.1f}')
    ax.set_xlabel('Jam lifetime (steps)', fontsize=12)
    ax.set_ylabel('Frequency', fontsize=12)
    ax.set_title('Distribution of all jam lifetimes (Monte Carlo)', fontsize=14, fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig('images/investigation_4_mc_all_lifetimes.png', dpi=150, bbox_inches='tight')
    plt.close()
    print("✓ Saved: images/investigation_4_mc_all_lifetimes.png")

    # 2. Per-run averages histogram
    if len(taus_per_run) > 0:
        fig2, ax = plt.subplots(figsize=(10, 6))
        ax.hist(taus_per_run, bins=15, edgecolor='black', alpha=0.7, color='seagreen')
        ax.axvline(np.mean(taus_per_run), color='red', linestyle='--', linewidth=2,
                    label=f'Mean = {np.mean(taus_per_run):.1f}')
        ax.set_xlabel('Mean jam lifetime per run (steps)', fontsize=12)
        ax.set_ylabel('Frequency', fontsize=12)
        ax.set_title(f'Distribution of per-run averages ({cleared_runs} runs with cleared jams)',
                      fontsize=14, fontweight='bold')
        ax.legend()
        ax.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.savefig('images/investigation_4_mc_per_run_averages.png', dpi=150, bbox_inches='tight')
        plt.close()
        print("✓ Saved: images/investigation_4_mc_per_run_averages.png")


# Investigation 4C: Sensitivity Analysis of Jam Detection Parameters

**Question:** How sensitive are the Monte Carlo results to the choice of jam detection parameters?

Test four parameter combinations to assess robustness:
- Baseline: cluster=10, M=10, K=10
- Variant 1: cluster=8, M=10, K=10 (looser spatial criterion)
- Variant 2: cluster=12, M=10, K=10 (stricter spatial criterion)
- Variant 3: cluster=10, M=5, K=5 (shorter persistence requirements)

In [None]:
# ============================================================================
# CELL: Sensitivity Analysis - Jam Detection Parameters
# ============================================================================

print("\n" + "="*70)
print("SENSITIVITY ANALYSIS: JAM DETECTION PARAMETERS")
print("="*70)

# Fixed parameters
L = 200
N = 80
v_max = 5
p = 0.25
T_warm = 500
T_max = 5000
R = 30

# Parameter variants to test
variants = [
    {'name': 'Baseline', 'cluster': 10, 'M': 10, 'K': 10},
    {'name': 'Variant 1 (looser spatial)', 'cluster': 8, 'M': 10, 'K': 10},
    {'name': 'Variant 2 (stricter spatial)', 'cluster': 12, 'M': 10, 'K': 10},
    {'name': 'Variant 3 (shorter persistence)', 'cluster': 10, 'M': 5, 'K': 5}
]

# Helper function for this analysis
def jam_present_param(road, min_cluster_size):
    slow_mask = (road >= 0) & (road <= 1)
    cluster_size = 0
    for i in range(len(road)):
        if slow_mask[i]:
            cluster_size += 1
            if cluster_size >= min_cluster_size:
                return True
        else:
            cluster_size = 0
    return False

def simulate_one_run_param(seed, L, N, v_max, p, min_cluster_size, M_form, K_clear, T_warm, T_max):
    rng = np.random.default_rng(seed)
    road = init_road(L, N, rng)
    
    for t in range(T_warm):
        road, _ = step(road, v_max, p, rng)
    
    jam_on_streak = 0
    jam_off_streak = 0
    in_jam = False
    t_form = None
    jam_lifetimes = []
    
    for t in range(T_max):
        is_jam_present = jam_present_param(road, min_cluster_size)
        
        if is_jam_present:
            jam_on_streak += 1
            jam_off_streak = 0
        else:
            jam_off_streak += 1
            jam_on_streak = 0
        
        if (not in_jam) and (jam_on_streak == M_form):
            in_jam = True
            t_form = t
        
        if in_jam and (jam_off_streak == K_clear):
            t_clear = t
            tau = t_clear - t_form
            jam_lifetimes.append(tau)
            in_jam = False
            t_form = None
        
        road, _ = step(road, v_max, p, rng)
    
    persistent_flag = (in_jam and t_form is not None)
    return jam_lifetimes, persistent_flag

# Run sensitivity analysis
results = []

for variant in variants:
    print(f"\nTesting {variant['name']}...")
    print(f"  Parameters: cluster={variant['cluster']}, M={variant['M']}, K={variant['K']}")
    
    all_lifetimes = []
    taus_per_run = []
    persistent_count = 0
    
    for r in range(R):
        seed = 1000 + r
        jam_lifetimes, persistent = simulate_one_run_param(
            seed, L, N, v_max, p, variant['cluster'], variant['M'], variant['K'], T_warm, T_max
        )
        
        if persistent:
            persistent_count += 1
        
        if len(jam_lifetimes) > 0:
            all_lifetimes.extend(jam_lifetimes)
            taus_per_run.append(np.mean(jam_lifetimes))
    
    # Store results
    result = {
        'name': variant['name'],
        'params': f"c={variant['cluster']}, M={variant['M']}, K={variant['K']}",
        'total_jams': len(all_lifetimes),
        'jams_per_run': len(all_lifetimes) / R,
        'mean_lifetime': np.mean(all_lifetimes) if len(all_lifetimes) > 0 else 0,
        'median_lifetime': np.median(all_lifetimes) if len(all_lifetimes) > 0 else 0,
        'std_lifetime': np.std(all_lifetimes) if len(all_lifetimes) > 0 else 0,
        'persistent_fraction': persistent_count / R * 100
    }
    results.append(result)
    
    print(f"  Total jams: {result['total_jams']} ({result['jams_per_run']:.1f} per run)")
    print(f"  Mean lifetime: {result['mean_lifetime']:.1f} steps")
    print(f"  Persistent: {result['persistent_fraction']:.1f}%")

# Create comparison table
print("\n" + "="*70)
print("SENSITIVITY ANALYSIS SUMMARY")
print("="*70)
print(f"{'Variant':<30} {'Jams/Run':<12} {'Mean τ':<12} {'Persistent %':<12}")
print("-" * 70)
for r in results:
    print(f"{r['name']:<30} {r['jams_per_run']:<12.1f} {r['mean_lifetime']:<12.1f} {r['persistent_fraction']:<12.1f}")
print("="*70)

# Visual comparison table
fig_sens, ax_sens = plt.subplots(figsize=(12, 5))
sens_data = [
    [r['name'], r['params'], f"{r['jams_per_run']:.1f}", f"{r['mean_lifetime']:.1f}", 
     f"{r['median_lifetime']:.1f}", f"{r['persistent_fraction']:.1f}%"]
    for r in results
]

ax_sens.axis('tight')
ax_sens.axis('off')
table_sens = ax_sens.table(cellText=sens_data,
                            colLabels=['Variant', 'Parameters', 'Jams/Run', 'Mean τ (steps)', 
                                       'Median τ (steps)', 'Persistent %'],
                            cellLoc='center', loc='center', colWidths=[0.25, 0.2, 0.12, 0.15, 0.15, 0.13])
table_sens.auto_set_font_size(False)
table_sens.set_fontsize(10)
table_sens.scale(1, 2.5)

# Style header
for i in range(6):
    table_sens[(0, i)].set_facecolor('#4472C4')
    table_sens[(0, i)].set_text_props(weight='bold', color='white')

# Highlight baseline
for j in range(6):
    table_sens[(1, j)].set_facecolor('#D9E1F2')
    table_sens[(1, j)].set_text_props(weight='bold')

# Alternate row colors for variants
for i in range(2, 5):
    for j in range(6):
        if i % 2 == 0:
            table_sens[(i, j)].set_facecolor('#E7E6E6')

ax_sens.set_title('Sensitivity Analysis: Impact of Jam Detection Parameters', 
                  fontsize=14, fontweight='bold', pad=20)
plt.tight_layout()
plt.savefig('images/investigation_4_sensitivity_analysis.png', dpi=150, bbox_inches='tight')
plt.show()
print("\n✓ Sensitivity analysis table saved: images/investigation_4_sensitivity_analysis.png")

# Analysis
print("\n" + "="*70)
print("INTERPRETATION")
print("="*70)
baseline = results[0]
print(f"Baseline results: {baseline['jams_per_run']:.1f} jams/run, mean τ = {baseline['mean_lifetime']:.1f} steps")
print(f"\nVariability across parameter choices:")
jams_range = max(r['jams_per_run'] for r in results) - min(r['jams_per_run'] for r in results)
tau_range = max(r['mean_lifetime'] for r in results) - min(r['mean_lifetime'] for r in results)
print(f"  Jams/run range: {jams_range:.1f} ({jams_range/baseline['jams_per_run']*100:.1f}% of baseline)")
print(f"  Mean τ range: {tau_range:.1f} steps ({tau_range/baseline['mean_lifetime']*100:.1f}% of baseline)")
print("\nConclusion: Results show [low/moderate/high] sensitivity to detection parameters.")
print("="*70)
