In [3]:
import torch
import math

from transformers import GPT2Tokenizer, GPT2LMHeadModel, set_seed, AutoModel, AutoConfig
from datasets import load_dataset

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

from transformers import BatchEncoding

from minicons import scorer
from torch.utils.data import DataLoader

import json
from datasets import load_dataset, load_from_disk
import os
from tqdm import tqdm

import pickle
import gc
from modular_transformers.straightening.straightening_utils import compute_model_activations, compute_model_curvature
from modular_transformers.models.gpt2.configuration_gpt2 import GPT2Config

from modular_transformers.models import components
from transformer_xray.perturb_utils import register_pertubation_hooks

from torchviz import make_dot

from functools import partial

max_len = 25
layer_num = 48
embedding_size = 1600
first_sequence_len = 4

principal_dimensions_for_curved = 10
principal_dimensions_for_straight = 1600

#set seed
set_seed(42)

#set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#set tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, FactoredMatrix
torch.set_grad_enabled(False)


def perturb_input(input, hook, perturbations, principal_dimensions):
    perturb_idx = input.shape[1] - 1

    input[:, perturb_idx, :] = input[:, perturb_idx, :] + perturbations.to(device)

    return input

def get_curvature(P1, P2, P3):
    v1 = P2 - P1
    v2 = P3 - P2
    v1 = v1 / v1.norm(dim=-1, keepdim=True)
    v2 = v2 / v2.norm(dim=-1, keepdim=True)
    curvature = torch.acos(torch.sum(v1 * v2, dim=-1))
    return curvature

curve_difs = {0.1: None, 0.25: None, 0.5: None, 0.75: None, 1: None}

def get_perturbations(model, data, perturb_location, size):
    model.reset_hooks()

    num_samples = 10

    straightest_perturbations = torch.zeros(len(data), embedding_size)
    curviest_perturbations = torch.zeros(len(data), embedding_size)
        
    def record_perturbation_directions(input, hook):
        perturb_idx = input.shape[1] - 1

        random_directions = torch.randn(num_samples, input.shape[0], input.shape[2]).to(device)
        cloned_input = input.clone()
        norm = (cloned_input[:, perturb_idx, :] - cloned_input[:, perturb_idx-1, :]).norm(dim=-1, keepdim=True)
        random_perturbations = (random_directions / random_directions.norm(dim=-1, keepdim=True)) * norm.view(1, -1, 1) * size * 0.4
         
        norm = cloned_input[:, perturb_idx, :].norm(dim=-1, keepdim=True)
        new_points = cloned_input[:, perturb_idx, :].clone().unsqueeze(0) + random_perturbations
        new_points = new_points / new_points.norm(dim=-1, keepdim=True) * norm.view(1, -1, 1)

        perturbations_curves = torch.zeros(num_samples, input.shape[0]).to(device)
        for i, new_point in enumerate(new_points):
            curvature = get_curvature(cloned_input[:, perturb_idx-2, :], cloned_input[:, perturb_idx-1, :], new_point)
            perturbations_curves[i] = curvature
                
        min_indices = torch.argmin(perturbations_curves, dim=0)
        for i in range(input.shape[0]):
            straightest_perturbations[i, :] = random_perturbations[min_indices[i], i, :]

        max_indices = torch.argmax(perturbations_curves, dim=0)
        for i in range(input.shape[0]):
            curviest_perturbations[i, :] = random_perturbations[max_indices[i], i, :]
        
    fwd_hooks = [
        (perturb_location, record_perturbation_directions)
    ]

    model.run_with_hooks(
        data,
        return_type=None,
        fwd_hooks=fwd_hooks,
    )

    model.reset_hooks()

    return straightest_perturbations, curviest_perturbations

def generate_perturbed_token(model, data, perturb_function):

    sequence_len = data.shape[1]

    post_activations = torch.zeros((len(data), layer_num, sequence_len, embedding_size))
    def record_post_activations(input, hook, layer):
        post_activations[:, layer, :, :] = input

    fwd_hooks = []
    
    fwd_hooks.append((
            perturb_location,
            perturb_function
        ))
    
    for layer in range(layer_num):
        fwd_hooks.append((utils.get_act_name("resid_post", layer), partial(record_post_activations, layer=layer)))

    logits = model.run_with_hooks(
        data,
        return_type="logits",
        fwd_hooks=fwd_hooks
    )

    new_token = logits.argmax(dim=-1)[:, -1]

    return new_token, post_activations

def continued_gen_normal(model, data, length):
    #generate new sentences by adding new token to the end of the sentence
    final_data = torch.zeros((len(data), max_len), dtype=torch.int64).to(device)

    batch_size = 500
    batch_indxs = torch.arange(0, len(data), batch_size)
    for i in range(len(batch_indxs) - 1):
        batch = data[batch_indxs[i]:batch_indxs[i+1]]

        for _ in range(length):
            logits = model(batch).logits
            new_token = logits.argmax(dim=-1)[:, -1]
            batch = torch.cat([batch, new_token.unsqueeze(1)], dim=1)
            torch.cuda.empty_cache()

        final_data[batch_indxs[i]:batch_indxs[i+1], :] = batch.type(torch.int64)
    
    return final_data

def generate_normal_sentences(data):
    model = HookedTransformer.from_pretrained("gpt2-xl", device=device)
    
    normal_data = data.clone()

    new_token, activations = generate_perturbed_token(model, normal_data, perturb_function = lambda input, hook: None)
    normal_data = torch.cat([normal_data, new_token.unsqueeze(1)], dim=1)
    gc.collect()
    torch.cuda.empty_cache()

    del model

    model = GPT2LMHeadModel.from_pretrained("gpt2-xl").to(device)
    model.output_logits = True

    length = max_len - first_sequence_len - num_perturbations
    normal_data = continued_gen_normal(model, normal_data, length)
        
    return normal_data, activations

def continued_gen_perturbed(model, data, activations, perturb_location, length):
    #generate new sentences by adding new token to the end of the sentence
    final_data = torch.zeros((len(data), max_len), dtype=torch.int64).to(device)

    if perturb_location == "blocks.15.hook_resid_post":
        layer = 15 + 1
    elif perturb_location == "blocks.5.hook_resid_post":
        layer = 5 + 1
    elif perturb_location == "blocks.30.hook_resid_post":
        layer = 30 + 1

    def replace_activations(input, hook):
        input[:, 0:first_sequence_len, :] = activations[:, layer, :, :]
        return input

    for _ in range(length):
        logits = model.run_with_hooks(
            data,
            return_type="logits",
            fwd_hooks=[(perturb_location, replace_activations)]
        )
        new_token = logits.argmax(dim=-1)[:, -1]
        data = torch.cat([data, new_token.unsqueeze(1)], dim=1)
        torch.cuda.empty_cache()
    
    return data


def generate_sentences(data, perturb_location, size):
    model = HookedTransformer.from_pretrained("gpt2-xl", device=device)
    
    straighter_data = data.clone()
    curved_data = data.clone()

    straightest_perturbations, curviest_perturbations = get_perturbations(model, straighter_data, perturb_location, size)
    perturb_function = partial(perturb_input, perturbations = straightest_perturbations, principal_dimensions=principal_dimensions_for_straight)
    new_token, straighter_activations = generate_perturbed_token(model, straighter_data, perturb_function)

    straightest_perturbations, curviest_perturbations = get_perturbations(model, curved_data, perturb_location, size)
    perturb_function = partial(perturb_input, perturbations = curviest_perturbations, principal_dimensions=principal_dimensions_for_curved)
    new_token, curved_activations = generate_perturbed_token(model, curved_data, perturb_function)

    gc.collect()
    torch.cuda.empty_cache()

    model.reset_hooks()

    length = max_len - first_sequence_len
    straighter_data = continued_gen_perturbed(model, straighter_data, straighter_activations, perturb_location, length)
    gc.collect()
    torch.cuda.empty_cache()
    curved_data = continued_gen_perturbed(model, curved_data, curved_activations, perturb_location, length)
    gc.collect()
    torch.cuda.empty_cache()
        
    return straighter_data, curved_data, straighter_activations, curved_activations


def record_activations(model, data):
    post_activations = torch.zeros((len(data), layer_num, max_len, embedding_size))
    def record_post_activations(input, hook, layer):
        post_activations[:, layer, :, :] = input

    fwd_hooks = []
    for layer in range(layer_num):
        fwd_hooks.append((utils.get_act_name("resid_post", layer), partial(record_post_activations, layer=layer)))

    model.run_with_hooks(
        data, 
        return_type=None, 
        fwd_hooks=fwd_hooks,
    )
    model.reset_hooks()

    return post_activations


def run_perturbed(gen_data, gen_activations):
    #gen activations shape: (num_sentences, num_layers, num_tokens, hidden_size)
    gen_curvatures = [{}] * num_perturbations

    gen_curvatures[0] = compute_model_curvature(gen_activations)
        
    #get curvature with sentences
    model = HookedTransformer.from_pretrained("gpt2-xl", device=device)
    data_activations = record_activations(model,gen_data)
    data_curvature = {}
    data_curvature["post"] = compute_model_curvature(data_activations)

    #get surprisal with sentences
    data_decoded = [tokenizer.decode(sentence) for sentence in gen_data]

    model = GPT2LMHeadModel.from_pretrained("gpt2-xl").to(device)
    model = scorer.IncrementalLMScorer(model, tokenizer=tokenizer, device=device)
    batch_size = 1000
    for i in range(0, len(data_decoded), batch_size):
        data_decoded_batch = data_decoded[i:i+batch_size]
        surprisals = torch.tensor(model.sequence_score(data_decoded_batch, reduction = lambda x: -x.sum(0)))
        if i == 0:
            surprisals_all = surprisals
        else:
            surprisals_all = torch.cat([surprisals_all, surprisals], dim=0)

    return_dict = {
        "surprisals": surprisals_all,
        "sentences": data_decoded,
        "curvatures": data_curvature,
        "gen_curvatures": gen_curvatures
    }
    return return_dict

def launch(data, perturb_location):
    
    path_to_dict = f"/om2/user/jackking/modular_transformers/scripts/adding_straightness/perturb_straight_byact_results_{perturb_location}.pkl"
    # if os.path.exists(path_to_dict):
    #     new_surprisals = pickle.load(open(path_to_dict, "rb"))
    # else:
    new_surprisals = {}
        
    data = data[:50]
    cut_data = data[:, :first_sequence_len].to(device)

    normal_data, normal_activations = generate_normal_sentences(cut_data)
    normal_results = run_perturbed(normal_data, normal_activations)
    print("normal analyzed")

    new_surprisals["normal"] = normal_results

    for size in [0.1]:
        if size in new_surprisals:
            continue
        print(size)        
        straighter_data, curved_data, straighter_activations, curved_activations = generate_sentences(cut_data, perturb_location, size)
        print("sentences generated")
        
        straighter_results = run_perturbed(straighter_data, straighter_activations)
        print("straighter analyzed")
        curved_results = run_perturbed(curved_data, curved_activations)
        print("curved analyzed")

        new_surprisals[size] = {
            "straighter": straighter_results,
            "curved": curved_results
        }

        print(new_surprisals[0.1]["straighter"]["sentences"])

        with open(path_to_dict, 'wb') as f:
            pickle.dump(new_surprisals, f)
    

if __name__ == "__main__":
    data_dir = "/rdma/vast-rdma/vast/evlab/ehoseini/MyData/sent_sampling/analysis/straightening/generation/sentences_ud_sentencez_token_filter_v3_textNoPeriod_cntx_3_cont_7.pkl"
    with open(data_dir, 'rb') as f:
        data = pickle.load(f)
    tokenizer.pad_token = tokenizer.eos_token
    data = tokenizer.batch_encode_plus(data, add_special_tokens=True, padding='longest', return_tensors="pt")["input_ids"]

    num_perturbations = 1

    perturb_location = "blocks.15.hook_resid_post"
    launch(data, perturb_location)
    print("15 post done")

    # perturb_location = "blocks.5.hook_resid_post"
    # launch(data, perturb_location)
    # print("5 post done")

    # perturb_location = "blocks.30.hook_resid_post"
    # launch(data, perturb_location)
    # print("30 post done")


Loaded pretrained model gpt2-xl into HookedTransformer


50it [00:00, 158.45it/s]


Loaded pretrained model gpt2-xl into HookedTransformer


50it [00:01, 32.86it/s]


normal analyzed
0.1
Loaded pretrained model gpt2-xl into HookedTransformer
sentences generated


50it [00:00, 155.04it/s]


Loaded pretrained model gpt2-xl into HookedTransformer


50it [00:01, 32.71it/s]


straighter analyzed


50it [00:00, 160.85it/s]


Loaded pretrained model gpt2-xl into HookedTransformer


50it [00:01, 29.87it/s]


curved analyzed
['and we don\'t want to be the ones to tell you that you\'re wrong.\n\nThe problem is that the "', 'Tolstoi, who was born in Russia, was a member of the Russian Academy of Sciences and a member of the', "Being well-informed about the issues is a good start. But it's not enough. We need to be able to communicate", "Of course, you can't just go out and buy a new car. You have to get a loan. And that's", 'The clans didn\'t know what to do. They were all in shock.\n\n"I\'m sorry, but I can', 'I am half way through the book and I am loving it. I am also a huge fan of the series and I am', "She still has some of the same problems she had before, but she's not as bad as she was before. She's", 'Neither me nor my wife have ever been to the hospital. I have never been to the hospital. I have never been to', "But the words did not come from the mouth of a man who was in the least degree of the Lord's favor.\n", 'Unfortunately, at its core, the problem is that the government is no