In [None]:
import numpy as np
import pickle
import matplotlib.pyplot as plt
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from tqdm import tqdm
import torch as torch
from datasets import load_dataset, load_from_disk
import os
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import random


tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
path = "/om2/user/jackking/modular_transformers/scripts/input_statistics/data"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class LMDataset(Dataset):
    def __init__(self, inputs, attn_mask=None, labels=None):
        #cast to tensors if not already tensors
        if not torch.is_tensor(inputs):
            inputs = torch.tensor(inputs)
        if not torch.is_tensor(labels):
            labels = torch.tensor(labels)
        if attn_mask is not None and not torch.is_tensor(attn_mask):
            attn_mask = torch.tensor(attn_mask)
            
        self.inputs = inputs
        self.attn_mask = attn_mask
        self.labels = labels

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        if self.labels is None:
            item = {
                'input_ids': self.inputs[idx],
                'attention_mask': self.attn_mask[idx]}
        elif self.attn_mask is None:
            item = {
                'input_ids': self.inputs[idx],
                'labels': self.labels[idx]
            }
        else:
            item = {
                'input_ids': self.inputs[idx],
                'attention_mask': self.attn_mask[idx],
                'labels': self.labels[idx]
            }
        return item

def make_autoregressive_dataset(data):
    tokenizer.pad_token = tokenizer.eos_token
    dataset = tokenizer.batch_encode_plus(data, add_special_tokens=True, padding='longest', return_tensors="pt")
    inputs = dataset["input_ids"]
    attn_mask = dataset["attention_mask"]
    labels = dataset["input_ids"].clone()
    context_len = inputs.size(1)
    return LMDataset(inputs, attn_mask, labels), context_len

def make_classification_dataset(data1, data2):
    tokenizer.pad_token = tokenizer.eos_token
    len1 = len(data1)
    len2 = len(data2)
    combined = data1 + data2
    labels = [0]*len1 + [1]*len2
    dataset = tokenizer.batch_encode_plus(combined, add_special_tokens=True, padding='longest', return_tensors="pt")
    inputs = dataset["input_ids"]
    attn_mask = dataset["attention_mask"]
    context_len = inputs.size(1)
    return LMDataset(inputs, attn_mask, torch.tensor(labels)), context_len

# Experiment 1

## Bigrams

In [None]:
def get_bigram_model_to_sample_from(data):
    bigram_model = {}
    for sentence in data:
        for i in range(len(sentence)-1):
            if sentence[i] in bigram_model:
                bigram_model[sentence[i]].append(sentence[i+1])
            else:
                bigram_model[sentence[i]] = [sentence[i+1]]

    return bigram_model

def sample_from_bigram_model(bigram_model, num_samples, string_len):
    samples = []
    bigram_model_keys = list(bigram_model.keys())
    for i in tqdm(range(num_samples)):
        sample = [random.choice(bigram_model_keys)]
        for _ in range(string_len - 1):
            if sample[-1] not in bigram_model:
                sample.append(random.choice(bigram_model_keys))
            else:
                sample.append(random.choice(bigram_model[sample[-1]]))
        samples.append(sample)
    return samples

#### note, with string_len=128 there are an insignifgant number of samples with a relevant attention mask

In [None]:
string_len = 128
train_set_size = 20000
valid_set_size = 5000
datatype = "experiment_1"

data_path = '/om/weka/evlab/ehoseini/MyData/miniBERTa_v2/'
grouped_pad_train = load_from_disk(
    os.path.join(data_path, f'miniBERTa-{10}M-crunched',
                    f'train_context_len_{512}'))
subset_idxs = np.random.choice(len(grouped_pad_train), train_set_size, replace=False)
subset = grouped_pad_train.select(subset_idxs)["input_ids"]
subset = [x[:string_len] for x in subset]

with open(f"{path}/{datatype}/train_data_A.pkl", 'wb') as f:
    pickle.dump(subset, f)

data_path = '/om/weka/evlab/ehoseini/MyData/miniBERTa_v2/'
grouped_pad_train = load_from_disk(
    os.path.join(data_path, f'miniBERTa-{10}M-crunched',
                    f'valid_context_len_{512}'))
subset_idxs = np.random.choice(len(grouped_pad_train), valid_set_size, replace=False)
subset = grouped_pad_train.select(subset_idxs)["input_ids"]
subset = [x[:string_len] for x in subset]

with open(f"{path}/{datatype}/valid_data_A.pkl", 'wb') as f:
    pickle.dump(subset, f)

In [None]:
string_len = 128
train_set_size = 20000
valid_set_size = 5000
datatype = "experiment_1"

In [None]:
train_data = pickle.load(open(f"{path}/{datatype}/train_data_A.pkl", 'rb'))
valid_data = pickle.load(open(f"{path}/{datatype}/valid_data_A.pkl", 'rb'))
data = train_data + valid_data

bigram_model = get_bigram_model_to_sample_from(data)
data_B1 = sample_from_bigram_model(bigram_model, train_set_size//2 + valid_set_size//2, string_len = string_len)
np.random.shuffle(data_B1)
train_data_B1 = data_B1[:train_set_size//2]
valid_data_B1 = data_B1[train_set_size//2:]

#switch up bigram model
new_bigram_model = {}

all_tokens = bigram_model.keys()
all_tokens = list(all_tokens)
np.random.shuffle(all_tokens)

for i, first_token in tqdm(enumerate(bigram_model.keys())):
    second_tokens = bigram_model[all_tokens[i]]
    # new_bigram_model[first_token] = [np.random.choice(second_tokens) for _ in range(len(second_tokens))]
    new_bigram_model[first_token] = second_tokens

data_B2 = sample_from_bigram_model(new_bigram_model, train_set_size//2 + valid_set_size//2, string_len = string_len)
np.random.shuffle(data_B2)
train_data_B2 = data_B2[:train_set_size//2]
valid_data_B2 = data_B2[train_set_size//2:]

# with open(f"{path}/{datatype}/train_data_EB1.pkl", 'wb') as f:
#     pickle.dump(train_data_B1, f)

# with open(f"{path}/{datatype}/valid_data_EB1.pkl", 'wb') as f:
#     pickle.dump(valid_data_B1, f)

with open(f"{path}/{datatype}/train_data_EB2.pkl", 'wb') as f:
    pickle.dump(train_data_B2, f)

with open(f"{path}/{datatype}/valid_data_EB2.pkl", 'wb') as f:
    pickle.dump(valid_data_B2, f)
    


## Trigrams

In [None]:
def get_trigram_model_to_sample_from(data):
    trigram_model = {}
    for sentence in data:
        for i in range(len(sentence) - 2):
            first_two = (sentence[i], sentence[i+1])
            if first_two in trigram_model:
                trigram_model[first_two].append(sentence[i+2])
            else:
                trigram_model[first_two] = [sentence[i+2]]
    
    return trigram_model

def sample_from_trigram_model(trigram_model, num_samples, string_len):
    samples = []
    trigram_model_keys = list(trigram_model.keys())
    for _ in tqdm(range(num_samples)):
        next_bigram = random.choice(trigram_model_keys)
        sample = [next_bigram[0], next_bigram[1]]
        for i in range(string_len - 2):
            first_two = (sample[-2], sample[-1])
            if first_two not in trigram_model:
                next_bigram = random.choice(trigram_model_keys)
                sample.append(next_bigram[0])
                sample.append(next_bigram[1])
                i += 1
            else:
                sample.append(random.choice(trigram_model[first_two]))
        
        if len(sample) > string_len:
            sample = sample[:string_len]
        samples.append(sample)
    return samples
        

In [None]:
train_data = pickle.load(open(f"{path}/{datatype}/train_data_A.pkl", 'rb'))
valid_data = pickle.load(open(f"{path}/{datatype}/valid_data_A.pkl", 'rb'))
data = train_data + valid_data

trigram_model = get_trigram_model_to_sample_from(data)
data_T1 = sample_from_trigram_model(trigram_model, train_set_size//2 + valid_set_size//2, string_len)
#split into train and test randomly
np.random.shuffle(data_T1)
train_data_T1 = data_T1[:train_set_size//2]
valid_data_T1 = data_T1[train_set_size//2:]

with open(f"{path}/{datatype}/train_data_ET1.pkl", 'wb') as f:
    pickle.dump(train_data_T1, f)

with open(f"{path}/{datatype}/valid_data_ET1.pkl", 'wb') as f:
    pickle.dump(valid_data_T1, f)

# #switch up bigram model
new_trigram_model = {}

all_tokens = trigram_model.keys()
all_tokens = list(all_tokens)
np.random.shuffle(all_tokens)

for i, first_two_tokens in tqdm(enumerate(trigram_model.keys())):
    third_tokens = trigram_model[all_tokens[i]]
    # new_trigram_model[first_two_tokens] = [np.random.choice(third_tokens) for _ in range(len(third_tokens))]
    new_trigram_model[first_two_tokens] = third_tokens

data_T2 = sample_from_trigram_model(new_trigram_model, train_set_size//2 + valid_set_size//2, string_len)
#split into train and test randomly
np.random.shuffle(data_T2)
train_data_T2 = data_T2[:train_set_size//2]
valid_data_T2 = data_T2[train_set_size//2:]

with open(f"{path}/{datatype}/train_data_ET2.pkl", 'wb') as f:
    pickle.dump(train_data_T2, f)

with open(f"{path}/{datatype}/valid_data_ET2.pkl", 'wb') as f:
    pickle.dump(valid_data_T2, f)
    


## Fourgrams

In [None]:
def get_fourgram_model_to_sample_from(data):
    fourgram_model = {}
    for sentence in data:
        for i in range(len(sentence) - 3):
            first_three = (sentence[i], sentence[i+1], sentence[i+2])
            if first_three in fourgram_model:
                fourgram_model[first_three].append(sentence[i+3])
            else:
                fourgram_model[first_three] = [sentence[i+3]]
    
    return fourgram_model

def sample_from_fourgram_model(fourgram_model, num_samples, string_len):
    samples = []
    fourgram_model_keys = list(fourgram_model.keys())
    for _ in tqdm(range(num_samples)):
        next_trigram = random.choice(fourgram_model_keys)
        sample = [next_trigram[0], next_trigram[1], next_trigram[2]]
        for i in range(string_len - 3):
            first_three = (sample[-3], sample[-2], sample[-1])
            if first_three not in fourgram_model:
                next_trigram = random.choice(fourgram_model_keys)
                sample.append(next_trigram[0])
                sample.append(next_trigram[1])
                sample.append(next_trigram[2])
                i += 2
            else:
                sample.append(random.choice(fourgram_model[first_three]))
        if len(sample) > string_len:
            sample = sample[:string_len]
        samples.append(sample)
    return samples

In [None]:
train_data = pickle.load(open(f"{path}/{datatype}/train_data_A.pkl", 'rb'))
valid_data = pickle.load(open(f"{path}/{datatype}/valid_data_A.pkl", 'rb'))
data = train_data + valid_data

fourgram_model = get_fourgram_model_to_sample_from(data)
data_F1 = sample_from_fourgram_model(fourgram_model, train_set_size//2 + valid_set_size//2, string_len)
np.random.shuffle(data_F1)
train_data_F1 = data_F1[:train_set_size//2]
valid_data_F1 = data_F1[train_set_size//2:]

with open(f"{path}/{datatype}/train_data_EF1.pkl", 'wb') as f:
    pickle.dump(train_data_F1, f)

with open(f"{path}/{datatype}/valid_data_EF1.pkl", 'wb') as f:
    pickle.dump(valid_data_F1, f)

#switch up bigram model
new_fourgram_model = {}

# for first_three_tokens, fourth_tokens in tqdm(fourgram_model.items()):
#     fourth_tokens_set = list(set(fourth_tokens))
#     new_fourgram_model[first_three_tokens] = [np.random.choice(fourth_tokens_set) for _ in range(len(fourth_tokens))]

all_tokens = fourgram_model.keys()
all_tokens = list(all_tokens)
np.random.shuffle(all_tokens)

for i, first_token in tqdm(enumerate(fourgram_model.keys())):
    fourth_tokens = fourgram_model[all_tokens[i]]
    # new_fourgram_model[first_token] = [np.random.choice(fourth_tokens) for _ in range(len(fourth_tokens))]
    new_fourgram_model[first_token] = fourth_tokens

data_F2 = sample_from_fourgram_model(new_fourgram_model, train_set_size//2 + valid_set_size//2, string_len)
np.random.shuffle(data_F2)
train_data_F2 = data_F2[:train_set_size//2]
valid_data_F2 = data_F2[train_set_size//2:]

with open(f"{path}/{datatype}/train_data_EF2.pkl", 'wb') as f:
    pickle.dump(train_data_F2, f)

with open(f"{path}/{datatype}/valid_data_EF2.pkl", 'wb') as f:
    pickle.dump(valid_data_F2, f)

In [None]:
train_data = pickle.load(open(f"{path}/{datatype}/train_data_A.pkl", 'rb'))
valid_data = pickle.load(open(f"{path}/{datatype}/valid_data_A.pkl", 'rb'))
data = train_data + valid_data

fourgram_model = get_fourgram_model_to_sample_from(data)

In [None]:
fourgram_model.keys()

In [None]:
lengths = [len(fourgram_model[key]) for key in fourgram_model.keys()]
#take top 20 lengths
lengths = sorted(lengths, reverse=True)[:1000]

In [None]:
len(lengths)

In [None]:
lengths = [(len(fourgram_model[key]), key) for key in fourgram_model.keys()]
print(max(lengths))

In [None]:
plt.hist(lengths)

## Evening out Entropy

In [None]:
def calculate_entropy(model, dataloader):
    entropies = []
    
    for batch in tqdm(dataloader):
        # Get model outputs (logits)
        with torch.no_grad():
            outputs = model(batch["input_ids"].to(device))
            logits = outputs.logits

        # Convert logits to probabilities
        probs = torch.softmax(logits, dim=-1).squeeze()

        # Calculate entropy for each token
        token_entropies = -torch.sum(probs * torch.log(probs), dim=-1)

        # Average entropy over all tokens in the text
        avg_entropy = token_entropies.mean().item()
        entropies.append(avg_entropy)
    
    # Calculate average entropy over all texts in the dataset
    dataset_entropy = np.mean(entropies)
    return dataset_entropy

In [None]:
model_name = "experiment_1/M1_128_12/faithful-elevator-344/epoch_80"
model_path = f'{path}/{model_name}'
# model = GPT2LMHeadModel.from_pretrained(model_path)
model = GPT2LMHeadModel.from_pretrained("gpt2")
model.to(device)
model.eval()
datatype = "experiment_1"

natural_data = pickle.load(open(f"{path}/{datatype}/train_data_A.pkl", 'rb'))
dataset = LMDataset(natural_data, labels = natural_data)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=False)
dataset_entropy = calculate_entropy(model, dataloader)
print(f"Dataset entropy: {dataset_entropy}")

bigram_data = pickle.load(open(f"{path}/{datatype}/train_data_EB1.pkl", 'rb')) + pickle.load(open(f"{path}/{datatype}/train_data_EB2.pkl", 'rb'))
dataset = LMDataset(bigram_data, labels = natural_data)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=False)
dataset_entropy = calculate_entropy(model, dataloader)
print(f"Bigram dataset entropy: {dataset_entropy}")

In [None]:
bigram_data = pickle.load(open(f"{path}/{datatype}/train_data_EB1.pkl", 'rb')) + pickle.load(open(f"{path}/{datatype}/train_data_EB2.pkl", 'rb'))
dataset = LMDataset(bigram_data, labels = natural_data)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=False)
dataset_entropy = calculate_entropy(model, dataloader)
print(f"Bigram dataset entropy: {dataset_entropy}")

In [None]:
import nltk
from nltk import bigrams, FreqDist, ConditionalFreqDist
import numpy as np

# Ensure necessary NLTK data packages are downloaded
nltk.download('punkt')

def calculate_bigram_entropy(texts):
    # Create bigrams for all texts
    bigrams_list = [bigram for text in texts for bigram in bigrams(text)]

    # Frequency distribution of bigrams and individual words
    bigram_fd = FreqDist(bigrams_list)
    word_fd = FreqDist(word for text in texts for word in text)

    # Conditional frequency distribution of bigrams
    cfd = ConditionalFreqDist(bigrams_list)

    # Calculate bigram probabilities
    bigram_probabilities = {}
    for word in word_fd:
        for following_word in cfd[word]:
            bigram_probabilities[(word, following_word)] = cfd[word][following_word] / word_fd[word]

    # Calculate entropy
    entropies = []
    for bigram, prob in bigram_probabilities.items():
        entropies.append(-prob * np.log2(prob))

    # Average entropy over all bigrams
    dataset_entropy = np.mean(entropies)
    return dataset_entropy

bigram_data = pickle.load(open(f"{path}/{datatype}/train_data_B1.pkl", 'rb')) + pickle.load(open(f"{path}/{datatype}/valid_data_B2.pkl", 'rb'))
calculate_bigram_entropy(bigram_data)

In [None]:
def calculate_entropy_of_bigram_model(bigram_model, unigram_frequencies):
    reverse_bigram_model = {}
    for first_word, second_words in bigram_model.items():
        for second_word in second_words:
            if second_word in reverse_bigram_model:
                reverse_bigram_model[second_word].append(first_word)
            else:
                reverse_bigram_model[second_word] = [first_word]
    
    entropy = 0
    for second_word, first_words in tqdm(reverse_bigram_model.items()):
        first_words = set(first_words)
        probs_second_word = [bigram_model[first_word].count(second_word)/len(bigram_model[first_word]) * unigram_frequencies[first_word]/len(unigram_frequencies) for first_word in first_words]
        prob_second_word = sum(probs_second_word)
        entropy_second_word = -prob_second_word*np.log2(prob_second_word)
        entropy += entropy_second_word
    
    return entropy / len(reverse_bigram_model)


def get_bigram_model_with_correct_entropy(data, entropy_target):
    bigram_model = {}
    unigram_frequencies = {}
    for sentence_idx, sentence in enumerate(data):

        for i in range(len(sentence)-1):

            if sentence[i] in unigram_frequencies:
                unigram_frequencies[sentence[i]] += 1
            else:
                unigram_frequencies[sentence[i]] = 1

            if sentence[i] in bigram_model:
                bigram_model[sentence[i]].append(sentence[i+1])
            else:
                bigram_model[sentence[i]] = [sentence[i+1]]

        # if sentence_idx > 15000 and sentence_idx % 200 == 0:
        #     entropy = calculate_entropy_of_bigram_model(bigram_model, unigram_frequencies)
        #     print(f"Entropy: {entropy}")
        #     if entropy > entropy_target:
        #         break

    print(f"Bigram model entropy: {calculate_entropy_of_bigram_model(bigram_model, unigram_frequencies)}")
    return bigram_model

def sample_from_bigram_model(bigram_model, num_samples, string_len = None, lengths=None):
    samples = []
    for i in tqdm(range(num_samples)):
        sample = [np.random.choice(list(bigram_model.keys()))]
        length = string_len
        for _ in range(length - 1):
            if sample[-1] not in bigram_model:
                sample.append(np.random.choice(list(bigram_model.keys())))
            else:
                sample.append(np.random.choice(bigram_model[sample[-1]]))
        samples.append(sample)
    return samples

In [None]:
bigram_model = get_bigram_model_with_correct_entropy(data, dataset_entropy)
# train_data_B1 = sample_from_bigram_model(bigram_model, train_set_size//2, string_len = string_len)
# valid_data_B1 = sample_from_bigram_model(bigram_model, valid_set_size//2, string_len = string_len)

In [None]:
def calculate_bigram_entropy_of_dataset(data, bigram_model):
    reverse_bigram_model = {}
    for first_word, second_words in bigram_model.items():
        for second_word in second_words:
            if second_word in reverse_bigram_model:
                reverse_bigram_model[second_word].append(first_word)
            else:
                reverse_bigram_model[second_word] = [first_word]

    new_reverse_bigram_model = {}
    for second_word, first_words in reverse_bigram_model.items():
        first_word_set = set(first_words)
        new_reverse_bigram_model[second_word] = {}
        for first_word in first_word_set:
            count = first_words.count(first_word)
            prob = count/len(first_words)
            new_reverse_bigram_model[second_word][first_word] = prob

    entropy = 0
    for sentence in tqdm(data):
        for i in range(1, len(sentence)):
            try:
                prob = new_reverse_bigram_model[sentence[i]][sentence[i-1]]
                entropy += -prob*np.log2(prob)
            except:
                print("error")

    return entropy / (len(data) * len(data[0]))
                

In [None]:
bigram_data = pickle.load(open(f"{path}/{datatype}/train_data_B1.pkl", 'rb')) + pickle.load(open(f"{path}/{datatype}/train_data_B2.pkl", 'rb'))
print(calculate_bigram_entropy_of_dataset(bigram_data, bigram_model))
natural_data = pickle.load(open(f"{path}/{datatype}/train_data_A.pkl", 'rb'))
print(calculate_bigram_entropy_of_dataset(natural_data, bigram_model))

# Experiment 2

In [None]:
import numpy as np

# Define the dimension and the variance of the perturbation
dim = 3000
sigma = 1
num_samples = 20000

# Define the original points in a high-dimensional space
A = np.random.rand(dim)
B = np.random.rand(dim)
C = np.random.rand(dim)

# Function to calculate the angle between two vectors
def calculate_angle(u, v):
    dot_product = np.dot(u, v)
    cos_theta = dot_product / (np.linalg.norm(u) * np.linalg.norm(v))
    theta = np.arccos(np.clip(cos_theta, -1.0, 1.0))
    return theta

# Generate random perturbations and calculate angles
angles = []

for _ in range(num_samples):
    P = np.random.normal(0, sigma, dim) 
    C_perturbed = C + P
    angle = calculate_angle(A - B, C_perturbed - B)
    angles.append(angle)

# Calculate the average angle in radians and degrees
average_angle_radians = np.mean(angles)
average_angle_degrees = np.degrees(average_angle_radians)

average_angle_radians, average_angle_degrees


In [None]:
# commented out lines are used when there is a leave probability

vocabulary = np.arange(1, 5000, 1)
context_length = 10
cycle_len = 10
leave_probability = 0.03
num_samples = 50000
num_cycles = 200

# total_probability = 1 - leave_probability
total_probability = 1
probability_model = np.exp(-np.arange(1, cycle_len, 1) / 5)
probability_model = probability_model/np.sum(probability_model)
probability_model = total_probability*probability_model
# probability_model = np.concatenate((probability_model, [leave_probability]))
token_picker = np.arange(1, cycle_len)
# token_picker = np.concatenate((token_picker, [-100]))

In [None]:
probability_model

In [None]:
#sample cycles from vocabulary
cycles = []
for _ in range(num_cycles):
    sample = np.random.choice(vocabulary, size=cycle_len, replace=False)
    cycles.append(sample)

In [None]:
samples = []
cycle_indexes = [] #labels
for sample in tqdm(range(num_samples)):
    sample = []

    cycle_index = np.random.choice(num_cycles)
    cycle_indexes.append(cycle_index)
    cycle = cycles[cycle_index]

    token_index = np.random.choice(cycle_len)
    sample.append(cycle[token_index])

    for i in range(context_length-1):
        next_token_increment = np.random.choice(token_picker, p=probability_model)
        if next_token_increment == -100:
            cycle_index = np.random.choice(num_cycles)
            cycle = cycles[cycle_index]
            token_index = np.random.choice(cycle_len)
        else:
            token_index = int((token_index+next_token_increment) % cycle_len)

        sample.append(cycle[token_index])
    
    samples.append(sample)
        

In [None]:
lengths = []
i = 0
for indx in range(1, len(cycle_indexes)):
    if cycle_indexes[indx-1] == cycle_indexes[indx]:
        i += 1
    else:
        lengths.append(i)
        i = 0

plt.hist(lengths, bins=50, alpha=0.5, label='generated data')
plt.legend(loc='upper right')
plt.show()

In [None]:
data_path = f"{path}/experiment_2"
cutoff = int(num_samples*(4/5))
train_data = {"inputs": samples[:cutoff], "labels": cycle_indexes[:cutoff]}
valid_data = {"inputs": samples[cutoff:], "labels": cycle_indexes[cutoff:]}

with open(f"{data_path}/train_data_D.pkl", 'wb') as f:
    pickle.dump(train_data, f)

with open(f"{data_path}/valid_data_D.pkl", 'wb') as f:
    pickle.dump(valid_data, f)

# Experiment 2S

In [None]:
vocabulary = np.arange(1, 500, 1)

context_length = 128
cycle_len = 24
num_samples = 20000
num_cycles = 200

#sample cycles from vocabulary
cycles = []
cycle_probs = []
for _ in range(num_cycles):
    sample = np.random.choice(vocabulary, size=cycle_len, replace=False)
    cycles.append(sample)
    probs = np.random.uniform(0.65, 0.85, size=cycle_len)
    cycle_probs.append(probs)

In [None]:
def break_integer(n, cycle_len):
    pieces = []
    low_end = int(cycle_len * (4/8))
    high_end = int(cycle_len * (7/8))
    while n > 0:
        piece = random.randint(low_end, high_end)
        if n - piece < 0:
            piece = n
        pieces.append(piece)
        n -= piece
    return pieces

In [None]:
#2S

samples = []
cycle_indexes = [] #labels
for sample_idx in tqdm(range(num_samples)):
    sample = []

    cycle_index = sample_idx % num_cycles
    cycle_indexes.append(cycle_index)
    cycle = cycles[cycle_index]

    token_index = np.random.choice(cycle_len)
    sample.append(cycle[token_index])

    for i in range(context_length-1):
        move_prob = cycle_probs[cycle_index][token_index]
        if np.random.uniform() < move_prob:
            next_token_increment = 1
        else: 
            next_token_increment = 0

        token_index = int((token_index+next_token_increment) % cycle_len)
        sample.append(cycle[token_index])
    
    samples.append(sample)

data_path = f"{path}/experiment_2S"
cutoff = int(num_samples*(3/4))
train_data = {"inputs": samples[:cutoff], "labels": cycle_indexes[:cutoff]}
valid_data = {"inputs": samples[cutoff:], "labels": cycle_indexes[cutoff:]}

with open(f"{data_path}/train_data_C.pkl", 'wb') as f:
    pickle.dump(train_data, f)

with open(f"{data_path}/valid_data_C.pkl", 'wb') as f:
    pickle.dump(valid_data, f)
        

In [None]:
#2F

samples = []
cycle_indexes = [] #labels
for sample_idx in tqdm(range(num_samples)):
    sample = []

    cycle_index = sample_idx % num_cycles
    cycle_indexes.append(cycle_index)
    cycle = cycles[cycle_index]
    fragment_sizes = break_integer(context_length, cycle_len)

    for size in fragment_sizes:

        token_index = random.randint(0, cycle_len-size-1)
        sample.append(cycle[token_index])

        for i in range(size-1):
            move_prob = cycle_probs[cycle_index][token_index]
            if np.random.uniform() < move_prob:
                next_token_increment = 1
            else: 
                next_token_increment = 0

            token_index = int(token_index+next_token_increment)
            sample.append(cycle[token_index])
    
    samples.append(sample)

data_path = f"{path}/experiment_2F"
cutoff = int(num_samples*(4/5))
train_data = {"inputs": samples[:cutoff], "labels": cycle_indexes[:cutoff]}
valid_data = {"inputs": samples[cutoff:], "labels": cycle_indexes[cutoff:]}

with open(f"{data_path}/train_data_G.pkl", 'wb') as f:
    pickle.dump(train_data, f)

with open(f"{data_path}/valid_data_G.pkl", 'wb') as f:
    pickle.dump(valid_data, f)
        

# Random Data

In [None]:
vocabulary = np.arange(1, 43748, 1)
prob_model = np.exp(-np.arange(1, 43748, 1)/43748 * 100)
prob_model = prob_model/np.sum(prob_model)

context_length = 128
num_samples = 20000

samples = []

for sample_idx in range(num_samples):
    sample = np.random.choice(vocabulary, size=context_length, p=prob_model, replace=True)
    samples.append(sample)

data_path = f"{path}/random"
cutoff = int(num_samples*(3/4))
train_data = samples[:cutoff]
valid_data = samples[cutoff:]

with open(f"{data_path}/train_data.pkl", 'wb') as f:
    pickle.dump(train_data, f)

with open(f"{data_path}/valid_data.pkl", 'wb') as f:
    pickle.dump(valid_data, f)

In [None]:
import numpy as np
import pickle
import matplotlib.pyplot as plt
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2Config
import torch as torch
import wandb
from modular_transformers.straightening.straightening_utils import compute_model_activations, compute_model_curvature


tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
path = "/om2/user/jackking/modular_transformers/scripts/input_statistics"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def get_curvature(model_name, data):
    if model_name == "untrained":
        embedding_dim = 128
        n_layer = 12                       
        model_config = GPT2Config(n_layer = n_layer, n_head = 4, n_embd = embedding_dim, ctx_len = ctx_len)
        model = GPT2LMHeadModel._from_config(model_config)
    elif model_name == "gpt2":
        model = GPT2LMHeadModel.from_pretrained("gpt2")
    else:
        model_path = f'{path}/{model_name}'
        model = GPT2LMHeadModel.from_pretrained(model_path)

    model.to(device)
    activations = compute_model_activations(model, data, device)
    curvature = compute_model_curvature(activations)
    curve = 180 / np.pi * curvature["curve"]
    return np.nanmean(curve, axis=1)

In [None]:
name = "experiment_1/M2_B1_128_12/soft-plant-402/epoch_12"
val_data_B = pickle.load(open(f"{path}/experiment_1/valid_data_B1.pkl", 'rb'))
curve = get_curvature(name, val_data_B)
plt.plot(curve)