In [None]:
# Code borrowed from https://github.com/MinhZou/selective-copying-mamba

In [2]:
from transformers import GPT2Config
from scipy.stats import entropy
import torch

from model_load import model_loader
from dataset import generate_dataset
from visualization import model_viz_data
from trainer import eval, train

from captum.attr import LayerActivation

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [4]:
# train(my_model, dataset_config, training_config)
# eval(my_model, dataset_config, training_config)

In [5]:
# Configuration for training
training_config = {
    "batch_size": 1000,
    "learning_rate": 0.0003,
    "num_steps": 10000
}

# Configuration for dataset
dataset_config = {
    "span_length": 4,
    "num_spans": 3,
    "copying_ratio": .25,
    "n_tokens": 10,  # alphabet size
    "lag": False,
    "variable": True,  # Randomly distribute memorization tokens throughout sequence instead of frontloading them
    "variable_length": False,  # Randomize number of tokens to memorize
    "one_hot": False,
    "reverse": False,
    "static": False,
}

custom_config =  GPT2Config(
  bos_token_id= dataset_config['n_tokens'],
  eos_token_id= dataset_config['n_tokens'],
  n_embd= 64,
  n_head= 4,
  n_layer= 4,
  vocab_size= dataset_config['n_tokens']+1
)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [6]:
all_values_buffer = []

custom_config =  GPT2Config(
  bos_token_id= dataset_config['n_tokens'],
  eos_token_id= dataset_config['n_tokens'],
  n_embd= 64,
  n_head= 1,
  n_layer= 2,
  vocab_size= dataset_config['n_tokens']+1
)

In [7]:
my_model = model_loader(custom_config, device, load_model=True, dataset_config=dataset_config)
eval(my_model, dataset_config, training_config, device)

100.0


In [8]:
head = True
split = True
attention, x, y = model_viz_data(my_model, dataset_config, training_config, 
                                 device, head=head, split=split, x=None, y=None)

tensor([[6, 8, 8, 8, 8, 8, 8, 8, 8, 8, 6, 6, 6, 3, 3, 1, 1, 1, 1, 1, 1, 2, 6, 6,
         6, 1, 3, 3, 3, 3, 3, 2, 3, 1, 5, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2,
         6, 2, 1, 3, 3, 1, 2, 2, 5, 1, 5, 3]], device='cuda:0')


<IPython.core.display.Javascript object>

In [None]:
def capture_intermediates(dataset_config, training_config, my_model, device, x=None, y=None):
    my_model.eval()

    if x is None and y is None:
        x, y = generate_dataset(dataset_config, training_config, batch_size_override=1)
        x = x.to(device)
        y = y.to(device)


    layer_activations = []
    for block in my_model.base_model.h:
        layer_act = LayerActivation(my_model, [block.attn.c_proj,  block.mlp])
        layer_activations.append(layer_act.attribute(x))


    return layer_activations, x, y

In [None]:
def verify_intermediates(dataset_config, training_config, my_model, device, y, layer_activations):
    out_sz = y.shape[1]
    
    first_att = layer_activations[0][1][0, -out_sz, :]
    out = torch.matmul(layer_activations[0][1][0, -12:, :], my_model.linear.weight.T)
    
    preds = torch.argmax(out, dim=1)
    
    return torch.sum(preds==y)/y.shape[1]

In [None]:
layer_activations, x, y = capture_intermediates(dataset_config, training_config, my_model, device, x=None, y=None)

In [None]:
total = 0
cont = 1000
for _ in range(cont):
    layer_activations, x, y = capture_intermediates(dataset_config, training_config, 
                                                    my_model, device, x=None, y=None)
    total += verify_intermediates(dataset_config, training_config, my_model, device, y, layer_activations)
    
print(total/cont)