In [None]:
import torch
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os

start_num = 20
end_num = 105
plot_part = "val"  # "train" or "val"
mode = "tiedw" 

# Define folders programmatically
base_dir = "."
folders = [folder for folder in os.listdir(base_dir) if folder.startswith(f"out_config_{mode}_") and os.path.isdir(os.path.join(base_dir, folder))]

folder_dict = {}

# --- 1. Concise Names ---
name_map = {
    "original": "Original (GPT-3 Small)",
    "identity": "W_Q = Identity",
    "identity_larger_mlp": "W_Q = Identity (Larger MLP)",
    "original_smaller_mlp": "Original (Smaller MLP)",
    "original_smaller_overall": "Original (Smaller Model Dim)",
\
}

# --- 2. Parameter Counts ---
param_counts = {
    "original": 84.95,
    "identity": 77.88,
    "identity_larger_mlp": 84.95,
    "original_smaller_mlp": 77.88,
    "original_smaller_overall": 79.73,
}

plt.figure(figsize=(12, 7))
for folder in folders:
    full_path = os.path.join(base_dir, folder)
    file_path = os.path.join(full_path, f"eval_{plot_part}_tensor.pt")
    
    if not os.path.exists(file_path):
        continue

    series = pd.Series(torch.load(file_path)[start_num:end_num, 1].cpu().numpy())#.apply(lambda x: np.exp(x))  # Convert log loss to loss

    # Identify model
    model_key = folder.split(f"{mode}_")[-1]
    folder_dict[model_key] = series

    # --- 3. Construct Cleaner Label ---
    readable_name = name_map.get(model_key, f"{mode}_{model_key}")
    p_count = param_counts.get(model_key, "?")
    
    # Just the name and the number
    final_label = f"{readable_name} ({p_count}M)"

    x_values = np.arange(start_num, start_num + len(series))
    plt.plot(x_values, series.tolist(), label=final_label)


# --- 4. Explanation in Legend Title ---
plt.legend(title="Model Configuration\n(Non-embedding/LMHead params)")

plt.title(f"Tied weights Configurations - {plot_part.capitalize()} Loss")
plt.xlabel("Gradient Step (x 1000)")
plt.ylabel("LogLoss")
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt


# 2. Determine length of the pre-filtered arrays
num_points = len(folder_dict["original"])

# 3. Create x_values starting at start_num
# This ensures the first data point is plotted at x = 20
x_values = np.arange(start_num, start_num + num_points)

plt.figure(figsize=(12, 6))

# 4. Plot each series using the shifted x_values
plt.plot(x_values, (100 * (folder_dict["identity_larger_mlp"] - folder_dict["original_smaller_mlp"]) / folder_dict["original"]), 
         color="red", label="identity_larger_mlp vs original_smaller_mlp")

plt.plot(x_values, (100 * (folder_dict["identity"] - folder_dict["original_smaller_overall"]) / folder_dict["original"]), 
         color="blue", label="identity vs original_smaller_overall")

plt.plot(x_values, (100 * (folder_dict["original"] - folder_dict["original_smaller_mlp"]) / folder_dict["original"]), 
          color="green", label="original vs original_smaller_mlp")

plt.plot(x_values, (100 * (folder_dict["identity_larger_mlp"] - folder_dict["original"]) / folder_dict["original"]), 
         color="black", label="identity_larger_mlp vs original")

plt.plot(x_values, (100 * (folder_dict["original"] - folder_dict["identity"]) / folder_dict["original"]), 
         color="orange", label="original vs identity")

# 5. Visual Formatting
# This forces the axis to show the range from 0, even though data starts at start_num
plt.xlim(left=0)

# Optional: Add a vertical line to clearly mark where the data begins
#plt.axvline(x=start_num, color='black', linestyle=':', alpha=0.5, label=f'Start ({start_num})')

# Add a horizontal dotted line at 0%
plt.axhline(y=0, color='gray', linestyle='--', linewidth=1)
plt.axhline(y=-0.1, color='gray', linestyle='--', linewidth=1)
plt.axhline(y=0.1, color='gray', linestyle='--', linewidth=1)

plt.axhline(y=-0.35, color='gray', linestyle='--', linewidth=1)

# Labels and Legend
plt.xlabel("Gradient Step (x 1000)")
plt.ylabel("Percentage Difference Relative to the Original model's loss(%)")
plt.title(f"Relative **{plot_part.capitalize()}** LogLoss Differences ")
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left') # Moved legend outside for clarity
plt.grid(True, which='both', linestyle='--', alpha=0.5)
plt.tight_layout()

plt.show()
