In [None]:
import wandb
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

api = wandb.Api()


In [None]:
def plot_cls_runs(loss_data, acc_data, keys, figsize=(12, 5)):
    """
    Plot train/loss (left) and val/accuracy (right) for the given run keys.
    loss_data: dict[key] -> list of train/loss values (per step)
    acc_data: dict[key] -> list of val/accuracy values (per epoch)
    keys: list of key names to plot
    """
    fig, (ax_loss, ax_acc) = plt.subplots(1, 2, figsize=figsize)

    for key in keys:
        if key not in loss_data or key not in acc_data:
            continue
        losses = loss_data[key]
        accs = acc_data[key]
        # Left: train loss (x = step index)
        ax_loss.plot(range(len(losses)), losses, label=key)
        # Right: val accuracy (x = epoch 1, 2, ...)
        epochs = list(range(1, len(accs) + 1))
        ax_acc.plot(epochs, accs, "o-", label=key)
        for ep, val in zip(epochs, accs):
            ax_acc.annotate(f"{val:.1f}", (ep, val), textcoords="offset points", xytext=(0, 6), ha="center", fontsize=8)

    ax_loss.set_xlabel("Step")
    ax_loss.set_ylabel("train/loss")
    ax_loss.set_title("Train Loss")
    ax_loss.legend()
    ax_loss.grid(True, alpha=0.3)

    ax_acc.set_xlabel("Epoch")
    ax_acc.set_ylabel("val/accuracy")
    ax_acc.set_title("Val Accuracy")
    ax_acc.legend()
    ax_acc.grid(True, alpha=0.3)
    plt.tight_layout()
    return fig, (ax_loss, ax_acc)

In [None]:
cls_runs = {
    "Baseline": "chentianyi453/CSE256_PA2_CLS/rdol824l",
    "RoPE": "chentianyi453/CSE256_PA2_CLS/78g1y84e",
    "DisentangledAttn": "chentianyi453/CSE256_PA2_CLS/17qoo6hl",
    "[CLS] Token": "chentianyi453/CSE256_PA2_CLS/xtpz3nln"
}
loss_data = {exp_name: api.run(run).history()["train/loss"].to_list() for exp_name, run in cls_runs.items()}
loss_data

In [None]:
data = pd.read_csv("/home/tianyichen/cse256/UCSD-CSE256-2026WI-PA2/cls_val_acc.csv")
acc_data = dict()
acc_data["Baseline"] = data["Baseline - val/accuracy"].tolist()
acc_data["RoPE"] = data["RoPE - val/accuracy"].tolist()
acc_data["DisentangledAttn"] = data["DisentangledAttn - val/accuracy"].tolist()
acc_data["[CLS] Token"] = data["[CLS] Token - val/accuracy"].tolist()
acc_data

In [None]:
# LM: train loss (every step) + train perplexity (5 points at 100,200,300,400,500)
train_ppl_df = pd.read_csv("/home/tianyichen/cse256/UCSD-CSE256-2026WI-PA2/train_ppl.csv")
steps_ppl = train_ppl_df["Step"].astype(int).tolist()
train_ppl = train_ppl_df["Baseline - train/perplexity"].tolist()
# train_loss: from wandb (500 steps)
runs_lm = ["chentianyi453/CSE256_PA2_LM/iubeig97"]
hist = api.run(runs_lm[0]).history()
train_loss = hist["train/loss"].tolist()[:500]

fig, (ax_loss, ax_ppl) = plt.subplots(1, 2, figsize=(12, 5))
ax_loss.plot(range(len(train_loss)), train_loss)
ax_loss.set_xlabel("Step")
ax_loss.set_ylabel("train/loss")
ax_loss.set_title("Train Loss")
ax_loss.grid(True, alpha=0.3)

ax_ppl.plot(steps_ppl, train_ppl, "o-")
for s, v in zip(steps_ppl, train_ppl):
    ax_ppl.annotate(f"{v:.1f}", (s, v), textcoords="offset points", xytext=(0, 8), ha="center", fontsize=9)
ax_ppl.set_xlabel("Step")
ax_ppl.set_ylabel("train/perplexity")
ax_ppl.set_title("Train Perplexity")
ax_ppl.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# LM: three test perplexities (wbush, obama, hbush) â€” 5 points each, annotated
base = "/home/tianyichen/cse256/UCSD-CSE256-2026WI-PA2"
wbush = pd.read_csv(f"{base}/wbush_ppl.csv")
obama = pd.read_csv(f"{base}/obama_ppl.csv")
hbush = pd.read_csv(f"{base}/hbush_ppl.csv")

def plot_ppl_5(ax, steps, values, title):
    ax.plot(steps, values, "o-")
    for s, v in zip(steps, values):
        ax.annotate(f"{v:.1f}", (s, v), textcoords="offset points", xytext=(0, 8), ha="center", fontsize=9)
    ax.set_xlabel("Step")
    ax.set_ylabel("Perplexity")
    ax.set_title(title)
    ax.grid(True, alpha=0.3)

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(14, 5))
steps = wbush["train/iter"].astype(int).tolist()
plot_ppl_5(ax1, steps, wbush["Baseline - val/perplexity_wbush"].tolist(), "Val Perplexity (W. Bush)")
plot_ppl_5(ax2, steps, obama["Baseline - val/perplexity_obama"].tolist(), "Val Perplexity (Obama)")
plot_ppl_5(ax3, steps, hbush["Baseline - val/perplexity_hbush"].tolist(), "Val Perplexity (H. Bush)")
plt.tight_layout()
plt.show()

In [None]:
# Example: after loading data and loss_data, call with desired keys
keys = ["Baseline", "RoPE"]
# keys = ["Baseline", "RoPE", "DisentangledAttn", "[CLS] Token"]
plot_cls_runs(loss_data, acc_data, keys)
plt.show()