In [13]:
import numpy as np
import pickle
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config, AutoTokenizer, AutoModel, AutoConfig
import torch
from modular_transformers.straightening.straightening_utils import compute_model_activations, compute_model_curvature

from matplotlib import pyplot as plt

from datasets import load_dataset
from nltk.tokenize import sent_tokenize
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.nn.functional as F

import os

from tqdm import tqdm

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
path = "/om2/user/jackking/modular_transformers/scripts/attention_interpretability"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x154fd8161d60>

# Split By Curvature

In [11]:
datasource = "ud"
sentence_length = 10
datatype = f"{datasource}/{sentence_length}_word"
model_name = "gpt2-xl"

data_dir = "/rdma/vast-rdma/vast/evlab/ehoseini/MyData/sent_sampling/analysis/straightening/generation/sentences_ud_sentencez_token_filter_v3_textNoPeriod_cntx_3_cont_7.pkl"
with open(data_dir, 'rb') as f:
    data = pickle.load(f)

sentences = []
for i, raw_sentence in enumerate(data):
    sentence = tokenizer.encode(raw_sentence)
    if len(sentence) < sentence_length:
        pass
    sentence = sentence[:sentence_length]
    sentences.append(sentence)

model = GPT2LMHeadModel.from_pretrained(model_name)
model.to(device)

activations = compute_model_activations(model, sentences, device)
curvatures = compute_model_curvature(activations)

full_path = f"{path}/data/{datatype}"

if not os.path.exists(full_path):
    os.makedirs(full_path)

with open(f"{full_path}/{model_name}_activations.pkl", "wb") as f:
    pickle.dump(activations, f)

with open(f"{full_path}/{model_name}_curvatures.pkl", "wb") as f:
    pickle.dump(curvatures, f)

with open(f"{full_path}/sentences.pkl", "wb") as f:
    pickle.dump(sentences, f)

# Split By Sentence Statistics

## Collect All Sentences

In [14]:
fw = load_dataset("HuggingFaceFW/fineweb", name="sample-10BT", split="train", streaming=True)
fw_text = fw.select_columns("text")

In [18]:
two_word_sets = []
for i, sample in enumerate(fw_text):
    sample = tokenizer.encode(sample["text"])
    #randomly extract 10 pairs of words
    for j in range(10):
        start = np.random.randint(0, len(sample)-2)
        two_word_sets.append(sample[start:start+2])
    if i % 1000 == 0:
        print(i)
    if len(two_word_sets) >= 100000:
        break

Token indices sequence length is longer than the specified maximum sequence length for this model (1048 > 1024). Running this sequence through the model will result in indexing errors


0
1000
2000
3000
4000
5000
6000
7000
8000
9000


In [7]:
sentence_length = 10
datasource = "fineweb"
datatype = f"{datasource}/{sentence_length}_word"
num_sentences = 10000

sentences = []
for i, sample in enumerate(fw_text):
    #pull out any sentences are of length 10
    sample = sent_tokenize(sample["text"])
    sents = [sentence for sentence in sample if len(tokenizer.encode(sentence)) == sentence_length]
    if sents:
        sentences.extend(sents)
    if len(sentences) > num_sentences:
        break

full_path = f"{path}/data/{datatype}"

if not os.path.exists(full_path):
    os.makedirs(full_path)

with open(f"{full_path}/all_sentences.pkl", "wb") as f:
    pickle.dump(sentences, f)

## Calculate full context and two token suprisal for each sentence

In [8]:
sentence_length = 10
datasource = "fineweb"
datatype = f"{datasource}/{sentence_length}_word"
full_path = f"{path}/data/{datatype}"

with open(f"{full_path}/all_sentences.pkl", "rb") as f:
    sentences = pickle.load(f)

In [9]:
batch_size = 64

class SentenceDataset(Dataset):
    def __init__(self, inputs):
        self.inputs = inputs

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        inputs = self.inputs[idx]
        return inputs

inputs = tokenizer(sentences, return_tensors="pt", max_length=sentence_length, truncation=True)["input_ids"]
dataset = SentenceDataset(inputs)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

# finding the suprisal at each token given the full previous context

def get_whole_context_surprisals(model, dataloader, device):
    model.eval()
    surprisals = []
    for batch in tqdm(dataloader):
        inputs = batch.to(device)
        with torch.no_grad():
            outputs = model(inputs)
        logits = outputs.logits
        log_probs = -F.log_softmax(logits, dim=-1)

        for i in range(inputs.shape[0]):
            #find the surprisal of each place in the context predicting the next token
            token_indices = inputs[i, 1:].cpu().numpy()
            token_log_probs = log_probs[i, torch.arange(len(token_indices)), token_indices]
            surprisals.append(token_log_probs.cpu().numpy())

    return surprisals

whole_context_surprisals = get_whole_context_surprisals(model, dataloader, device)

with open(f"{full_path}/whole_context_surprisals.pkl", "wb") as f:
    pickle.dump(whole_context_surprisals, f)

# finding the surprisal at each token given only the previous two tokens

def get_two_token_context_surprisals(model, dataloader, device):
    model.eval()
    surprisals = []
    for batch in tqdm(dataloader):
        #get every three token pair
        # shape: (batch_size, max_length-2, 3)
        batch = np.stack([batch[:, i:i+3] for i in range(sentence_length - 2)], axis=1)
        for sample in batch:
            inputs = torch.tensor(sample).to(device)
            with torch.no_grad():
                outputs = model(inputs)
            logits = outputs.logits
            log_probs = -F.log_softmax(logits, dim=-1)

            one_sentence = []

            for i in range(inputs.shape[0]):
                #for each three token string
                token_index = inputs[i, -1].cpu().numpy() #get the last token
                token_log_prob = log_probs[i, -2, token_index] #prob of the second to last token predicting the last token
                one_sentence.append(token_log_prob.item())

            surprisals.append(one_sentence)

    return surprisals

two_token_context_surprisals = get_two_token_context_surprisals(model, dataloader, device)

with open(f"{full_path}/two_token_context_surprisals.pkl", "wb") as f:
    pickle.dump(two_token_context_surprisals, f)


100%|██████████| 157/157 [00:23<00:00,  6.72it/s]
100%|██████████| 157/157 [05:07<00:00,  1.96s/it]


## Find sentences with correct statistics 

In [None]:
sentence_length = 10
datasource = "fineweb"
datatype = f"{datasource}/{sentence_length}_word"
full_path = f"{path}/data/{datatype}"

whole_context_surprisals = pickle.load(open(f"{full_path}/whole_context_surprisals.pkl", "rb"))
two_token_surprisals = pickle.load(open(f"{full_path}/two_token_context_surprisals.pkl", "rb"))
sentences = pickle.load(open(f"{full_path}/all_sentences.pkl", "rb"))

whole_context_surprisals = np.array(whole_context_surprisals)
two_token_surprisals = np.array(two_token_surprisals)

In [None]:
slow_sentences = []
whole_slow_surprisals = []
two_token_slow_surprisals = []
fast_sentences = []
whole_fast_surprisals = []
two_token_fast_surprisals = []

all_indivd_suprisals = []
for surprisal in whole_context_surprisals:
    all_indivd_suprisals.extend(surprisal)
all_indivd_suprisals = np.array(all_indivd_suprisals)
whole_suprisal_mean = np.mean(all_indivd_suprisals)
whole_surprisal_std = np.std(all_indivd_suprisals)

all_indivd_suprisals = []
for surprisal in two_token_surprisals:
    all_indivd_suprisals.extend(surprisal)
all_indivd_suprisals = np.array(all_indivd_suprisals)
two_token_suprisal_mean = np.mean(all_indivd_suprisals)
two_token_surprisal_std = np.std(all_indivd_suprisals)


def check_for_fast(whole_context_surprisal, two_token_surprisal):
    #whole context surprisal is high and limited (two token) context is low
    clause1 = np.all(whole_context_surprisal > whole_suprisal_mean - whole_surprisal_std*4/5) 
    clause2 = np.all(two_token_surprisal < (two_token_suprisal_mean + two_token_surprisal_std / 2))

    return clause1 and clause2
                                                                
def check_for_slow(whole_context_surprisal, two_token_surprisal):
    # the last two tokens are low surprisal with full context
    clause1 = np.all(whole_context_surprisal[-2:] < whole_suprisal_mean - whole_surprisal_std / 3)

    # the third to last token is also low surprisal with full context
    clause3 = np.all(whole_context_surprisal[-4:-2] < whole_suprisal_mean)

    #in general the model does not have fast statistics
    clause2 = np.all(two_token_surprisal > two_token_suprisal_mean - two_token_surprisal_std*5/8) 

    #this is especially true for the last few tokens, because their prediction needs to come from slow context
    clause4 = np.all(two_token_surprisal[-3:] > two_token_suprisal_mean - two_token_surprisal_std*2/3)

    return clause1 and clause2 and clause3 and clause4

for sentence, whole_context_surprisal, two_token_surprisal in zip(sentences, whole_context_surprisals, two_token_surprisals):

    if check_for_slow(whole_context_surprisal, two_token_surprisal):
        slow_sentences.append(sentence)
        whole_slow_surprisals.append(whole_context_surprisal)
        two_token_slow_surprisals.append(two_token_surprisal)

    elif check_for_fast(whole_context_surprisal, two_token_surprisal):
        fast_sentences.append(sentence)
        whole_fast_surprisals.append(whole_context_surprisal)
        two_token_fast_surprisals.append(two_token_surprisal)

whole_fast_surprisals = np.array(whole_fast_surprisals)
two_token_fast_surprisals = np.array(two_token_fast_surprisals)
whole_slow_surprisals = np.array(whole_slow_surprisals)
two_token_slow_surprisals = np.array(two_token_slow_surprisals)

with open(f"{full_path}/slow_sentences.pkl", "wb") as f:
    pickle.dump(slow_sentences, f)

with open(f"{full_path}/whole_slow_surprisals.pkl", "wb") as f:
    pickle.dump(whole_slow_surprisals, f)

with open(f"{full_path}/two_token_slow_surprisals.pkl", "wb") as f:
    pickle.dump(two_token_slow_surprisals, f)

with open(f"{full_path}/fast_sentences.pkl", "wb") as f:
    pickle.dump(fast_sentences, f)

with open(f"{full_path}/whole_fast_surprisals.pkl", "wb") as f:
    pickle.dump(whole_fast_surprisals, f)

with open(f"{full_path}/two_token_fast_surprisals.pkl", "wb") as f:
    pickle.dump(two_token_fast_surprisals, f)

In [None]:
np.mean(whole_fast_surprisals), np.mean(two_token_fast_surprisals), np.mean(whole_slow_surprisals), np.mean(two_token_slow_surprisals)
np.mean(whole_fast_surprisals, axis=0), np.mean(whole_slow_surprisals, axis=0)
np.mean(two_token_fast_surprisals, axis=0), np.mean(two_token_slow_surprisals, axis=0)

## Compare Curvature

In [None]:
sentence_length = 10
datasource = "fineweb"
datatype = f"{datasource}/{sentence_length}_word"
full_path = f"{path}/data/{datatype}"

slow_sentences = pickle.load(open(f"{full_path}/slow_sentences.pkl", "rb"))
fast_sentences = pickle.load(open(f"{full_path}/fast_sentences.pkl", "rb"))

model = GPT2LMHeadModel.from_pretrained("gpt2-xl")
model.to(device)

fast_sentences = tokenizer(fast_sentences, return_tensors="pt", max_length=sentence_length, truncation=True)["input_ids"]
activations = compute_model_activations(model, fast_sentences, device)
fast_curvature = compute_model_curvature(activations)

slow_sentences = tokenizer(slow_sentences, return_tensors="pt", max_length=sentence_length, truncation=True)["input_ids"]
activations = compute_model_activations(model, slow_sentences, device)
slow_curvature = compute_model_curvature(activations)

In [None]:
fast_curve = np.mean(fast_curvature["curve"], axis=1) / np.pi * 180
slow_curve = np.mean(slow_curvature["curve"], axis=1) / np.pi * 180
fast_curve_std = np.std(fast_curvature["curve"], axis=1) / np.pi * 180
slow_curve_std = np.std(slow_curvature["curve"], axis=1) / np.pi * 180
plt.plot(fast_curve, label="fast")
plt.plot(slow_curve, label="slow")
plt.fill_between(np.arange(len(fast_curve)), fast_curve - fast_curve_std, fast_curve + fast_curve_std, alpha=0.5)
plt.fill_between(np.arange(len(slow_curve)), slow_curve - slow_curve_std, slow_curve + slow_curve_std, alpha=0.5)
plt.legend()
plt.show()

### Curvature over last four tokens

In [None]:
fast_curve = np.mean(np.mean(fast_curvature["all_layer_curve_all"], axis=0)[:, -4:], axis = -1) / np.pi * 180
slow_curve = np.mean(np.mean(slow_curvature["all_layer_curve_all"], axis=0)[:, -4:], axis = -1) / np.pi * 180
fast_curve_std = np.std(fast_curvature["curve"], axis=1) / np.pi * 180
slow_curve_std = np.std(slow_curvature["curve"], axis=1) / np.pi * 180
plt.plot(fast_curve, label="fast")
plt.plot(slow_curve, label="slow")
plt.fill_between(np.arange(len(fast_curve)), fast_curve - fast_curve_std, fast_curve + fast_curve_std, alpha=0.5)
plt.fill_between(np.arange(len(slow_curve)), slow_curve - slow_curve_std, slow_curve + slow_curve_std, alpha=0.5)
plt.legend()
plt.show()

### Curvature at Each token

In [None]:
fig, axs = plt.subplots(7, 1, figsize=(8, 20))
for i in range(1, 8):
    curve = np.mean(np.mean(slow_curvature["all_layer_curve_all"], axis=0)[:, i-1:i], axis = -1) / np.pi * 180
    axs[i-1].plot(curve, label="slow")
    curve = np.mean(np.mean(fast_curvature["all_layer_curve_all"], axis=0)[:, i-1:i], axis = -1) / np.pi * 180
    axs[i-1].plot(curve, label="fast")
    axs[i-1].set_title(f"Curve {i}")

plt.legend()
plt.show()