In [57]:
import torch
import pandas as pd
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer

In [85]:
one_gram_samples_df = pd.read_csv('data/1gram_samples.csv', delimiter=",", encoding="utf-8", quotechar='"')

one_gram_samples = one_gram_samples_df['gram'].tolist()

one_gram_targets_df = pd.read_csv('data/1gram_targets.csv', delimiter=",", encoding="utf-8", quotechar='"')

one_gram_targets = one_gram_targets_df['gram'].tolist()

two_grams_df = pd.read_csv('data/2grams.csv', delimiter=",", encoding="utf-8", quotechar='"')

two_grams = two_grams_df['gram'].tolist()

In [86]:
two_grams

['black cat']

In [87]:
# Load the model and tokenizer
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-70m-deduped")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m-deduped")
tokenizer.pad_token = tokenizer.eos_token


In [88]:
activation_list = []

def hook_fn(module, input, output):
    activation_list.append(output)


layer_to_hook = model.gpt_neox.layers[3].mlp
hook = layer_to_hook.register_forward_hook(hook_fn)

In [89]:
one_gram_samples_activations = []
one_gram_targets_activations = []
two_gram_activations = []

for one_gram in one_gram_samples:
    input = tokenizer(one_gram, return_tensors="pt", padding=True, truncation=True)
    decoded_tokens = tokenizer.convert_ids_to_tokens(input['input_ids'][0])
    output = model(**input)
    one_gram_samples_activations.append(activation_list[0][0].detach().numpy())
    activation_list.clear()


for one_gram in one_gram_targets:
    input = tokenizer(one_gram, return_tensors="pt", padding=True, truncation=True)
    output = model(**input)
    one_gram_targets_activations.append(activation_list[0][0].detach().numpy())
    activation_list.clear()


for two_gram in two_grams:
    input = tokenizer(two_gram, return_tensors="pt", padding=True, truncation=True)
    output = model(**input)
    two_gram_activations.append(activation_list[0][0].detach().numpy())
    activation_list.clear()

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [90]:
# Sanity Check => Assumes position is the same for both

if np.allclose(one_gram_samples_activations[0][0],two_gram_activations[0][0],atol=1e-5):
    print("Match")

Match


In [91]:
nc_target = one_gram_targets_activations[0][0] # black

nc_sample = one_gram_samples_activations[0][0] # cat
c_sample = two_gram_activations[0][1] # cat with context "black"

sample_stack = np.stack((nc_sample, c_sample), axis=0).flatten()

class AutoEncoder(torch.nn.Module):
    def __init__(self, input_size, output_size):
        super(AutoEncoder, self).__init__()

        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(input_size, input_size//2),
            torch.nn.ReLU(),
            torch.nn.Linear(input_size//2, input_size//4),
        )

        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(input_size//4, output_size//2),
            torch.nn.ReLU(),
            torch.nn.Linear(output_size//2, output_size),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
    

# Train the autoencoder
autoencoder = AutoEncoder(sample_stack.shape[0], nc_target.shape[0])

criterion = torch.nn.MSELoss()

optimizer = torch.optim.Adam(autoencoder.parameters(), lr=0.001)

for epoch in range(1000):
    optimizer.zero_grad()
    outputs = autoencoder(torch.Tensor(sample_stack))
    loss = criterion(outputs, torch.Tensor(nc_target))
    loss.backward()
    optimizer.step()
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

Epoch 1, Loss: 1.1443760395050049
Epoch 2, Loss: 1.0527876615524292
Epoch 3, Loss: 0.9041814208030701
Epoch 4, Loss: 0.7349306344985962
Epoch 5, Loss: 0.646043598651886
Epoch 6, Loss: 0.5118300914764404
Epoch 7, Loss: 0.3915354907512665
Epoch 8, Loss: 0.32034289836883545
Epoch 9, Loss: 0.2563043236732483
Epoch 10, Loss: 0.1997385025024414
Epoch 11, Loss: 0.1636376529932022
Epoch 12, Loss: 0.13400988280773163
Epoch 13, Loss: 0.10915946215391159
Epoch 14, Loss: 0.09528877586126328
Epoch 15, Loss: 0.0839952602982521
Epoch 16, Loss: 0.07151993364095688
Epoch 17, Loss: 0.06083247810602188
Epoch 18, Loss: 0.050812505185604095
Epoch 19, Loss: 0.041030097752809525
Epoch 20, Loss: 0.03427056968212128
Epoch 21, Loss: 0.030135534703731537
Epoch 22, Loss: 0.027492789551615715
Epoch 23, Loss: 0.025643987581133842
Epoch 24, Loss: 0.023664003238081932
Epoch 25, Loss: 0.0206773579120636
Epoch 26, Loss: 0.0178944431245327
Epoch 27, Loss: 0.01570933870971203
Epoch 28, Loss: 0.013844707980751991
Epoch 29

In [92]:
autoencoder.decoder[-1].weight.shape

torch.Size([512, 256])