In [57]:
import torch
import pandas as pd
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer

In [58]:
one_grams_df = pd.read_csv('data/1grams.csv', delimiter=",", encoding="utf-8", quotechar='"')

one_grams = one_grams_df['gram'].tolist()

two_grams_df = pd.read_csv('data/2grams.csv', delimiter=",", encoding="utf-8", quotechar='"')

two_grams = two_grams_df['gram'].tolist()

In [59]:
two_grams

['black cat']

In [60]:
# Load the model and tokenizer
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-70m-deduped")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m-deduped")
tokenizer.pad_token = tokenizer.eos_token


In [61]:
activation_list = []

def hook_fn(module, input, output):
    activation_list.append(output)


layer_to_hook = model.gpt_neox.layers[3].mlp
hook = layer_to_hook.register_forward_hook(hook_fn)

In [62]:
one_gram_activations = []
two_gram_activations = []

for one_gram in one_grams:
    input = tokenizer(one_gram, return_tensors="pt", padding=True, truncation=True)
    decoded_tokens = tokenizer.convert_ids_to_tokens(input['input_ids'][0])
    output = model(**input)
    one_gram_activations.append(activation_list[0][0].detach().numpy())
    activation_list.clear()


for two_gram in two_grams:
    input = tokenizer(two_gram, return_tensors="pt", padding=True, truncation=True)
    output = model(**input)
    two_gram_activations.append(activation_list[0][0].detach().numpy())
    activation_list.clear()

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [63]:
# Sanity Check => Assumes position is the same for both

if np.allclose(one_gram_activations[0][0],two_gram_activations[0][0],atol=1e-5):
    print("Match")

Match


In [75]:
nc_target = one_gram_activations[0][0] # black

nc_sample = one_gram_activations[1][0] # cat
c_sample = two_gram_activations[0][1] # cat with context "black"

sample_stack = np.stack((nc_sample, c_sample), axis=0).flatten()

class AutoEncoder(torch.nn.Module):
    def __init__(self, input_size, output_size):
        super(AutoEncoder, self).__init__()

        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(input_size, input_size//2),
            torch.nn.ReLU(),
            torch.nn.Linear(input_size//2, input_size//4),
            torch.nn.ReLU()
        )

        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(input_size//4, output_size*2),
            torch.nn.ReLU(),
            torch.nn.Linear(output_size*2, output_size),
            torch.nn.ReLU()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
    

# Train the autoencoder
autoencoder = AutoEncoder(sample_stack.shape[0], nc_target.shape[0])

criterion = torch.nn.MSELoss()

optimizer = torch.optim.Adam(autoencoder.parameters(), lr=0.001)

for epoch in range(1000):
    optimizer.zero_grad()
    outputs = autoencoder(torch.Tensor(sample_stack))
    loss = criterion(outputs, torch.Tensor(nc_target))
    loss.backward()
    optimizer.step()
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

Epoch 1, Loss: 1.2148545980453491
Epoch 2, Loss: 1.1801179647445679
Epoch 3, Loss: 1.1182854175567627
Epoch 4, Loss: 1.0560204982757568
Epoch 5, Loss: 1.0488924980163574
Epoch 6, Loss: 1.0091736316680908
Epoch 7, Loss: 0.9861205816268921
Epoch 8, Loss: 0.9879978895187378
Epoch 9, Loss: 0.9922462701797485
Epoch 10, Loss: 0.9903084635734558
Epoch 11, Loss: 0.9849807024002075
Epoch 12, Loss: 0.9769585132598877
Epoch 13, Loss: 0.9677664637565613
Epoch 14, Loss: 0.9586188793182373
Epoch 15, Loss: 0.9523312449455261
Epoch 16, Loss: 0.9487028121948242
Epoch 17, Loss: 0.9457762837409973
Epoch 18, Loss: 0.9427114725112915
Epoch 19, Loss: 0.9402496218681335
Epoch 20, Loss: 0.9384238123893738
Epoch 21, Loss: 0.9369004964828491
Epoch 22, Loss: 0.9360945224761963
Epoch 23, Loss: 0.9356386065483093
Epoch 24, Loss: 0.9348900318145752
Epoch 25, Loss: 0.9339846968650818
Epoch 26, Loss: 0.9334036111831665
Epoch 27, Loss: 0.9328954219818115
Epoch 28, Loss: 0.9321054220199585
Epoch 29, Loss: 0.93161916732

In [77]:
autoencoder.decoder.

AttributeError: 'Sequential' object has no attribute 'shape'