### Setup

In [1]:
import os, sys
# HOME = os.environ['HOME']  # change if necessary
HOME = '/workspace/'
sys.path.append(f'{HOME}/wilson/Finite-groups/src')

In [22]:
import torch as t
import numpy as np
from matplotlib import pyplot as plt
import json
from itertools import product
from jaxtyping import Float
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
import plotly.graph_objects as go
import copy
import math
from itertools import product
import pandas as pd
from typing import Union
from einops import repeat
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars
from hmmlearn import hmm


from model import MLP3, MLP4, InstancedModule
from utils import *
from group_data import *
from model_utils import *
from group_utils import *
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
t.set_grad_enabled(False)
device = t.device("cuda" if t.cuda.is_available() else "cpu")

In [42]:
# MODEL_DIR = f'{HOME}/models/2025-01-08_03-15-59_S4_A4x2_MLP2_32_ubias_wd2e-4_BIG_hmmmetrics'
# MODEL_DIR = f'{HOME}/models/2025-01-08_03-15-59_S4_A4x2_MLP2_32_ubias_wd2e-4_BIG_hmmmetrics'
MODEL_DIR = f'{HOME}/models/2025-01-08_22-01-30_S5_A5x2_MLP2_128_ubias_wd2e-5_hmmmetrics'
losses = load_loss_trajectory(MODEL_DIR)
hmm_metrics = load_hmm_trajectory(MODEL_DIR)
hmm_keys = list(hmm_metrics.keys())
hmm_metrics = t.stack([hmm_metrics[k] for k in hmm_keys], dim=1)  # instance hmm_metric epoch

  losses = [t.load(f) for f in loss_files]
  losses = [t.load(f) for f in loss_files]


### Fit HMM
See Hu et al. "Latent state models of training dynamics"

In [43]:
losses['G0_acc'].shape

torch.Size([100, 2001])

In [44]:
ACC_THRESH = 0.99
# DATA_SIZE = 30
DATA_SIZE = 100
g0_grokked = losses['G0_acc'][:,-1] > ACC_THRESH
g1_grokked = losses['G1_acc'][:,-1] > ACC_THRESH
none_grokked = ~(g0_grokked | g1_grokked)
g0_hmm_metrics = hmm_metrics[g0_grokked]
g1_hmm_metrics = hmm_metrics[g1_grokked]
none_hmm_metrics = hmm_metrics[none_grokked]
balanced_hmm_metrics = t.concat([g0_hmm_metrics[:DATA_SIZE], g1_hmm_metrics[:DATA_SIZE], none_hmm_metrics[:DATA_SIZE]], axis=0)
balanced_hmm_metrics = balanced_hmm_metrics[t.randperm(balanced_hmm_metrics.shape[0])]

In [45]:
train_size = int(0.8 * balanced_hmm_metrics.shape[0])
train_data = balanced_hmm_metrics[:train_size].cpu().numpy()
train_lengths = [train_data.shape[-1]] * train_data.shape[0]
train_data = einops.rearrange(train_data, 'instance metric epoch -> (instance epoch) metric')
train_data = (train_data - train_data.mean(axis=0)) / train_data.std(axis=0)

test_data = balanced_hmm_metrics[train_size:].cpu().numpy()
test_lengths = [test_data.shape[-1]] * test_data.shape[0]
test_data = einops.rearrange(test_data, 'instance metric epoch -> (instance epoch) metric')
test_data = (test_data - train_data.mean(axis=0)) / train_data.std(axis=0)

In [46]:
scores = []
bics = []
# cov_type = 'full'
cov_type = 'diag'
for n_components in range(1, 30):
    print('n_components', n_components)
    # best_bic = float('inf')
    best_score = float('-inf')
    for seed in range(5):
        model = hmm.GaussianHMM(n_components=n_components, covariance_type=cov_type, n_iter=1000, random_state=seed)
        model.fit(train_data, lengths=train_lengths)
        score = model.score(test_data, lengths=test_lengths)
        # bic = model.bic(test_data, lengths=test_lengths)   # bic only makes sense to evaluate on train data...
        best_score = max(best_score, score)
        # best_bic = min(best_bic, bic)
    print(f'score: {int(best_score):,}')
    # print('bic', best_bic)
    scores.append(best_score)
    # bics.append(best_bic)

n_components 1


  0%|          | 2/1000 [00:00<00:50, 19.86it/s]
  0%|          | 2/1000 [00:00<00:49, 20.12it/s]
  0%|          | 2/1000 [00:00<00:49, 20.32it/s]
  0%|          | 2/1000 [00:00<00:48, 20.49it/s]
  0%|          | 2/1000 [00:00<00:49, 20.01it/s]


score: -215,053,228,291
n_components 2


  2%|▏         | 22/1000 [00:01<01:26, 11.32it/s]
  2%|▏         | 22/1000 [00:01<01:27, 11.16it/s]
  2%|▏         | 22/1000 [00:01<01:26, 11.26it/s]
  2%|▏         | 22/1000 [00:01<01:26, 11.33it/s]
  2%|▏         | 22/1000 [00:01<01:28, 11.06it/s]


score: -226,407,291,751
n_components 3


  2%|▎         | 25/1000 [00:03<02:18,  7.06it/s]
  2%|▏         | 21/1000 [00:03<02:25,  6.74it/s]
  2%|▏         | 20/1000 [00:02<02:20,  7.00it/s]
  2%|▏         | 21/1000 [00:03<02:21,  6.93it/s]
  2%|▎         | 25/1000 [00:03<02:18,  7.03it/s]


score: -438,905,741,430
n_components 4


  5%|▍         | 46/1000 [00:08<03:02,  5.22it/s]
  5%|▌         | 53/1000 [00:10<03:11,  4.95it/s]
  4%|▍         | 45/1000 [00:09<03:13,  4.94it/s]
  4%|▍         | 43/1000 [00:08<03:10,  5.01it/s]
  5%|▍         | 46/1000 [00:09<03:09,  5.04it/s]


score: -445,301,741,327
n_components 5


  6%|▋         | 63/1000 [00:16<04:12,  3.71it/s]
  3%|▎         | 30/1000 [00:08<04:33,  3.54it/s]
  4%|▎         | 36/1000 [00:09<04:24,  3.65it/s]
  3%|▎         | 34/1000 [00:09<04:30,  3.56it/s]
  3%|▎         | 31/1000 [00:08<04:36,  3.51it/s]


score: -460,160,282,262
n_components 6


  4%|▍         | 44/1000 [00:15<05:37,  2.84it/s]
  4%|▎         | 36/1000 [00:12<05:30,  2.91it/s]
 12%|█▏        | 116/1000 [00:35<04:32,  3.25it/s]
  5%|▍         | 48/1000 [00:16<05:35,  2.83it/s]
 12%|█▏        | 115/1000 [00:37<04:50,  3.05it/s]


score: -460,529,648,192
n_components 7


  9%|▉         | 94/1000 [00:37<06:00,  2.51it/s]
  8%|▊         | 83/1000 [00:33<06:09,  2.48it/s]
  9%|▉         | 91/1000 [00:37<06:13,  2.43it/s]
  7%|▋         | 74/1000 [00:29<06:03,  2.55it/s]
 10%|▉         | 96/1000 [00:36<05:45,  2.62it/s]


score: -638,794,650,372
n_components 8


 12%|█▏        | 124/1000 [00:56<06:41,  2.18it/s]
 13%|█▎        | 126/1000 [00:59<06:51,  2.13it/s]
 10%|█         | 100/1000 [00:48<07:13,  2.08it/s]
 12%|█▏        | 119/1000 [00:54<06:43,  2.18it/s]
 22%|██▏       | 218/1000 [01:33<05:33,  2.34it/s]


score: -764,796,565,289
n_components 9


 14%|█▎        | 137/1000 [01:12<07:35,  1.90it/s]
  7%|▋         | 66/1000 [00:35<08:27,  1.84it/s]
  9%|▉         | 90/1000 [00:49<08:16,  1.83it/s]


KeyboardInterrupt: 

In [33]:
np.argmax(np.array(scores))

np.int64(0)

In [25]:
model.score(test_data, lengths=test_lengths)

-6558513275.564452

In [26]:
model.bic(test_data, lengths=test_lengths)

np.float64(13117029247.089188)