Tasks:

* Get early attention blocks
* Use those blocks to predict the next token
* For loop to iterate over all attention block predictions
* Time how long it takes for each prediction
* Add some kind of benchmark; plot layer vs accuracy

### Setup & loading model

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm

In [2]:
model = AutoModelForCausalLM.from_pretrained("gpt2-medium", output_attentions=True, output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained("gpt2-medium")

In [3]:
def predict_next_token(text, model, tokenizer):
    input_ids = tokenizer.encode(text, return_tensors="pt")

    with torch.no_grad():
        output = model(input_ids=input_ids)

    logits = output.logits[0, -1, :]
    next_token_id = torch.argmax(logits).item()
    next_token = tokenizer.decode(next_token_id)

    return next_token

In [4]:
text = "Once upon a time"
next_token = predict_next_token(text, model, tokenizer)
print("Next token prediction:", next_token)

Next token prediction: ,


### Get early attention blocks

In [5]:
def get_hiddenstates_attn(text, model, tokenizer):
  input_ids = tokenizer.encode(text, return_tensors="pt")
  with torch.no_grad():
    output = model(input_ids=input_ids)
  hidden_states = output.hidden_states
  attentions = output.attentions
  return hidden_states, attentions

In [6]:
hidden, attn = get_hiddenstates_attn(text, model, tokenizer)

### Use blocks to predict next token


In [7]:
# this always returns none; I am not sure why
def get_next_token_from_hidden(text, hidden_states, attentions, layer):
  logits = model.lm_head(hidden_states[layer])
  next_token_logits = logits[:, -1, :]
  probabilities = F.softmax(next_token_logits, dim=-1)
  predicted_token_id = probabilities.argmax(dim=-1)  # Take max prediction
  # Ariel: Where is the temperature added?
  predicted_token = tokenizer.decode(predicted_token_id.tolist())
  return predicted_token

In [8]:
# Ariel: We are making a prediction from hidden state i? Iterating over the blocks
# Ariel: So indeed it seems later blocks predict something different
for i in range(len(hidden)):
  print(get_next_token_from_hidden(text, hidden, attn, i))

 time
abwe
 Sponsor
mi
 lapse
 lapse
 lapse
 lapse
 CI
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,
,


### Predicting $k$ next tokens

In [9]:
def full_model_prediction(model, input_text, num_tokens_to_generate=20):

    # get the input ids from the input text
    input_ids = tokenizer.encode(input_text, return_tensors='pt')

    for _ in range(num_tokens_to_generate):
        # Get model outputs
        outputs = model(input_ids)
        logits = outputs.logits

        # Only use the logits from the last token position
        next_token_logits = logits[:, -1, :]
        next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
        # print(next_token_logits)

        # Append the predicted token ID to the input sequence
        input_ids = torch.cat([input_ids, next_token_id], dim=-1)

    # Decode the input IDs to a string
    generated_text = tokenizer.decode(input_ids[0])

    return generated_text

In [None]:
input_text = "Once upon a time"
num_tokens_to_generate = 20
print("Generated text: ", full_model_prediction(model, input_text, num_tokens_to_generate))

In [None]:
from matplotlib import pyplot as plt

def early_prediction(model, input_text, num_tokens_to_generate=20, exit_layer=5):
    def get_hiddenstates_attn(input_ids, model, tokenizer):
        with torch.no_grad():
            output = model(input_ids=input_ids)
        hidden_states = output.hidden_states
        attentions = output.attentions
        return hidden_states, attentions

    token_probabilities = []
    input_ids = tokenizer.encode(input_text, return_tensors='pt')

    for _ in range(num_tokens_to_generate):
        # outputs = model(input_ids)
        hidden_states, attns = get_hiddenstates_attn(input_ids, model, tokenizer)

        ln_f = nn.LayerNorm(model.config.n_embd, eps=model.config.layer_norm_epsilon)
        layer_norm = ln_f(hidden_states[exit_layer])

        # if apply_layer_norm:
        if exit_layer != len(hidden_states) - 1:
            # print("yes for ", str(exit_layer))
            ln_f = nn.LayerNorm(model.config.n_embd, eps=model.config.layer_norm_epsilon)
            layer_norm = ln_f(hidden_states[exit_layer])
            logits = model.lm_head(layer_norm)
        else:
            # print("no layernorm or lm_head for ", str(exit_layer))
            # logits = model.lm_head(hidden_states[exit_layer])
            logits = model.lm_head(hidden_states[exit_layer])

        # Only use the logits from the last token position
        next_token_logits = logits[:, -1, :]
        next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
        
        k = 15
        top_k_tokens = torch.topk(next_token_logits, k, dim=-1)
        top_k_probabilities = F.softmax(top_k_tokens.values, dim=-1)
        
        plt.figure()
        plt.plot(top_k_probabilities[0].tolist())
        plt.title('Top-15 probabilities for early exit layer ' + str(exit_layer))
        
        # plt.legend()

        # print(next_token_logits)

        # Append the predicted token ID to the input sequence
        input_ids = torch.cat([input_ids, next_token_id], dim=-1)

        # get the probabilities
        probabilities = F.softmax(next_token_logits, dim=-1)
        top_proba = probabilities[0][next_token_id].item()
        # top_proba = next_token_logits[0][next_token_id].item()
        token_probabilities.append(top_proba)

    # print(input_ids)
    generated_text = tokenizer.decode(input_ids[0])

    return generated_text, token_probabilities

In [None]:
token_probabilities = []
# input_ids = tokenizer.encode(input_text, return_tensors='pt')
input_ids = tokenizer.encode(shakes_best[0]['Text'], return_tensors='pt')
exit_layer = 5

print("Input ID length: ", len(input_ids[0]))

def get_hiddenstates_attn(input_ids, model, tokenizer):
    with torch.no_grad():
        output = model(input_ids=input_ids)
    hidden_states = output.hidden_states
    attentions = output.attentions
    return hidden_states, attentions

# for _ in range(num_tokens_to_generate):
# outputs = model(input_ids)
hidden_states, attns = get_hiddenstates_attn(input_ids, model, tokenizer)

ln_f = nn.LayerNorm(model.config.n_embd, eps=model.config.layer_norm_epsilon)
layer_norm = ln_f(hidden_states[exit_layer])

# print(layer_norm.shape)
plt.imshow(layer_norm[0].detach().numpy())
plt.colorbar()


In [None]:
len(model(input_ids).hidden_states)

In [None]:
token_probabilities = []
# input_ids = tokenizer.encode(input_text, return_tensors='pt')
input_ids = tokenizer.encode(shakes_best[0]['Text'], return_tensors='pt')
exit_layer = 5

print("Input ID length: ", len(input_ids[0]))

def get_hiddenstates_attn(input_ids, model, tokenizer):
    with torch.no_grad():
        output = model(input_ids=input_ids)
    hidden_states = output.hidden_states
    attentions = output.attentions
    return hidden_states, attentions

# for _ in range(num_tokens_to_generate):
# outputs = model(input_ids)
hidden_states, attns = get_hiddenstates_attn(input_ids, model, tokenizer)

ln_f = nn.LayerNorm(model.config.n_embd, eps=model.config.layer_norm_epsilon)
layer_norm = ln_f(attns[exit_layer])

# print(layer_norm.shape)
plt.imshow(layer_norm[0].detach().numpy())
plt.colorbar()


In [None]:
from torchinfo import summary as ti_summary

ti_summary(model, depth=4)

In [None]:
dir()

In [None]:
token_probabilities = []
# input_ids = tokenizer.encode(input_text, return_tensors='pt')
input_ids = tokenizer.encode(shakes_best[0]['Text'], return_tensors='pt')
exit_layer = 5

def get_hiddenstates_attn(input_ids, model, tokenizer):
    with torch.no_grad():
        output = model(input_ids=input_ids)
    hidden_states = output.hidden_states
    attentions = output.attentions
    return hidden_states, attentions

# outputs = model(input_ids)
hidden_states, attns = get_hiddenstates_attn(input_ids, model, tokenizer)

ln_f = nn.LayerNorm(model.config.n_embd, eps=model.config.layer_norm_epsilon)
layer_norm = ln_f(hidden_states[exit_layer])

# print(layer_norm.shape)
plt.hist(layer_norm[0].detach().numpy())
# plt.colorbar()

In [None]:
from datasets import load_dataset
# ChatGPT - take 100 shortest entries from Tiny Shakespeare
def get_n_longest_tiny_shakespeare(n=25):
    # Load the dataset
    train_dataset = load_dataset("Trelis/tiny-shakespeare", split="train")

    # Define a function to get the length of each input
    def input_length(example):
        return len(example['Text'])

    # Add a new field to each example in the dataset containing the length of the input
    train_dataset = train_dataset.map(lambda x: {"length": input_length(x)})

    # Sort the dataset by the newly added 'length' field
    sorted_dataset = train_dataset.sort("length", reverse=True)

    # Get the top 100 shortest inputs
    shortest_100 = sorted_dataset.select(range(n))

    # Print the shortest 100 inputs (optional)
    # for example in shortest_100:
    #     print(len(example['Text']))

    return shortest_100

In [None]:
shakes_best = get_n_longest_tiny_shakespeare(n=25)
shakes_best

In [None]:
shakes_best[0]['Text']

In [None]:
len(shakes_best[0]['Text'])

In [None]:
input_text = "Once upon a time"
num_tokens_to_generate = 5
generated_text, token_probabilities = early_prediction(model, input_text, num_tokens_to_generate, exit_layer=24)
print("Generated text: ", generated_text)
print("Probabilities: ", str(token_probabilities))

#### Let's compare the prediction from the different early exit blocks

In [None]:
input_text = "Today I"
num_tokens_to_generate = 1
gens = []
probas = []
for exit_layer in tqdm(range(25)):  # Loop over the layers
    generated_text, token_probabilities = early_prediction(model, input_text, num_tokens_to_generate, exit_layer)
    gens.append(generated_text)
    probas.append(token_probabilities)

In [None]:
# printing out the different generations
for layer, gen in enumerate(gens):
    print("Generated text for layer ", str(layer), ": ")
    print("\t", gen)
    # print("\tAverage probability: ", str(sum(probas[int(layer)])/len(probas[int(layer)])))

In [None]:
average_probas = [sum(layer)/len(layer) for layer in probas]
first_token_probas = [layer[0] for layer in probas]

In [None]:
probas

In [None]:
first_token_probas

In [None]:
import matplotlib.pyplot as plt
import numpy as np

plt.plot(first_token_probas)

In [None]:
model.Model.summary()

In [None]:
from transformers import GPT2LMHeadModel

In [None]:
model.config

In [None]:
print(model)

In [None]:
dir(model)

### Prediction time

In [None]:
# code

### Benchmark & plotting

In [None]:
# code