In [1]:
import os, sys
HOME = os.environ['HOME']  # change if necessary
sys.path.append(f'{HOME}/Finite-groups/src')

In [2]:
import torch as t
import numpy as np
from matplotlib import pyplot as plt
import json
from itertools import product
from jaxtyping import Float
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
import plotly.graph_objects as go
import copy
import math
from itertools import product
import pandas as pd
from typing import Union
from einops import repeat
from huggingface_hub import snapshot_download
from huggingface_hub.utils import disable_progress_bars


from model import MLP3, MLP4, InstancedModule
from utils import *
from group_data import *
from model_utils import *
from group_utils import *
%load_ext autoreload
%autoreload 2

fo.g:73
  if not IsKernelExtensionAvailable("Browse", "ncurses") then
         ^^^^^^^^^^^^^^^^^^^^^^^^^^
.g:60
  if not IsKernelExtensionAvailable("EDIM","ediv") then
         ^^^^^^^^^^^^^^^^^^^^^^^^^^


In [36]:

device = t.device("cuda" if t.cuda.is_available() else "cpu")
#MODEL_DIR = '2024-09-20_20-20-22_MLP2_Z_53_'
MODEL_DIR = '2024-09-20_20-50-54_MLP2_Z_53_'
local_dir = f'{HOME}/models/{MODEL_DIR}'
models, params = load_models(local_dir)
data = GroupData(params)
group = data.groups[0]

Intersection size: 2809/2809 (1.00)
Added 2809 elements from intersection
Added 0 elements from group 0: Z(53)
Taking random subset: 1123/2809 (0.40)
Train set size: 1123/2809 (0.40)


  model.load_state_dict(t.load(model_path, map_location=device))


In [37]:
model = models[-1]

In [42]:
loss_dict = test_loss(models[-1].to(device), data)
mean(loss_dict)

{'G0_loss': tensor(0.0711), 'G0_acc': tensor(0.9955)}

In [39]:
# instance = loss_dict['G0_loss'].argmin().item()
# print(loss_dict[f'G0_loss'][instance], loss_dict[f'G0_acc'][instance], instance)
# model = models[-1][instance].to(device)

In [40]:
mean = lambda d: {k: v.mean() for k, v in d.items()}

### Swap embeds

In [44]:
model2 = copy.deepcopy(model)
model2.embedding_left = nn.Parameter(model.embedding_right)
model2.embedding_right = nn.Parameter(model.embedding_left)
mean(test_loss(model2, data))

{'G0_loss': tensor(4.1553), 'G0_acc': tensor(0.0385)}

### Change embedding signs

In [45]:
model2 = copy.deepcopy(model)
model3 = copy.deepcopy(model)
model4 = copy.deepcopy(model)
model2.embedding_left = nn.Parameter(-model.embedding_left)
model3.embedding_right = nn.Parameter(-model.embedding_right)
model4.embedding_left = nn.Parameter(-model.embedding_left)
model4.embedding_right = nn.Parameter(-model.embedding_right)
mean(test_loss(model2, data)), mean(test_loss(model3, data)), mean(test_loss(model4, data))

({'G0_loss': tensor(17.1501), 'G0_acc': tensor(0.)},
 {'G0_loss': tensor(17.1500), 'G0_acc': tensor(0.)},
 {'G0_loss': tensor(0.0663), 'G0_acc': tensor(0.9969)})

### Absolute value nonlinearity

In [46]:
class Abs(nn.Module):
    def __init__(self, scale=1.):
        super().__init__()
        self.scale = scale

    def forward(self, input: t.Tensor) -> t.Tensor:
        return t.abs(input) * self.scale

model2 = copy.deepcopy(model)
model2.activation= Abs()
mean(test_loss(model2, data))

{'G0_loss': tensor(0.0045), 'G0_acc': tensor(0.9988)}

### Add noise

In [47]:
# Should've used transformerlens....
class MLP2Noise(InstancedModule):
    '''
    Architecture used by Chughtai et al. and Stander et al.
    '''
    def __init__(self, model, mean, std):
        super().__init__()
        model = copy.deepcopy(model)
        self.params = model.params
        self.N = model.N

        # self.embedding_left = init_func(
        self.embedding_left = nn.Parameter(model.embedding_left)
        self.embedding_right = nn.Parameter(model.embedding_right)
        self.linear_left = nn.Parameter(model.linear_left)
        self.linear_right = nn.Parameter(model.linear_right)
        self.unembedding = nn.Parameter(model.unembedding)
        if model.unembed_bias is not None:
            self.unembed_bias = nn.Parameter(model.unembed_bias)
        else:
            self.unembed_bias = None
        self.activation = model.activation
        self.mean = mean
        self.std = std

    def _forward(
        self, a: Int[t.Tensor, "batch_size entries"]
    ) -> Float[t.Tensor, "batch_size instances d_vocab"]:

        a_instances = einops.repeat(
            a, " batch_size entries -> batch_size n entries", n=self.num_instances(),
        )  # batch_size instances entries
        a_1, a_2 = a_instances[..., 0], a_instances[..., 1]

        a_1_onehot = F.one_hot(a_1, num_classes=self.N).float()
        a_2_onehot = F.one_hot(a_2, num_classes=self.N).float()

        x_1 = einops.einsum(
            a_1_onehot,
            self.embedding_left,
            "batch_size instances d_vocab, instances d_vocab embed_dim -> batch_size instances embed_dim",
        )
        x_2 = einops.einsum(
            a_2_onehot,
            self.embedding_right,
            "batch_size instances d_vocab, instances d_vocab embed_dim -> batch_size instances embed_dim",
        )

        hidden_1 = einops.einsum(
            x_1,
            self.linear_left,
            "batch_size instances embed_dim, instances embed_dim hidden -> batch_size instances hidden",
        )
        hidden_2 = einops.einsum(
            x_2,
            self.linear_right,
            "batch_size instances embed_dim, instances embed_dim hidden -> batch_size instances hidden",
        )
        hidden = hidden_1 + hidden_2

        hidden += t.randn_like(hidden) * self.std + self.mean

        out = einops.einsum(
            self.activation(hidden),
            self.unembedding,
            "batch_size instances hidden, instances hidden d_vocab-> batch_size instances d_vocab ",
        )
        if self.unembed_bias is not None:
            out += einops.repeat(
                self.unembed_bias,
                'instances d_vocab -> batch_size instances d_vocab',
                batch_size=out.shape[0]
            )

        return out


In [48]:
model2 = MLP2Noise(model, 0., 1.)
model3 = MLP2Noise(model, 0., .1)
model4 = MLP2Noise(model, 1., 1.)
model5 = MLP2Noise(model, -1., 1.)
mean(test_loss(model2, data)), mean(test_loss(model3, data)), mean(test_loss(model4, data)), mean(test_loss(model5, data))

({'G0_loss': tensor(0.8285), 'G0_acc': tensor(0.7690)},
 {'G0_loss': tensor(0.0752), 'G0_acc': tensor(0.9951)},
 {'G0_loss': tensor(1.7851), 'G0_acc': tensor(0.4980)},
 {'G0_loss': tensor(0.7797), 'G0_acc': tensor(0.8317)})