In [1]:
import torch
import math

from transformers import GPT2Tokenizer, GPT2LMHeadModel, set_seed
from datasets import load_dataset

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

from minicons import scorer
from torch.utils.data import DataLoader

import json
from datasets import load_dataset, load_from_disk
import os
from tqdm import tqdm

import pickle
import gc
from modular_transformers.straightening.straightening_utils import compute_model_activations, compute_model_curvature

from modular_transformers.models import components
from transformer_xray.perturb_utils import register_pertubation_hooks

#set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#set tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

fine_tuned_models = ["finetuned-1-l2_curvature-768x12", "finetuned-1-l1_curvature-768x12", "huggingface-pretrained-768x12", "finetuned-1-l0_curvature-768x12"]
fulltrained_models = ['warmup5-lr0.0006-0.1-l0_curvature-768x12', 'multi-warmup5-lr0.0006-1-l2_curvature-768x12', 'warmup5-lr0.0006-1-l2_curvature-768x12','768x12_test']

model_names = fine_tuned_models + fulltrained_models

  from .autonotebook import tqdm as notebook_tqdm
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  return torch._C._cuda_getDeviceCount() > 0


In [2]:
def generate_random_pertubation(shape, pertubation_size):
    pertubation = torch.randn(shape)
    pertubation = pertubation / np.linalg.norm(pertubation) * pertubation_size
    return pertubation

def random_perturbation_function(input, layer, token):
    size = np.linalg.norm(input) / 10
    return generate_random_pertubation(input.shape, size)

def compute_surprisals_with_context(model_names, prefixes, queries, perturbation_type, surprisal_type):
    surprisal_dict = torch.load("/om2/user/jackking/modular_transformers/modular_transformers/train/surprisal_dict.pt", map_location=torch.device('cpu'))
    
    if perturbation_type not in surprisal_dict:
        surprisal_dict[perturbation_type] = {}
    if surprisal_type not in surprisal_dict[perturbation_type]:
        surprisal_dict[perturbation_type][surprisal_type] = {}

    for model_name in tqdm(model_names):
        print(model_name)
        if model_name in surprisal_dict[perturbation_type][surprisal_type]:
            continue
        path = f'/om2/user/jackking/MyData/mt/miniberta_100M/{model_name}/checkpoint_final'
        model = components.LM.from_pretrained(path)

        if perturbation_type == "activation":
            perturbation_hooks = {0: [("before_attn", "all", random_perturbation_function)]}
            register_pertubation_hooks(model, perturbation_hooks, device)

        model = scorer.IncrementalLMScorer(model, tokenizer=tokenizer, device=device)
        all_surprisals = []
        for prefix, query in tqdm(zip(prefixes, queries)):
            surprisals = model.conditional_score(prefix, query, reduction = lambda x: -x.sum(0))
            all_surprisals.extend(surprisals)
        surprisal_dict[perturbation_type][surprisal_type][model_name] = np.array(all_surprisals)

    with open("/om2/user/jackking/modular_transformers/modular_transformers/train/surprisal_dict.pt", 'wb') as f:
        torch.save(surprisal_dict, f)

def perturb_inputs(input_ids, perturbation_type):
    perturbation_amount = math.ceil(len(input_ids) * 0.1)
    if perturbation_type == "swap":
        for i in range(perturbation_amount):
            idx1, idx2 = np.random.choice(len(input_ids), 2)
            input_ids[idx1], input_ids[idx2] = input_ids[idx2], input_ids[idx1]
        return input_ids
    elif perturbation_type == "remove":
        for i in range(perturbation_amount):
            idx = np.random.choice(len(input_ids))
            input_ids[idx] = tokenizer.pad_token_id
        return input_ids
    elif perturbation_type == "replace":
        for i in range(perturbation_amount):
            idx = np.random.choice(len(input_ids))
            input_ids[idx] = np.random.choice(len(tokenizer))
        return input_ids
    else:
        raise ValueError(f"perturbation_type {perturbation_type} not recognized")

In [3]:
path = '/om/weka/evlab/ehoseini/MyData/miniBERTa_v2/'
data_size = "10M"
data = load_from_disk(
    os.path.join(path, f'miniBERTa-{data_size}-crunched',
                    f'train_context_len_{512}'))

prefixes = [tokenizer.decode(sample[:20]) for sample in data["input_ids"][:500]]
queries = [tokenizer.decode(sample[20:40]) for sample in data["input_ids"][:500]]

compute_surprisals_with_context(fine_tuned_models, prefixes, queries, "activation", "continuation")

  0%|          | 0/4 [00:00<?, ?it/s]

finetuned-1-l2_curvature-768x12


500it [05:28,  1.52it/s]
 25%|██▌       | 1/4 [05:32<16:37, 332.41s/it]

finetuned-1-l1_curvature-768x12


500it [04:42,  1.77it/s]
 50%|█████     | 2/4 [10:26<10:19, 309.83s/it]

huggingface-pretrained-768x12


500it [05:00,  1.67it/s]
 75%|███████▌  | 3/4 [15:38<05:10, 310.78s/it]

finetuned-1-l0_curvature-768x12


500it [04:44,  1.76it/s]
100%|██████████| 4/4 [20:34<00:00, 308.67s/it]


In [4]:
surprisal_dict = torch.load("/om2/user/jackking/modular_transformers/modular_transformers/train/surprisal_dict.pt", map_location=torch.device('cpu'))
surprisal_dict["activation"]["continuation"].keys()

dict_keys(['finetuned-1-l2_curvature-768x12', 'finetuned-1-l1_curvature-768x12', 'huggingface-pretrained-768x12', 'finetuned-1-l0_curvature-768x12'])