In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from tuned_lens.nn.lenses import TunedLens
import numpy as np
from torch.utils.data import DataLoader
from datasets import load_dataset
from typing import Counter
from tqdm import tqdm
import random
import sklearn.metrics as metrics



In [2]:
from lenses_experiments.utils.logger import get_logger

logger = get_logger(name = __name__)
logger.info("Initializing logger...")

2024-04-14 13:06:03,991 - __main__ - INFO - Initializing logger...


In [3]:
model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-410m-deduped")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-410m-deduped")
tuned_lens = TunedLens.from_model_and_pretrained(model)
tokenizer.pad_token = tokenizer.eos_token

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
dataset = load_dataset("allenai/ai2_arc", "ARC-Easy")

In [5]:
from lenses_experiments.utils.datasets import Loader

In [7]:
loader = Loader(dataset_name="allenai/ai2_arc", dataset=dataset, tokenizer=tokenizer, device=model.device)
val_dataloader = loader.loader_normal
wrong_loader = loader.loader_injected

In [8]:
model.eval()
tokenizer.pad_token = tokenizer.eos_token

n_layers = model.config.num_hidden_layers + 1
n_hidden = model.config.hidden_size
n_vocab = model.config.vocab_size

correct_sum = 0
n_batches = len(val_dataloader)

features = []

for batch in tqdm(val_dataloader):
    pr, encoded_prompt, encoded_labels, correct_encoded_labels = batch

    n_batch = len(encoded_prompt.input_ids)
    encoded_labels = torch.tensor(encoded_labels).squeeze() #batch_size x n_choices
    correct_encoded_labels = torch.tensor(correct_encoded_labels).squeeze() #batch_size
    correct_encoded_labels = torch.tensor([torch.nonzero(encoded_labels[idx] == correct_encoded_labels[idx]).squeeze().item() for idx in range(correct_encoded_labels.shape[0])])
    inputs = encoded_prompt.input_ids
    attention_mask = encoded_prompt.attention_mask
    last_nonzero_tok_index = attention_mask.sum(dim=1) - 1  #which indexes to take for each prompt: the last token position

    with torch.no_grad():
        outputs = model(inputs, attention_mask=attention_mask, output_hidden_states=True)
        traj_log_probs = torch.empty(n_batch, n_layers, n_vocab)

        #getting logits at last token position
        logits_last_token = outputs.logits[torch.arange(n_batch), last_nonzero_tok_index, :].squeeze() #shape: (batchsize, vocab_size)
        traj_log_probs[:,-1,:] = logits_last_token.log_softmax(dim=-1) #shape: (batchsize, vocab_size)
    

        #getting hidden states at last token position
        hidden_states_last_token = torch.empty(n_batch, n_layers-1, n_hidden)
        hidden_states = outputs.hidden_states #shape: (layers, batchsize, seq_len, hidden_size)
        
        for layer_index in range(len(hidden_states[:-1])):
            for elem_index in range(len(hidden_states[layer_index])):
                hidden_states_last_token[elem_index][layer_index] = hidden_states[layer_index][elem_index][last_nonzero_tok_index[elem_index],:]

        #getting lens outputs from hidden states last token
        for layer_index in range(n_layers-1):
            traj_log_probs[:,layer_index,:] = tuned_lens.forward(hidden_states_last_token[:,layer_index,:], layer_index).log_softmax(dim=-1)

        print(f"Top answers given by last lenses index: {[tokenizer.decode(el) for el in traj_log_probs[:,-2].argmax(dim=-1).squeeze().tolist()]}")
        print(f"Top answers given by first lenses index: {[tokenizer.decode(el) for el in traj_log_probs[:,0].argmax(dim=-1).squeeze().tolist()]}")
        #transform the trajectory into the trajectory for the encoded labels
        traj_log_probs_filtered = torch.empty(n_batch, n_layers, len(encoded_labels[0]))

        correct_batch = 0
        for i in range(n_batch):
            traj_log_probs_filtered[i] = traj_log_probs[i,:,encoded_labels[i]]
            if(traj_log_probs_filtered[i,-1,correct_encoded_labels[i]].squeeze() == traj_log_probs_filtered[i,-1,:].max().squeeze()):
                correct_batch += 1
        #print("correct answer ratio for this batch: ", correct_batch/n_batch)
        correct_sum += correct_batch/n_batch

        traj_log_probs_flatten = traj_log_probs_filtered.flatten(start_dim=1)
        print(traj_log_probs_flatten.shape)
        features.append(traj_log_probs_flatten)

        #print(traj_log_probs_filtered.shape)
features = torch.cat(features, dim=0)

#print(correct_sum/n_batches)

  6%|▌         | 1/18 [00:16<04:40, 16.50s/it]

Top answers given by last lenses index: ['A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'C', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', '1', 'A', '1', 'A', 'A', 'A', 'A']
Top answers given by first lenses index: ['\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n']
torch.Size([32, 100])


 11%|█         | 2/18 [00:29<03:49, 14.34s/it]

Top answers given by last lenses index: ['A', 'A', 'A', 'A', 'A', 'C', 'A', 'A', 'A', 'A', 'A', 'A', 'A', '1', 'A', 'A', 'A', 'A', '1', 'A', 'A', 'A', 'A', 'A', 'A', '1', 'A', 'A', 'A', 'A', 'A', 'A']
Top answers given by first lenses index: ['\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n']
torch.Size([32, 100])


 17%|█▋        | 3/18 [00:43<03:35, 14.36s/it]

Top answers given by last lenses index: ['A', 'D', 'A', '1', 'A', '1', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'B', 'A', 'A', 'A', 'A', 'A']
Top answers given by first lenses index: ['\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n']
torch.Size([32, 100])


 22%|██▏       | 4/18 [00:59<03:27, 14.82s/it]

Top answers given by last lenses index: ['A', '1', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'B', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'C', 'A', 'A']
Top answers given by first lenses index: ['\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n']
torch.Size([32, 100])


 28%|██▊       | 5/18 [01:10<02:55, 13.48s/it]

Top answers given by last lenses index: ['A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', '1', 'A', 'C', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', '\n', 'A', 'A', 'B']
Top answers given by first lenses index: ['\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n']
torch.Size([32, 100])


 33%|███▎      | 6/18 [01:20<02:27, 12.25s/it]

Top answers given by last lenses index: ['A', 'A', 'A', '1', 'A', 'A', 'A', 'A', 'A', 'A', 'C', '1', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'B', 'A', 'A', '1', 'A', 'A', 'A', 'A']
Top answers given by first lenses index: ['\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n']
torch.Size([32, 100])


 39%|███▉      | 7/18 [01:31<02:11, 11.95s/it]

Top answers given by last lenses index: ['A', 'A', 'A', 'A', 'A', 'C', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', '\n', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A']
Top answers given by first lenses index: ['\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n']
torch.Size([32, 100])


 44%|████▍     | 8/18 [01:44<02:03, 12.39s/it]

Top answers given by last lenses index: ['A', 'A', 'A', 'C', 'A', 'A', 'A', 'C', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A']
Top answers given by first lenses index: ['\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n']
torch.Size([32, 100])


 50%|█████     | 9/18 [01:55<01:46, 11.78s/it]

Top answers given by last lenses index: ['A', 'A', 'A', 'A', 'A', '\n', 'A', 'A', 'A', 'A', 'C', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', '\n']
Top answers given by first lenses index: ['\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n']
torch.Size([32, 100])


 56%|█████▌    | 10/18 [02:05<01:30, 11.33s/it]

Top answers given by last lenses index: ['A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', '1', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'C', 'A']
Top answers given by first lenses index: ['\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n']
torch.Size([32, 100])


 61%|██████    | 11/18 [02:18<01:21, 11.65s/it]

Top answers given by last lenses index: ['A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'D', 'A', 'A', 'A', 'A', 'C', 'A', 'A', 'A', 'B', 'C', 'B', '1', '1', 'A']
Top answers given by first lenses index: ['\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n']
torch.Size([32, 100])


 67%|██████▋   | 12/18 [02:28<01:08, 11.35s/it]

Top answers given by last lenses index: ['A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'C', 'A', 'A', 'A', 'A', 'A', 'A', 'A', '\n', 'A', 'A', 'A', 'A', 'A', 'C', '1', 'A', 'A', 'A']
Top answers given by first lenses index: ['\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n']
torch.Size([32, 100])


 72%|███████▏  | 13/18 [02:44<01:03, 12.64s/it]

Top answers given by last lenses index: ['A', 'A', 'A', 'A', 'A', 'D', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', '1', '\n', 'A', 'B', 'A', 'A', 'A']
Top answers given by first lenses index: ['\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n']
torch.Size([32, 100])


 78%|███████▊  | 14/18 [02:54<00:47, 11.82s/it]

Top answers given by last lenses index: ['A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'C', 'A', 'A', 'A', 'A', 'A', '\n', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'C', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A']
Top answers given by first lenses index: ['\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n']
torch.Size([31, 100])


 83%|████████▎ | 15/18 [03:04<00:34, 11.34s/it]

Top answers given by last lenses index: ['A', 'A', 'A', 'A', 'A', 'A', '1', 'A', 'A', 'A', 'A', 'C', 'A', 'A', 'A', 'A', 'A', 'A', 'C', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A']
Top answers given by first lenses index: ['\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n']
torch.Size([31, 100])


 89%|████████▉ | 16/18 [03:17<00:23, 11.96s/it]

Top answers given by last lenses index: ['A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'C', 'B', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A']
Top answers given by first lenses index: ['\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n']
torch.Size([32, 100])


 94%|█████████▍| 17/18 [03:33<00:13, 13.11s/it]

Top answers given by last lenses index: ['A', 'B', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'C', 'A', 'A', 'A', 'A', 'A', '1', 'A', 'A', 'A', 'A', 'A', 'A', 'A']
Top answers given by first lenses index: ['\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n']
torch.Size([32, 100])


100%|██████████| 18/18 [03:41<00:00, 12.32s/it]

Top answers given by last lenses index: ['1', 'A', 'A', 'A', 'A', 'B', 'A', 'A', '1', 'A', 'A', 'A', 'A', 'A', 'A', 'A', '\n', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A']
Top answers given by first lenses index: ['\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n', '\n']
torch.Size([25, 100])



