In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import random
import pickle
import json
import math
from pathlib import Path
from functools import partial

import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from lightning import fabric
from transformer_lens import (
    HookedTransformer,
    HookedTransformerConfig,
    FactoredMatrix,
    ActivationCache,
)
from transformer_lens import utils
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.multioutput import MultiOutputClassifier, MultiOutputRegressor
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.metrics import log_loss
from einops import einsum, rearrange, unpack, repeat
from matplotlib import cm, colors
from tqdm import tqdm

from tic_tac_gpt.data import TicTacToeDataset, TicTacToeState, tensor_to_state

In [None]:
torch.set_grad_enabled(False)
torch.set_default_device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
checkpoint_dir = Path("out/model/exp24")

with open(checkpoint_dir / "config.pkl", "rb") as f:
    config: HookedTransformerConfig = pickle.load(f)
F = fabric.Fabric(precision="16-mixed")


def load_checkpoint(step: int):
    state_dict = F.load(checkpoint_dir / f"model_{step}.pt")
    model = HookedTransformer(config)
    model.load_state_dict(state_dict)
    model = model.eval()
    return model

In [None]:
def weight_norm(model):
    return torch.cat([p.view(-1) for p in model.parameters()]).norm().item()


norms = [weight_norm(load_checkpoint(step)) for step in range(1000, 40000, 1000)]

In [None]:
sns.lineplot(x=range(1000, 40000, 1000), y=norms, marker="o")