# Overview of what we want to achieve:

* develop replacement model
* develop way to sample from the replacement model
* need to make a dataset and dataloader
* write the evaluation + calculate the accuracy
* need to put this as a Validation Metric to track in our training loop

## What are the different implementation options?

1. Combine gpt2 + CLT; use hooks
* we load gpt2 in memory
* we load (or have) the CLT in memory
* we use nnsight to put a lot of hooks into gpt2
* each hook runs part of the clt

2. write a replacement model class from scratch
* basically like a transformer but replace the MLPs with they weights from the clt
* read in attention weights, W_enc, W_pos, W_unemb from gpt2
* read in CLT weights from the CLT

### Things to consider for decision:
* easy to read and understand
* whether it's okay to ie during validation offload the CLT and gpt2 to load this model


### local replacement model
* run gpt2 on some prompts, store mlp_in, mlp_out
* run through the CLT model using activations from gpt2
* calculate the error terms recons - mlp_out

### What is the replacement model used for?
* just for evaluation

## Conclusion:

Since the replacement model will only be used as a validation metric and has no use cases downstream, the best option to implement this as a hook function model with nnsight because we won't need to load in weights, we can simply use the model

In [2]:
import os
# use cuda 0
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
from einops import einsum


class ReplacementModel(torch.nn.Module):
    def __init__(self, gpt2, clt):
        super(ReplacementModel, self).__init__()
        self.gpt2 = gpt2
        self.clt = clt
        self.n_layers = self.gpt2.config.n_layer
        assert self.n_layers == self.clt.config.n_layers
        self.n_features = self.clt.config.d_features

    def forward(self, tokens):
        with self.gpt2.trace(tokens) as tracer:
            # features: batch_size x seq_len x n_layers x n_features
            features = torch.full((tokens.shape[0], tokens.shape[1], self.n_layers, self.n_features), float('nan'))

            for layer in range(self.n_layers):
                mlp_in = self.gpt2.transformer.h[layer].ln_2.input.save()  # batch_size x seq_len x d_resid
                # TODO: get features by running clt encoder on this layer
                mlp_in_norm = (mlp_in - mlp_in.mean(dim=-1, keepdim=True)) / mlp_in.std(dim=-1, keepdim=True)
                features[..., layer, :] = einsum(mlp_in_norm, self.clt.W_enc[layer], 'batch seq d_acts, d_acts d_features -> batch seq d_features')

                recons = einsum(features[:layer], self.W_dec[:layer, layer], 'batch seq n_layers d_features, n_layers d_features d_acts -> batch seq d_acts')

                mlp_out = self.gpt2.transformer.h[layer].mlp.output.save()
                # TODO: get reconstructions by running clt decoder on features of all lower layers
                self.gpt2.transformer.h[layer].mlp.output = recons
                logits = self.gpt2.lm_head.output.save()
        return logits
    

In [3]:
import nnsight

gpt2 = nnsight.LanguageModel('openai-community/gpt2', device_map='auto', dispatch=True)
gpt2.requires_grad_(False)

  from .autonotebook import tqdm as notebook_tqdm
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
  (generator): Generator(
    (streamer): Streamer()
  

In [4]:
from utils import get_webtext_dataloader

loader = get_webtext_dataloader(gpt2)

Downloading data: 100%|██████████| 21/21 [02:12<00:00,  6.33s/files]
Generating train split: 100%|██████████| 8013769/8013769 [21:35<00:00, 6186.44 examples/s]


Token indices sequence length is longer than the specified maximum sequence length for this model (1217 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1174 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2459 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (2027 > 1024). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1561 > 1024). Running this sequence through the model will result in indexing errors


In [4]:
gpt2.lm_head.weight

Parameter containing:
tensor([[-0.1101, -0.0393,  0.0331,  ..., -0.1364,  0.0151,  0.0453],
        [ 0.0403, -0.0486,  0.0462,  ...,  0.0861,  0.0025,  0.0432],
        [-0.1275,  0.0479,  0.1841,  ...,  0.0899, -0.1297, -0.0879],
        ...,
        [-0.0445, -0.0548,  0.0123,  ...,  0.1044,  0.0978, -0.0695],
        [ 0.1860,  0.0167,  0.0461,  ..., -0.0963,  0.0785, -0.0225],
        [ 0.0514, -0.0277,  0.0499,  ...,  0.0070,  0.1552,  0.1207]],
       device='cuda:0')

In [17]:
for batch in loader:
    bos = torch.zeros((batch.shape[0], 1), dtype=torch.long, device=batch.device) + gpt2.config.bos_token_id
    batch = torch.cat([bos, batch], dim=1)
    batch = batch.to('cuda:0')
    output = gpt2(batch)
    print(output)
    logits_orig = output.logits

    with gpt2.trace(batch) as tracer:
        mlp_ins = []
        mlp_outs = []
        for i in range(12):
            mlp_in = gpt2.transformer.h[i].ln_2.input.save()
            mlp_ins.append(mlp_in)
            mlp_out = gpt2.transformer.h[i].mlp.output.save()
            gpt2.transformer.h[i].mlp.output = torch.zeros_like(mlp_out)
            mlp_outs.append(mlp_out)
        logits_corrupted = gpt2.lm_head.output.save()
    print(mlp_in.shape)

    break

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


CausalLMOutputWithCrossAttentions(loss=None, logits=tensor([[[ -43.4316,  -39.8364,  -43.0659,  ...,  -54.0877,  -54.3451,
           -42.3644],
         [ -73.7466,  -74.9203,  -76.8675,  ...,  -78.6061,  -80.6910,
           -73.7790],
         [ -84.0888,  -85.0945,  -88.5724,  ...,  -90.5709,  -91.0832,
           -84.1004],
         ...,
         [ -98.8388,  -97.5960, -101.3842,  ..., -105.7438, -104.3169,
           -99.2655],
         [ -98.2660,  -96.1375,  -98.6375,  ...,  -94.2466, -101.6711,
           -95.7393],
         [ -58.3637,  -57.8836,  -60.0584,  ...,  -65.7926,  -65.8141,
           -57.9844]],

        [[ -43.4316,  -39.8364,  -43.0659,  ...,  -54.0877,  -54.3451,
           -42.3644],
         [ -80.3972,  -81.3332,  -84.8448,  ...,  -90.2444,  -87.5910,
           -79.9321],
         [-116.6863, -118.2129, -118.9647,  ..., -122.3632, -120.9058,
          -113.9914],
         ...,
         [ -91.7356,  -92.1710,  -92.8851,  ..., -101.4698,  -96.9163,
          

In [6]:
for batch in loader:
    bos = torch.zeros((batch.shape[0], 1), dtype=torch.long, device=batch.device) + gpt2.config.bos_token_id
    batch = torch.cat([bos, batch], dim=1)
    batch = batch.to('cuda:0')
    batch = batch[:5]
    output = gpt2(batch)
    logits_orig = output.logits

    with gpt2.trace(batch) as tracer:
        mlp_ins = []
        mlp_outs = []
        for i in range(12):
            mlp_in = gpt2.transformer.h[i].ln_2.input.save()
            mlp_ins.append(mlp_in)
        mlp_ins = torch.stack(mlp_ins)
        mlp_ins *= 2.
        for i in range(12):
            mlp_out = gpt2.transformer.h[i].mlp.output.save()
            gpt2.transformer.h[i].mlp.output = mlp_ins[11 - i]
            mlp_outs.append(mlp_out)
        logits_corrupted = gpt2.lm_head.output.save()
    print(mlp_in.shape)

    break

torch.Size([5, 1024, 768])


In [10]:
for batch in loader:
    bos = torch.zeros((batch.shape[0], 1), dtype=torch.long, device=batch.device) + gpt2.config.bos_token_id
    batch = torch.cat([bos, batch], dim=1)
    batch = batch.to('cuda:0')
    batch = batch[:5]
    output = gpt2(batch)
    logits_orig = output.logits

    with gpt2.trace(batch) as tracer:
        mlp_out_8 = gpt2.transformer.h[8].mlp.output.save()
        mlp_out_2 = gpt2.transformer.h[2].mlp.output.save()
        gpt2.transformer.h[2].mlp.output = mlp_out_8
        mlp_out_2_corrupted = gpt2.transformer.h[2].mlp.output.save()
        #logits_corrupted = gpt2.lm_head.output.save()
    
    print('mlp_in_8', mlp_in_8)
    print('mlp_out_2', mlp_out_2)
    print('mlp_out_2_corrupted', mlp_out_2_corrupted)

    break

mlp_in_8 tensor([[[-0.2317, -0.1742,  0.4076,  ...,  1.2094,  0.9779,  0.3468],
         [ 0.0132,  5.9815, -1.3030,  ..., -2.3134,  3.5621,  0.9069],
         [-1.1271,  2.6131,  3.5580,  ..., -1.8958,  5.2259,  2.8670],
         ...,
         [ 0.4287, -0.4732,  0.4213,  ...,  0.6913,  1.9424,  2.2117],
         [-1.4544,  1.4217, -3.0940,  ...,  1.3569,  2.2228,  1.6743],
         [-2.6305, -1.8654, -0.4224,  ..., -0.9321, -2.3973, -0.3056]],

        [[-0.2317, -0.1742,  0.4076,  ...,  1.2094,  0.9779,  0.3468],
         [-0.7623,  1.7911,  0.1114,  ...,  0.8085, -2.6471,  0.6905],
         [ 0.3247,  0.2966, -0.4103,  ...,  0.7561, -1.7757, -0.6157],
         ...,
         [ 1.3467,  2.2840, -1.1671,  ...,  1.5021,  2.0084,  0.2160],
         [ 4.6676, -0.1384,  1.7488,  ...,  5.2829, -0.7914, -0.9944],
         [ 1.6990, -0.5843,  2.4302,  ...,  5.1885,  1.4176,  0.1654]],

        [[-0.2317, -0.1742,  0.4076,  ...,  1.2094,  0.9779,  0.3468],
         [-0.9271,  1.1020, -1.0148,

In [9]:
for batch in loader:
    bos = torch.zeros((batch.shape[0], 1), dtype=torch.long, device=batch.device) + gpt2.config.bos_token_id
    batch = torch.cat([bos, batch], dim=1)
    batch = batch.to('cuda:0')
    batch = batch[:5]
    output = gpt2(batch)
    logits_orig = output.logits

    with gpt2.trace(batch) as tracer:
        mlp_in_2 = gpt2.transformer.h[2].ln_2.input.save()
        mlp_out_8 = gpt2.transformer.h[8].mlp.output.save()
        gpt2.transformer.h[8].mlp.output = mlp_in_2
        mlp_out_8_corrupted = gpt2.transformer.h[8].mlp.output.save()
        #logits_corrupted = gpt2.lm_head.output.save()
    
    print('mlp_in_2', mlp_in_2)
    print('mlp_out_8', mlp_out_8)
    print('mlp_out_8_corrupted', mlp_out_8_corrupted)

    break

mlp_in_2 tensor([[[ 1.4882e-01, -7.8068e-02,  5.4677e-01,  ...,  1.6689e+00,
           1.4880e+00,  7.6614e-01],
         [ 1.7404e-01,  3.6589e+00, -1.4495e+00,  ..., -8.9506e-01,
          -1.6451e+00, -5.5786e-01],
         [-2.1732e+00, -3.9464e-01,  3.1175e+00,  ..., -8.9543e-01,
           3.3316e-01, -1.9379e+00],
         ...,
         [-3.6008e+00,  7.6212e-02, -2.9871e-01,  ..., -5.6556e-02,
          -1.4686e+00, -2.8403e-01],
         [ 6.0625e-01,  6.7959e-01,  4.2723e-01,  ..., -1.7862e-01,
           4.7519e-01,  1.1427e+00],
         [-3.5875e+00,  1.3844e+00,  2.0117e-03,  ...,  3.8480e-01,
           2.0642e+00,  3.3392e-01]],

        [[ 1.4882e-01, -7.8068e-02,  5.4677e-01,  ...,  1.6689e+00,
           1.4880e+00,  7.6614e-01],
         [-9.0029e-01, -4.9958e-01, -9.2506e-01,  ..., -1.5241e+00,
          -3.2433e-01, -7.3334e-01],
         [ 1.7569e+00, -6.3580e-02,  1.4569e-01,  ..., -8.2524e-01,
           6.4750e-01, -8.1157e-01],
         ...,
         [ 1.568