### <span style="color:#db7d60">Setup</span>

Import dependencies

In [12]:
import json
import math
import pickle
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pUtil
import textwrap
from IPython.display import display
import matplotlib.pyplot as plt
from matplotlib.table import Table

Compare training

In [13]:
def plot_model_train_info(model_name, ax=None, use_epochs=True, y_lim=None, use_log_scale=True):
    training_log_filename = pUtil.get_training_dir(model_name) / "train_log_1.jsonl"
    meta_filename = pUtil.get_model_meta_filename(model_name)
    config_filename = pUtil.get_model_config_filename(model_name)
    
    with open(meta_filename, 'rb') as meta_file:
        meta_data = pickle.load(meta_file)
        max_sequence_length = meta_data['max_sequence_length']
        num_train_tokens = meta_data['num_train_tokens']

    with open(config_filename, 'r') as config_file:
        config = json.load(config_file)
        training_config = config.get('training_config', {})
        batch_size = training_config.get('batch_size', -1)
        block_size = training_config.get('block_size', -1)
        context_events = training_config.get('context_events', -1)
        
        if block_size == -1:
            block_size = context_events * max_sequence_length
        
        iterations_per_epoch = num_train_tokens // (batch_size * block_size)

    running_data, saved_data = {}, {}
    with open(training_log_filename) as training_log_file:
        for jline in training_log_file:
            jdata = json.loads(jline)
            if jdata.get("message") == "Training progress" and "iter" in jdata:
                current_epochs_trained = 0 if jdata['iter'] == 0 else (jdata['iter'] / iterations_per_epoch)
                running_data[jdata["iter"]] = current_epochs_trained, jdata["train_loss"], jdata["val_loss"]
            elif jdata.get("message") == "Training progress: checking checkpoint conditions":
                current_epochs_trained = 0 if jdata['step'] == 0 else (jdata['step'] / iterations_per_epoch)
                saved_data[jdata["step"]] = current_epochs_trained, jdata["train_loss"], jdata["val_loss"]

    sorted_running_iters = sorted(running_data)
    running_epochs_trained, running_train_loss, running_val_loss = zip(*(running_data[i] for i in sorted_running_iters))

    saved_saved_iters = sorted(saved_data)
    saved_epochs_trained, saved_train_loss, saved_val_loss = zip(*(saved_data[i] for i in saved_saved_iters))

    # Find min loss points
    min_val_loss_iter_idx = saved_val_loss.index(min(saved_val_loss))
    min_saved_losses_iter = saved_saved_iters[min_val_loss_iter_idx]
    min_saved_val_loss_epoch, min_saved_train_loss, min_saved_val_loss = saved_epochs_trained[min_val_loss_iter_idx], saved_train_loss[min_val_loss_iter_idx], saved_val_loss[min_val_loss_iter_idx]
    
    ax = ax or plt
    if use_epochs:
        train_plot_line, = ax.plot(running_epochs_trained, running_train_loss, label=f'Training Loss ({model_name})')
        val_plot_line, = ax.plot(running_epochs_trained, running_val_loss, label=f'Validation Loss ({model_name})', color=train_plot_line.get_color(), linestyle='--')
        ax.scatter(min_saved_val_loss_epoch, min_saved_train_loss, label=f'Min Train Loss ({model_name}; ${min_saved_train_loss:.4f}$ @ epoch {min_saved_val_loss_epoch:.5f})', color=train_plot_line.get_color(), edgecolors='black')
        ax.scatter(min_saved_val_loss_epoch, min_saved_val_loss, label=f'Min Val Loss ({model_name}; ${min_saved_val_loss:.4f}$ @ epoch {min_saved_val_loss_epoch:.5f})', color=train_plot_line.get_color(), edgecolors='black')
        if ax is not plt:
            ax.set_xlim([0, max(running_epochs_trained)])
            if y_lim is not None:
                ax.set_ylim(y_lim)
            ax.set_xlabel('Epochs')
            ax.set_ylabel('Loss')
            if use_log_scale:
                ax.set_yscale('log', base=10)
            ax.set_title(f'{model_name}')
            ax.legend()
    else:
        train_plot_line, = ax.plot(sorted_running_iters, running_train_loss, label=f'Training Loss ({model_name})', linewidth=0.5)
        val_plot_line, = ax.plot(sorted_running_iters, running_val_loss, label=f'Validation Loss ({model_name})', color=train_plot_line.get_color(), linestyle='--', linewidth=0.5)
        ax.scatter(min_saved_losses_iter, min_saved_val_loss, label=f'Min Val Loss ({model_name}; ${min_saved_val_loss:.4f}$ @ iter {min_saved_losses_iter})', color=train_plot_line.get_color(), edgecolors='black', marker='s', s=50, )
        ax.scatter(min_saved_losses_iter, min_saved_train_loss, label=f'Min Train Loss ({model_name}; ${min_saved_train_loss:.4f}$ @ iter {min_saved_losses_iter})', color=train_plot_line.get_color(), edgecolors='black')
        if ax is not plt:
            if y_lim is not None:
                ax.set_ylim(y_lim)
            ax.set_xlabel('Iteration')
            ax.set_ylabel('Loss')
            if use_log_scale:
                ax.set_yscale('log', base=10)
            ax.set_title(f'{model_name}')
            ax.legend()
    
    x_offset = 0.005 * sorted_running_iters[-1]
    y_offset = 0.02
    # ax.annotate(model_name, xy=(sorted_running_iters[-1], running_train_loss[-1]), xytext=(sorted_running_iters[-1] + x_offset, running_train_loss[-1] + y_offset), fontsize=9, color=train_plot_line.get_color())
    ax.annotate(model_name, xy=(sorted_running_iters[-1], running_val_loss[-1]), xytext=(sorted_running_iters[-1] + x_offset, running_val_loss[-1] - y_offset), fontsize=9, color=val_plot_line.get_color())
    
def plot_train_graphs(models_to_compare, juxtaposed=True, use_epochs=True, y_lim=None, use_log_scale=True):
    if juxtaposed:
        num_horizontal, num_vertical = min(len(models_to_compare), 3), (math.ceil(len(models_to_compare) / 3))
        figure, axes = plt.subplots(num_vertical, num_horizontal, figsize=(8 * num_horizontal, 6 * num_vertical), sharex=False, sharey=True)
        if len(models_to_compare) == 1:
            axes = [axes]
        axes = np.atleast_1d(axes).flatten()
        for model_name, ax in zip(models_to_compare, axes):
            plot_model_train_info(model_name, ax=ax, use_epochs=use_epochs, y_lim=y_lim, use_log_scale=use_log_scale)
        figure.suptitle(f'Training Progress for {", ".join(models_to_compare)}')
        plt.tight_layout()
        plt.grid(axis="y", linestyle="--", alpha=0.7)
        plt.show()
    else:
        plt.figure(figsize=(15, 6))
        for model_name in models_to_compare:
            plot_model_train_info(model_name, use_epochs=use_epochs, y_lim=y_lim, use_log_scale=use_log_scale)
        wrapped_title = "\n".join(textwrap.wrap(f'Training Progress for {", ".join(models_to_compare)}', width=60))
        plt.title(wrapped_title)
        if use_log_scale:
            plt.yscale('log', base=10)
        plt.xlabel('Iteration')
        plt.ylabel('Loss')
        if y_lim is not None:
            plt.ylim(y_lim)
        plt.legend(loc='center left', bbox_to_anchor=(1.02, 0.5), borderaxespad=0.)
        plt.grid(axis="y", linestyle="--", alpha=0.7)

Compare distributions

In [14]:
columns = ["num_particles", "pdgid", "e", "px", "py", "pz", "eta", "theta", "phi"]

def get_common_data(model_name):
    dictionary_filename = pUtil.get_model_preparation_dir(model_name) / 'dictionary.json'
    real_leading_test_particles_filename = pUtil.get_model_preparation_dir(model_name) / 'real_leading_test_particles.csv'
    sampled_leading_particles_filename = pUtil.get_latest_sampling_dir(model_name) / 'sampled_leading_particles.csv'

    with open(dictionary_filename) as dictionary_file:
        dictionary = json.load(dictionary_file)

    # Convenience dictionary definitions
    p_bin_count = (dictionary["e_bin_data"]["max"] - dictionary["e_bin_data"]["min"]) // 1000
    e_bin_count = (dictionary["e_bin_data"]["max"] - dictionary["e_bin_data"]["min"]) // dictionary["e_bin_data"]["step_size"]
    eta_bin_count = int((dictionary["eta_bin_data"]["max"] - dictionary["eta_bin_data"]["min"]) // dictionary["eta_bin_data"]["step_size"])
            
    bin_settings = {
        "num_particles": { "min": 0,                                 "max": 50,                                "bins": 50 },
        "e":             { "min": dictionary["e_bin_data"]["min"],   "max": dictionary["e_bin_data"]["max"],   "bins": e_bin_count },
        "px":            { "min": dictionary["e_bin_data"]["min"],   "max": dictionary["e_bin_data"]["max"],   "bins": p_bin_count },
        "py":            { "min": dictionary["e_bin_data"]["min"],   "max": dictionary["e_bin_data"]["max"],   "bins": p_bin_count },
        "pz":            { "min": dictionary["e_bin_data"]["min"],   "max": dictionary["e_bin_data"]["max"],   "bins": p_bin_count },
        "eta":           { "min": dictionary["eta_bin_data"]["min"], "max": dictionary["eta_bin_data"]["max"], "bins": eta_bin_count },
        "theta":         { "min": -2 * np.pi,                        "max": 2 * np.pi,                         "bins": int((4 * np.pi) // dictionary["theta_bin_data"]["step_size"]) },
        "phi":           { "min": -2 * np.pi,                        "max": 2 * np.pi,                         "bins": int((4 * np.pi) // dictionary["phi_bin_data"]["step_size"]) },
    }

    df1 = pd.read_csv(real_leading_test_particles_filename, sep=" ", names=columns, engine="c", header=None)
    df2 = pd.read_csv(sampled_leading_particles_filename, sep=" ", names=columns, engine="c", header=None)
    return bin_settings, df1, df2

def generate_distributions(model_name, column_name, ax=None):
    bin_settings, df1, df2 = get_common_data(model_name)
    
    min_val = bin_settings[column_name]['min']
    max_val = bin_settings[column_name]['max']
    bins = bin_settings[column_name]['bins']
    
    df1_weights = np.ones_like(df1[column_name]) / len(df1[column_name])
    df2_weights = np.ones_like(df2[column_name]) / len(df2[column_name])
    
    ax = ax or plt
    ax.hist(df1[column_name], bins=bins, weights=df1_weights, range=(min_val, max_val), edgecolor="black", alpha=0.7, color="blue", label=f'Input ({model_name})')
    ax.hist(df2[column_name], bins=bins, weights=df2_weights, range=(min_val, max_val), edgecolor="black", alpha=0.7, color="orange", label=f'Sampled ({model_name})')
    if ax is not plt:
        ax.set_xlabel(column_name)
        ax.set_ylabel('Frequency (Normalized)')
        ax.set_title(f'{model_name}')
        ax.legend()

def compare_distributions(models_to_compare, column_name, juxtaposed=True, dists_per_row=3):
    if juxtaposed:
        num_horizontal, num_vertical = min(len(models_to_compare), dists_per_row), (math.ceil(len(models_to_compare) / dists_per_row))
        figure, axes = plt.subplots(num_vertical, num_horizontal, figsize=(8 * num_horizontal, 6 * num_vertical), sharex=False, sharey=True)
        if len(models_to_compare) == 1:
            axes = [axes]
        axes = np.atleast_1d(axes).flatten()
        for model_name, ax in zip(models_to_compare, axes):
            generate_distributions(model_name, column_name=column_name, ax=ax)
        figure.suptitle(f'Training Progress for {", ".join(models_to_compare)}')
        plt.tight_layout()
        plt.grid(axis="y", linestyle="--", alpha=0.7)
        plt.show()
    else:
        plt.figure(figsize=(15, 6))
        for model_name in models_to_compare:
            generate_distributions(model_name, column_name=column_name)
        plt.title(f'Training Progress for {", ".join(models_to_compare)}')
        plt.xlabel('Iteration')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(axis="y", linestyle="--", alpha=0.7)

### <span style="color:#db7d60">Comparisions</span>

In [None]:
def get_train_data_as_row(model_name):
    training_log_filename = pUtil.get_training_dir(model_name) / "train_log_1.jsonl"
    meta_filename = pUtil.get_model_meta_filename(model_name)
    config_filename = pUtil.get_model_config_filename(model_name)
    
    with open(meta_filename, 'rb') as meta_file:
        meta_data = pickle.load(meta_file)
        max_sequence_length = meta_data['max_sequence_length']
        num_train_tokens = meta_data['num_train_tokens']

    with open(config_filename, 'r') as config_file:
        config = json.load(config_file)
        training_config = config.get('training_config', {})
        batch_size = training_config.get('batch_size', -1)
        block_size = training_config.get('block_size', -1)
        context_events = training_config.get('context_events', -1)
        learning_rate = training_config.get('learning_rate', -1)
        lr_decay_iters = training_config.get('lr_decay_iters', -1)
        min_lr = training_config.get('min_lr', -1)
        n_layer = training_config.get('n_layer', -1)
        n_head = training_config.get('n_head', -1)
        n_embd = training_config.get('n_embd', -1)
        
        if block_size == -1:
            block_size = context_events * max_sequence_length
        
        iterations_per_epoch = num_train_tokens // (batch_size * block_size)
        
    running_data, saved_data = [], []
    with open(training_log_filename) as training_log_file:
        for jline in training_log_file:
            jdata = json.loads(jline)
            if jdata.get("message") == "Training progress" and "iter" in jdata:
                current_epochs_trained = 0 if jdata['iter'] == 0 else (jdata['iter'] / iterations_per_epoch)
                running_data.append({'iter': jdata["iter"], 'epoch': current_epochs_trained, 'train_loss': jdata["train_loss"], 'val_loss': jdata["val_loss"]})
            elif jdata.get("message") == "Training progress: checking checkpoint conditions":
                current_epochs_trained = 0 if jdata['step'] == 0 else (jdata['step'] / iterations_per_epoch)
                saved_data.append({'iter': jdata["step"], 'epoch': current_epochs_trained, 'train_loss': jdata["train_loss"], 'val_loss': jdata["val_loss"]})
                
    running_df = pd.DataFrame(running_data)
    saved_df = pd.DataFrame(saved_data)
    
    iters_trained = running_df['iter'].max()
    min_saved_val_loss_row = saved_df.loc[saved_df['val_loss'].idxmin()]
    
    return {
        'model_name': model_name,
        "iters_trained": iters_trained,
        'learning_rate': learning_rate,
        'min_lr': min_lr,
        'lr_decay_iters': lr_decay_iters,
        'n_layer': n_layer,
        'n_head': n_head,
        'n_embd': n_embd,
        'min_saved_train_loss': min_saved_val_loss_row['train_loss'],
        'min_saved_val_loss': min_saved_val_loss_row['val_loss']
    }
    
models_to_compare = [
    'model_5_3_1',
    'model_5_3_2', 
    'model_5_3_3',
    'model_5_3_4',
    'model_5_3_5',
    'model_5_3_6',
    'model_5_3_7',
    'model_5_3_8', 
    'model_5_3_9',
    'model_5_3_10',
    'model_5_3_11',
    'model_5_3_12',
    'model_5_3_13', 
    'model_5_3_14',
    'model_5_3_15',
    'model_5_3_16',
    'model_5_3_17',
    'model_5_3_18',
    'model_5_3_19',
    'model_5_3_20',
    'model_5_3_21'
]

columns = ["model_name", "iters_trained", "learning_rate", "min_lr", "lr_decay_iters", "n_layer", "n_head", "n_embd", "min_saved_train_loss", "min_saved_val_loss"]
model_data_list = []
for model_name in models_to_compare:
    row = get_train_data_as_row(model_name)
    model_data_list.append(row)
    
model_data_df = pd.DataFrame(model_data_list, columns=columns)
model_data_df = model_data_df.applymap(lambda x: f"{x:,}" if isinstance(x, int) else x)

df_sorted_by_min_val_loss = model_data_df.sort_values(by="min_saved_val_loss", ascending=True)

display(df_sorted_by_min_val_loss)

##### Comparing different learning rates and their effect on validation loss and the model distributions

In [None]:
def get_train_data_for_model(model_name, ax=None, use_epochs=True):
    training_log_filename = pUtil.get_training_dir(model_name) / "train_log_1.jsonl"
    meta_filename = pUtil.get_model_meta_filename(model_name)
    config_filename = pUtil.get_model_config_filename(model_name)
    
    with open(meta_filename, 'rb') as meta_file:
        meta_data = pickle.load(meta_file)
        max_sequence_length = meta_data['max_sequence_length']
        num_train_tokens = meta_data['num_train_tokens']

    with open(config_filename, 'r') as config_file:
        config = json.load(config_file)
        training_config = config.get('training_config', {})
        batch_size = training_config.get('batch_size', -1)
        block_size = training_config.get('block_size', -1)
        context_events = training_config.get('context_events', -1)
        learning_rate = training_config.get('learning_rate', -1)
        lr_decay_iters = training_config.get('lr_decay_iters', -1)
        min_lr = training_config.get('min_lr', -1)
        
        if block_size == -1:
            block_size = context_events * max_sequence_length
        
        iterations_per_epoch = num_train_tokens // (batch_size * block_size)

    running_data, saved_data = {}, {}
    with open(training_log_filename) as training_log_file:
        for jline in training_log_file:
            jdata = json.loads(jline)
            if jdata.get("message") == "Training progress" and "iter" in jdata:
                current_epochs_trained = 0 if jdata['iter'] == 0 else (jdata['iter'] / iterations_per_epoch)
                running_data[jdata["iter"]] = current_epochs_trained, jdata["train_loss"], jdata["val_loss"]
            elif jdata.get("message") == "Training progress: checking checkpoint conditions":
                current_epochs_trained = 0 if jdata['step'] == 0 else (jdata['step'] / iterations_per_epoch)
                saved_data[jdata["step"]] = current_epochs_trained, jdata["train_loss"], jdata["val_loss"]

    sorted_running_iters = sorted(running_data)
    running_epochs_trained, running_train_loss, running_val_loss = zip(*(running_data[i] for i in sorted_running_iters))

    saved_saved_iters = sorted(saved_data)
    saved_epochs_trained, saved_train_loss, saved_val_loss = zip(*(saved_data[i] for i in saved_saved_iters))

    # Find min loss points
    min_val_loss_iter_idx = saved_val_loss.index(min(saved_val_loss))
    min_saved_losses_iter = saved_saved_iters[min_val_loss_iter_idx]
    min_saved_val_loss_epoch, min_saved_train_loss, min_saved_val_loss = saved_epochs_trained[min_val_loss_iter_idx], saved_train_loss[min_val_loss_iter_idx], saved_val_loss[min_val_loss_iter_idx]

    return learning_rate, min_lr, lr_decay_iters, min_saved_train_loss, min_saved_val_loss

def plot_learning_rate_effects(model_information_dict):
    fig = plt.figure(figsize=(10, 7))
    ax = fig.add_subplot(111, projection='3d')

    learning_rates = []
    min_lrs = []
    val_losses = []
    labels = []
    for model_name, info in model_information_dict.items():
        learning_rates.append(info['learning_rate'])
        min_lrs.append(info['min_lr'])
        val_losses.append(info['min_saved_val_loss'])
        labels.append(model_name)
    
     # Identify the best (lowest) original validation loss
    raw_val_losses = [info['min_saved_val_loss'] for info in model_information_dict.values() if info['learning_rate'] > 0 and info['min_lr'] > 0 and info['min_saved_val_loss'] > 0]
    min_idx = np.argmin(raw_val_losses)
    
    best_x = learning_rates[min_idx]
    best_y = min_lrs[min_idx]
    best_z = val_losses[min_idx]
    best_label = labels[min_idx]

    scatter = ax.scatter(learning_rates, min_lrs, val_losses, c=val_losses, cmap='viridis', s=60)
    ax.scatter([best_x], [best_y], [best_z], color='red', s=100, label=f'Best Model: {best_label}')
    
    for lr, mlr, vl, label in zip(learning_rates, min_lrs, val_losses, labels):
        ax.text(lr, mlr, vl, label, size=8)

    ax.set_xlabel("Learning Rate")
    ax.set_ylabel("Min LR")
    ax.set_zlabel("Min Validation Loss")
    ax.set_title("Effect of Learning Rates on Min Validation Loss")
    fig.colorbar(scatter, ax=ax, label='Min Val Loss')
    ax.legend()
    plt.tight_layout()
    plt.show()

def plot_learning_rate_effects_log10(model_information_dict):
    fig = plt.figure(figsize=(10, 7))
    ax = fig.add_subplot(111, projection='3d')

    learning_rates = []
    min_lrs = []
    val_losses = []
    labels = []
    for model_name, info in model_information_dict.items():
        lr = info['learning_rate']
        mlr = info['min_lr']
        vl = info['min_saved_val_loss']

        # Skip invalid or non-positive values
        if lr > 0 and mlr > 0 and vl > 0:
            learning_rates.append(np.log10(lr))
            min_lrs.append(np.log10(mlr))
            val_losses.append(np.log10(vl))
            labels.append(model_name)

    # Identify the best (lowest) original validation loss
    raw_val_losses = [info['min_saved_val_loss'] for info in model_information_dict.values() if info['learning_rate'] > 0 and info['min_lr'] > 0 and info['min_saved_val_loss'] > 0]
    min_idx = np.argmin(raw_val_losses)
    
    best_x = learning_rates[min_idx]
    best_y = min_lrs[min_idx]
    best_z = val_losses[min_idx]
    best_label = labels[min_idx]

    scatter = ax.scatter(learning_rates, min_lrs, val_losses, c=val_losses, cmap='viridis', s=60)

    # Highlight the best point
    ax.scatter([best_x], [best_y], [best_z], color='red', s=100, label=f'Best Model: {best_label}')

    for x, y, z, label in zip(learning_rates, min_lrs, val_losses, labels):
        ax.text(x, y, z, label, size=8)

    # Axis labels in log10 terms
    ax.set_xlabel("log10(Learning Rate)")
    ax.set_ylabel("log10(Min LR)")
    ax.set_zlabel("log10(Min Validation Loss)")
    ax.set_title("Effect of Learning Rates on Min Validation Loss")
    fig.colorbar(scatter, ax=ax, label='log10(Min Val Loss)')
    ax.legend()
    plt.tight_layout()
    plt.show()

def create_model_info_table(model_information_dict):
    # Sort model data by min_saved_val_loss
    sorted_items = sorted(
        model_information_dict.items(),
        key=lambda x: x[1]['min_saved_val_loss']
    )

    # Prepare table data
    columns = ["Model Name", "Learning Rate", "Min LR", "Decay Iters", "Avg Learning Rate", "Min Val Loss"]
    cell_data = [
        [
            model_name,
            f"{info['learning_rate']:.2e}",
            f"{info['min_lr']:.2e}",
            f"{info['lr_decay_iters']:.5f}",
            f"{(info['learning_rate'] - info['min_lr']) / info['lr_decay_iters']:.2e}",
            f"{info['min_saved_val_loss']:.5f}"
        ]
        for model_name, info in sorted_items
    ]

    fig, ax = plt.subplots(figsize=(10, len(cell_data) * 0.4 + 1))
    ax.axis('off')

    table = ax.table(
        cellText=cell_data,
        colLabels=columns,
        cellLoc='center',
        loc='center'
    )

    table.auto_set_font_size(False)
    table.set_fontsize(10)
    table.scale(1.2, 1.2)
    plt.title("Model Training Summary (Sorted by Min Val Loss)", fontsize=14, pad=20)
    plt.tight_layout()
    plt.show()

models_to_compare = ['model_5_3_1', 'model_5_3_2', 'model_5_3_3', 'model_5_3_4', 'model_5_3_8', 
                     'model_5_3_9', 'model_5_3_10', 'model_5_3_11', 'model_5_3_12', 'model_5_3_13', 
                     'model_5_3_14', 'model_5_3_15', 'model_5_3_16', 'model_5_3_17', 'model_5_3_18']
model_information_dict = {}
for model_name in models_to_compare:
    learning_rate, min_lr, lr_decay_iters, min_saved_train_loss, min_saved_val_loss = get_train_data_for_model(model_name)
    model_information_dict[model_name] = {
        "learning_rate": learning_rate,
        "min_lr": min_lr,
        "lr_decay_iters": lr_decay_iters,
        "min_saved_train_loss": min_saved_train_loss,
        "min_saved_val_loss": min_saved_val_loss,
    }

summary_stats = {}
value_keys = next(iter(model_information_dict.values())).keys()
for key in value_keys:
    all_values = [info[key] for info in model_information_dict.values()]
    summary_stats[key] = {
        'min': min(all_values),
        'max': max(all_values)
    }

for key, stats in summary_stats.items():
    print(f"{key}:")
    print(f"  Min: {stats['min']}")
    print(f"  Max: {stats['max']}")

models_to_compare = ['model_5_3_1', 'model_5_3_2', 'model_5_3_3', 'model_5_3_4', 'model_5_3_5', 'model_5_3_6', 'model_5_3_7', 'model_5_3_8', 'model_5_3_9', 'model_5_3_10', 'model_5_3_11', 'model_5_3_12',]
plot_train_graphs(models_to_compare, juxtaposed=False, use_epochs=False, use_log_scale=True)
models_to_compare = ['model_5_3_9', 'model_5_3_10', 'model_5_3_12']
plot_train_graphs(models_to_compare, juxtaposed=False, use_epochs=False, y_lim=[0.78, 0.85], use_log_scale=True)

print(json.dumps(model_information_dict, indent=4))
create_model_info_table(model_information_dict)
plot_learning_rate_effects(model_information_dict)
plot_learning_rate_effects_log10(model_information_dict)

In [None]:
models_to_compare = ['model_5_2_10', 'model_5_2_8', 'model_5_2_7', 'model_5_2_6', 'model_5_2_5', 'model_5_2_4', 'model_5_2_3', 'model_5_2_2', 'model_5_2_1']
for column in columns:
    if column == 'pdgid':
        continue
    compare_distributions(models_to_compare, column_name=column, juxtaposed=True, dists_per_row=5)

##### Comparing different number of layers and their effect on validation loss and model distributions

In [None]:
def get_train_data_for_model(model_name, ax=None, use_epochs=True):
    training_log_filename = pUtil.get_training_dir(model_name) / "train_log_1.jsonl"
    meta_filename = pUtil.get_model_meta_filename(model_name)
    config_filename = pUtil.get_model_config_filename(model_name)
    
    with open(meta_filename, 'rb') as meta_file:
        meta_data = pickle.load(meta_file)
        max_sequence_length = meta_data['max_sequence_length']
        num_train_tokens = meta_data['num_train_tokens']

    with open(config_filename, 'r') as config_file:
        config = json.load(config_file)
        training_config = config.get('training_config', {})
        batch_size = training_config.get('batch_size', -1)
        block_size = training_config.get('block_size', -1)
        context_events = training_config.get('context_events', -1)
        learning_rate = training_config.get('learning_rate', -1)
        lr_decay_iters = training_config.get('lr_decay_iters', -1)
        num_layers = training_config.get('n_layer', -1)
        min_lr = training_config.get('min_lr', -1)
        
        if block_size == -1:
            block_size = context_events * max_sequence_length
        
        iterations_per_epoch = num_train_tokens // (batch_size * block_size)

    running_data, saved_data = {}, {}
    with open(training_log_filename) as training_log_file:
        for jline in training_log_file:
            jdata = json.loads(jline)
            if jdata.get("message") == "Training progress" and "iter" in jdata:
                current_epochs_trained = 0 if jdata['iter'] == 0 else (jdata['iter'] / iterations_per_epoch)
                running_data[jdata["iter"]] = current_epochs_trained, jdata["train_loss"], jdata["val_loss"]
            elif jdata.get("message") == "Training progress: checking checkpoint conditions":
                current_epochs_trained = 0 if jdata['step'] == 0 else (jdata['step'] / iterations_per_epoch)
                saved_data[jdata["step"]] = current_epochs_trained, jdata["train_loss"], jdata["val_loss"]

    sorted_running_iters = sorted(running_data)
    running_epochs_trained, running_train_loss, running_val_loss = zip(*(running_data[i] for i in sorted_running_iters))

    saved_saved_iters = sorted(saved_data)
    saved_epochs_trained, saved_train_loss, saved_val_loss = zip(*(saved_data[i] for i in saved_saved_iters))

    # Find min loss points
    min_val_loss_iter_idx = saved_val_loss.index(min(saved_val_loss))
    min_saved_losses_iter = saved_saved_iters[min_val_loss_iter_idx]
    min_saved_val_loss_epoch, min_saved_train_loss, min_saved_val_loss = saved_epochs_trained[min_val_loss_iter_idx], saved_train_loss[min_val_loss_iter_idx], saved_val_loss[min_val_loss_iter_idx]

    num_iters = len(sorted_running_iters)
    return num_iters, learning_rate, min_lr, lr_decay_iters, min_saved_train_loss, min_saved_val_loss, num_layers

def plot_learning_rate_effects(model_information_dict):
    fig = plt.figure(figsize=(10, 7))
    ax = fig.add_subplot(111, projection='3d')

    learning_rates = []
    min_lrs = []
    val_losses = []
    labels = []
    for model_name, info in model_information_dict.items():
        learning_rates.append(info['learning_rate'])
        min_lrs.append(info['min_lr'])
        val_losses.append(info['min_saved_val_loss'])
        labels.append(model_name)
    
     # Identify the best (lowest) original validation loss
    raw_val_losses = [info['min_saved_val_loss'] for info in model_information_dict.values() if info['learning_rate'] > 0 and info['min_lr'] > 0 and info['min_saved_val_loss'] > 0]
    min_idx = np.argmin(raw_val_losses)
    
    best_x = learning_rates[min_idx]
    best_y = min_lrs[min_idx]
    best_z = val_losses[min_idx]
    best_label = labels[min_idx]

    scatter = ax.scatter(learning_rates, min_lrs, val_losses, c=val_losses, cmap='viridis', s=60)
    ax.scatter([best_x], [best_y], [best_z], color='red', s=100, label=f'Best Model: {best_label}')
    
    for lr, mlr, vl, label in zip(learning_rates, min_lrs, val_losses, labels):
        ax.text(lr, mlr, vl, label, size=8)

    ax.set_xlabel("Learning Rate")
    ax.set_ylabel("Min LR")
    ax.set_zlabel("Min Validation Loss")
    ax.set_title("Effect of Learning Rates on Min Validation Loss")
    fig.colorbar(scatter, ax=ax, label='Min Val Loss')
    ax.legend()
    plt.tight_layout()
    plt.show()

def plot_learning_rate_effects_log10(model_information_dict):
    fig = plt.figure(figsize=(10, 7))
    ax = fig.add_subplot(111, projection='3d')

    learning_rates = []
    min_lrs = []
    val_losses = []
    labels = []
    for model_name, info in model_information_dict.items():
        lr = info['learning_rate']
        mlr = info['min_lr']
        vl = info['min_saved_val_loss']

        # Skip invalid or non-positive values
        if lr > 0 and mlr > 0 and vl > 0:
            learning_rates.append(np.log10(lr))
            min_lrs.append(np.log10(mlr))
            val_losses.append(np.log10(vl))
            labels.append(model_name)

    # Identify the best (lowest) original validation loss
    raw_val_losses = [info['min_saved_val_loss'] for info in model_information_dict.values() if info['learning_rate'] > 0 and info['min_lr'] > 0 and info['min_saved_val_loss'] > 0]
    min_idx = np.argmin(raw_val_losses)
    
    best_x = learning_rates[min_idx]
    best_y = min_lrs[min_idx]
    best_z = val_losses[min_idx]
    best_label = labels[min_idx]

    scatter = ax.scatter(learning_rates, min_lrs, val_losses, c=val_losses, cmap='viridis', s=60)

    # Highlight the best point
    ax.scatter([best_x], [best_y], [best_z], color='red', s=100, label=f'Best Model: {best_label}')

    for x, y, z, label in zip(learning_rates, min_lrs, val_losses, labels):
        ax.text(x, y, z, label, size=8)

    # Axis labels in log10 terms
    ax.set_xlabel("log10(Learning Rate)")
    ax.set_ylabel("log10(Min LR)")
    ax.set_zlabel("log10(Min Validation Loss)")
    ax.set_title("Effect of Learning Rates on Min Validation Loss")
    fig.colorbar(scatter, ax=ax, label='log10(Min Val Loss)')
    ax.legend()
    plt.tight_layout()
    plt.show()

def create_model_info_table(model_information_dict):
    # Sort model data by min_saved_val_loss
    sorted_items = sorted(
        model_information_dict.items(),
        key=lambda x: x[1]['min_saved_val_loss']
    )

    # Prepare table data
    columns = ["Model Name", "Num Iters", "Learning Rate", "Min LR", "Decay Iters", "Avg Learning Rate", "Num Layers", "Min Val Loss"]
    cell_data = [
        [
            model_name,
            f"{info['num_iters']}",
            f"{info['learning_rate']:.2e}",
            f"{info['min_lr']:.2e}",
            f"{info['lr_decay_iters']:.5f}",
            f"{(info['learning_rate'] - info['min_lr']) / info['lr_decay_iters']:.2e}",
            f"{info['num_layers']}",
            f"{info['min_saved_val_loss']:.5f}"
        ]
        for model_name, info in sorted_items
    ]

    fig, ax = plt.subplots(figsize=(10, len(cell_data) * 0.4 + 1))
    ax.axis('off')

    table = ax.table(
        cellText=cell_data,
        colLabels=columns,
        cellLoc='center',
        loc='center'
    )

    table.auto_set_font_size(False)
    table.set_fontsize(10)
    table.scale(1.2, 1.2)
    plt.title("Model Training Summary (Sorted by Min Val Loss)", fontsize=14, pad=20)
    plt.tight_layout()
    plt.show()

models_to_compare = ['model_5_3_1', 'model_5_3_2', 'model_5_3_3', 'model_5_3_4', 'model_5_3_8', 
                     'model_5_3_9', 'model_5_3_10', 'model_5_3_11', 'model_5_3_12', 'model_5_3_13', 
                     'model_5_3_14', 'model_5_3_15', 'model_5_3_16', 'model_5_3_17', 'model_5_3_18']
model_information_dict = {}
for model_name in models_to_compare:
    num_iters, learning_rate, min_lr, lr_decay_iters, min_saved_train_loss, min_saved_val_loss, num_layers = get_train_data_for_model(model_name)
    model_information_dict[model_name] = {
        "num_iters": num_iters,
        "learning_rate": learning_rate,
        "min_lr": min_lr,
        "lr_decay_iters": lr_decay_iters,
        "min_saved_train_loss": min_saved_train_loss,
        "min_saved_val_loss": min_saved_val_loss,
        "num_layers": num_layers
    }

summary_stats = {}
value_keys = next(iter(model_information_dict.values())).keys()
for key in value_keys:
    all_values = [info[key] for info in model_information_dict.values()]
    summary_stats[key] = {
        'min': min(all_values),
        'max': max(all_values)
    }

for key, stats in summary_stats.items():
    print(f"{key}:")
    print(f"  Min: {stats['min']}")
    print(f"  Max: {stats['max']}")

# models_to_compare = ['model_5_3_1', 'model_5_3_2', 'model_5_3_3', 'model_5_3_4', 'model_5_3_5', 'model_5_3_6', 'model_5_3_7', 'model_5_3_8', 'model_5_3_9', 'model_5_3_10', 'model_5_3_11', 'model_5_3_12',]
# plot_train_graphs(models_to_compare, juxtaposed=False, use_epochs=False, use_log_scale=True)
# models_to_compare = ['model_5_3_9', 'model_5_3_10', 'model_5_3_12']
# plot_train_graphs(models_to_compare, juxtaposed=False, use_epochs=False, y_lim=[0.78, 0.85], use_log_scale=True)

print(json.dumps(model_information_dict, indent=4))
create_model_info_table(model_information_dict)
plot_learning_rate_effects(model_information_dict)
plot_learning_rate_effects_log10(model_information_dict)

##### Investigating the broken eta distribution

In [None]:
import csv

columns = ["num_particles", "pdgid", "e", "px", "py", "pz", "eta", "theta", "phi"]

def get_common_data(model_name):
    dictionary_filename = pUtil.get_model_preparation_dir(model_name) / 'dictionary.json'
    # real_leading_test_particles_filename = pUtil.get_model_preparation_dir(model_name) / 'real_leading_test_particles.csv'
    real_leading_test_particles_filename = pUtil.get_temp_dir() / 'real_leading_test_particles.txt'
    sampled_leading_particles_filename = pUtil.get_latest_sampling_dir(model_name) / 'sampled_leading_particles.csv'

    with open(dictionary_filename) as dictionary_file:
        dictionary = json.load(dictionary_file)

    # Convenience dictionary definitions
    p_bin_count = (dictionary["e_bin_data"]["max"] - dictionary["e_bin_data"]["min"]) // 1000
    e_bin_count = (dictionary["e_bin_data"]["max"] - dictionary["e_bin_data"]["min"]) // dictionary["e_bin_data"]["step_size"]
    eta_bin_count = int((dictionary["eta_bin_data"]["max"] - dictionary["eta_bin_data"]["min"]) // dictionary["eta_bin_data"]["step_size"])
            
    bin_settings = {
        "num_particles": { "min": 0,                                 "max": 50,                                "bins": 50 },
        "e":             { "min": dictionary["e_bin_data"]["min"],   "max": dictionary["e_bin_data"]["max"],   "bins": e_bin_count },
        "px":            { "min": dictionary["e_bin_data"]["min"],   "max": dictionary["e_bin_data"]["max"],   "bins": p_bin_count },
        "py":            { "min": dictionary["e_bin_data"]["min"],   "max": dictionary["e_bin_data"]["max"],   "bins": p_bin_count },
        "pz":            { "min": dictionary["e_bin_data"]["min"],   "max": dictionary["e_bin_data"]["max"],   "bins": p_bin_count },
        "eta":           { "min": dictionary["eta_bin_data"]["min"], "max": dictionary["eta_bin_data"]["max"], "bins": eta_bin_count },
        "theta":         { "min": -2 * np.pi,                        "max": 2 * np.pi,                         "bins": int((4 * np.pi) // dictionary["theta_bin_data"]["step_size"]) },
        "phi":           { "min": -2 * np.pi,                        "max": 2 * np.pi,                         "bins": int((4 * np.pi) // dictionary["phi_bin_data"]["step_size"]) },
    }

    df1 = pd.read_csv(real_leading_test_particles_filename, sep=" ", names=columns, engine="c", header=None)
    df2 = pd.read_csv(sampled_leading_particles_filename, sep=" ", names=columns, engine="c", header=None)
    return bin_settings, df1, df2

def generate_distributions(model_name, column_name, ax=None):
    bin_settings, df1, df2 = get_common_data(model_name)
    
    min_val = bin_settings[column_name]['min']
    max_val = bin_settings[column_name]['max']
    bins = bin_settings[column_name]['bins']
    
    df1_weights = np.ones_like(df1[column_name]) / len(df1[column_name])
    df2_weights = np.ones_like(df2[column_name]) / len(df2[column_name])
    
    ax = ax or plt
    ax.hist(df1[column_name], bins=bins, weights=df1_weights, range=(min_val, max_val), edgecolor="black", alpha=0.7, color="blue", label=f'Input ({model_name})')
    ax.hist(df2[column_name], bins=bins, weights=df2_weights, range=(min_val, max_val), edgecolor="black", alpha=0.7, color="orange", label=f'Sampled ({model_name})')
    if ax is not plt:
        ax.set_xlabel(column_name)
        ax.set_ylabel('Frequency (Normalized)')
        ax.set_title(f'{model_name}')
        ax.legend()

def compare_distributions(models_to_compare, column_name, juxtaposed=True, dists_per_row=3):
    if juxtaposed:
        num_horizontal, num_vertical = min(len(models_to_compare), dists_per_row), (math.ceil(len(models_to_compare) / dists_per_row))
        figure, axes = plt.subplots(num_vertical, num_horizontal, figsize=(8 * num_horizontal, 6 * num_vertical), sharex=False, sharey=True)
        if len(models_to_compare) == 1:
            axes = [axes]
        axes = np.atleast_1d(axes).flatten()
        for model_name, ax in zip(models_to_compare, axes):
            generate_distributions(model_name, column_name=column_name, ax=ax)
        figure.suptitle(f'Training Progress for {", ".join(models_to_compare)}')
        plt.tight_layout()
        plt.grid(axis="y", linestyle="--", alpha=0.7)
        plt.show()
    else:
        plt.figure(figsize=(15, 6))
        for model_name in models_to_compare:
            generate_distributions(model_name, column_name=column_name)
        plt.title(f'Training Progress for {", ".join(models_to_compare)}')
        plt.xlabel('Iteration')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(axis="y", linestyle="--", alpha=0.7)
        
models_to_compare = ['model_5_2_1']
for column in columns:
    if column == 'pdgid':
        continue
    compare_distributions(models_to_compare, column_name=column, juxtaposed=True, dists_per_row=5)

The eta distribution seems to be fine. There is something about the tokenizer that creates the brokenness in the distributions.
Furthermore this shows we should probably start comparing tokenized to tokenized not real to tokenized.

### <span style="color:#db7d60">Keeping track of things</span>
These are just general plots I've made to analyze data.

In [None]:
models_to_compare = ['model_5_3_13', 'model_5_3_14', 'model_5_3_15', 'model_5_3_16', 'model_5_3_18', 'model_5c1_1_1']
plot_train_graphs(models_to_compare, juxtaposed=False, use_epochs=False, use_log_scale=True, y_lim=[0.78, 0.85])
plot_train_graphs(models_to_compare, juxtaposed=False, use_epochs=False, use_log_scale=True)

models_to_compare = ['model_5_2_10']
compare_distributions(models_to_compare, column_name='eta', juxtaposed=True, dists_per_row=5)

In [None]:
models_to_compare = ['model_5c1_1_1', 'model_5c1_1_2']
plot_train_graphs(models_to_compare, juxtaposed=False, use_epochs=False, use_log_scale=True, y_lim=[1.05, 1.15])

Create a train graph for all the models.

In [None]:
models_to_compare = [
    'model_5_3_1',
    'model_5_3_2',
    'model_5_3_3',
    'model_5_3_4',
    'model_5_3_5',
    'model_5_3_6',
    'model_5_3_7',
    'model_5_3_8',
    'model_5_3_9',
    'model_5_3_10',
    'model_5_3_11',
    'model_5_3_12',
    'model_5_3_13',
    'model_5_3_14',
    'model_5_3_15',
    'model_5_3_16',
    'model_5_3_17',
    'model_5_3_18',
    'model_5_3_19',
    'model_5_3_20',
    'model_5_3_21'
]

plot_train_graphs(models_to_compare, juxtaposed=False, use_epochs=False, use_log_scale=True)
plot_train_graphs(models_to_compare, juxtaposed=False, use_epochs=False, use_log_scale=True, y_lim=[0.7, 1])

Tracking models without particle boundaries.

In [None]:
models_to_compare = [
    'model_5_5_1',
    'model_5_5_2',
    'model_5_5_3'
]

plot_train_graphs(models_to_compare, juxtaposed=False, use_epochs=False, use_log_scale=True)

ValueError: not enough values to unpack (expected 3, got 0)

<Figure size 1500x600 with 0 Axes>