In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from utils.plotting import plot_parallel_coordinates, plot_top1_across_horizons
from preprocess_data import deduplicate_and_aggregate

### Preprocess

In [None]:
PREFIX = "mass-scan-per-layer"
df = pd.read_csv(f"data/{PREFIX}-base-momentum-1.0-preprocessed-seeds.csv")
df.columns

In [None]:
agg = deduplicate_and_aggregate(df)
agg.columns

In [None]:
agg.to_csv(f"data/{PREFIX}-base-momentum-1.0-preprocessed-seeds-agg.csv", index=False)

### Plot

In [None]:
# Plot overlayed histograms of train_loss_std for each BS value
# Filter data where bs * step == 2**33
target_product = 2**33 // 4096
filtered_df = df[df['bs'] * df['step'] == target_product]

bs_values = sorted(filtered_df['bs'].unique())
colors = plt.cm.Set1(np.linspace(0, 1, len(bs_values)))

# Calculate shared bin range across all BS values
min_val = filtered_df['train_loss_std'].min()
max_val = filtered_df['train_loss_std'].max()
bins = np.linspace(min_val, max_val, 61)  # 61 bins with consistent width

plt.figure(figsize=(10, 6))

for i, bs in enumerate(bs_values):
    subset = filtered_df[filtered_df['bs'] == bs]
    plt.hist(subset['train_loss_std'], bins=bins, alpha=0.6, 
             label=f'bs = {bs}', color=colors[i], edgecolor='black', linewidth=0.5)
    
    # Add median line
    median_val = subset['train_loss_std'].median()
    plt.axvline(median_val, color=colors[i], linestyle='--', linewidth=2, 
                alpha=0.8, label=f'bs = {bs} median: {median_val:.4f}')

plt.yscale('log')
plt.xlabel('train_loss_std')
plt.ylabel('Frequency')
plt.title(f'Distribution of train_loss_std by Batch Size (bs × step = {target_product:e})')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Print info about filtered data
print(f"Original data: {len(df)} rows")
print(f"Filtered data: {len(filtered_df)} rows")
print(f"Batch sizes in filtered data: {sorted(filtered_df['bs'].unique())}")

In [None]:
# PREFIX = "mass-scan-per-layer"
PREFIX = "lr-bs-scan-per-layer"
csv_path = f"data/{PREFIX}-base-momentum-1.0-preprocessed-seeds-agg.csv"
df = pd.read_csv(csv_path)

In [None]:
# bs=32
HORIZON = 2**33
BS = 32
LR_EXP_MIN = -11
LR_EXP_MAX = -1

In [None]:
# bs=128
HORIZON = 2**33
BS = 128
LR_EXP_MIN = -9
LR_EXP_MAX = 1

In [None]:
# bs=512
HORIZON = 2**33
BS = 512
LR_EXP_MIN = -8
LR_EXP_MAX = 2

In [None]:
fig, ax = plot_parallel_coordinates(
    df,
    horizon=HORIZON,
    bs=BS,
    lr_exp_min=LR_EXP_MIN,
    lr_exp_max=LR_EXP_MAX,
    top_k=3,
    figsize=(10, 6),
    top_quantile_pct=10,
    # opacity_levels=( 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 1.0),
    topk_opacity_levels=(1.0, 0.5, 0.5),
    jitter_std=0.007,
)
# fig.savefig(f'plots/lr-layout-bs-{BS}-horizon-{HORIZON}.pdf')

In [None]:
# bs=512
HORIZONS = np.logspace(29, 35, 4, base=2)
BS = 512
LR_EXP_MIN = -7
LR_EXP_MAX = 0

In [None]:
fig, ax = plot_top1_across_horizons(
    df,
    horizons=HORIZONS,
    bs=BS,
    lr_exp_min=LR_EXP_MIN,
    lr_exp_max=LR_EXP_MAX,
    jitter_std=0.017,
    jitter_nonbest_only=False,
    rng_seed=7,
    colors=["#440154", "#31688e", "#35b779", "#fde725",],
)
# fig.savefig(f'plots/mass-scan-bs-{BS}.pdf')