# Imports

In [1]:
from ast import literal_eval
import functools
import json
import os
import random
import shutil

# Scienfitic packages
import numpy as np
import pandas as pd
import torch
import datasets
from torch import cuda
torch.set_grad_enabled(False)

# Visuals
from matplotlib import pyplot as plt
import seaborn as sns
sns.set(context="notebook",
        rc={"font.size":16,
            "axes.titlesize":16,
            "axes.labelsize":16,
            "xtick.labelsize": 16.0,
            "ytick.labelsize": 16.0,
            "legend.fontsize": 16.0})
palette_ = sns.color_palette("Set1")
palette = palette_[2:5] + palette_[7:]
sns.set_theme(style='whitegrid')

# Utilities

from general_utils import (
  ModelAndTokenizer,
  make_inputs,
  decode_tokens,
  find_token_range,
  predict_from_input,
)

from patchscopes_utils import *

from tqdm import tqdm
tqdm.pandas()

In [3]:
model_to_hook = {
    "EleutherAI/pythia-6.9b": set_hs_patch_hooks_neox,
    "EleutherAI/pythia-12b": set_hs_patch_hooks_neox,
    "meta-llama/Llama-2-13b-hf": set_hs_patch_hooks_llama,
    "lmsys/vicuna-7b-v1.5": set_hs_patch_hooks_llama,
    "./stable-vicuna-13b": set_hs_patch_hooks_llama,
    "CarperAI/stable-vicuna-13b-delta": set_hs_patch_hooks_llama,
    "EleutherAI/gpt-j-6b": set_hs_patch_hooks_gptj
}

In [None]:
# Load model 1

model_name_1 = "lmsys/vicuna-7b-v1.5"
sos_tok_1 = False

if "13b" in model_name_1 or "12b" in model_name_1:
    torch_dtype = torch.float16
else:
    torch_dtype = None

mt_1 = ModelAndTokenizer(
    model_name_1,
    low_cpu_mem_usage=False,
    torch_dtype=torch_dtype,
    device="cuda:1"
)
mt_1.set_hs_patch_hooks = model_to_hook[model_name_1]
mt_1.model.eval()
mt_1.model.to(mt_1.device)

In [None]:
# Load model 2

model_name_2 = "./stable-vicuna-13b"
model_name_2_ = model_name_2.strip('./')
sos_tok_2 = False

if "13b" in model_name_2 or "12b" in model_name_2:
    torch_dtype = torch.float16
else:
    torch_dtype = None

mt_2 = ModelAndTokenizer(
    model_name_2,
    low_cpu_mem_usage=False,
    torch_dtype=torch_dtype,
    device="cuda:0"
)
mt_2.set_hs_patch_hooks = model_to_hook[model_name_2]
mt_2.model.eval()
mt_2.model.to(mt_2.device)

# Next token prediction

In [None]:
pile_dataset = datasets.load_from_disk('./the_pile_deduplicated')
pile_dataset = pile_dataset.shuffle(seed=42)
print(len(pile_dataset))

trn_n = 100000
val_n = 2000
pile_trn = pile_dataset['text'][:trn_n]
pile_val = pile_dataset['text'][trn_n:trn_n+val_n]
sentences = [(x, 'train') for x in pile_trn] + [(x, 'validation') for x in pile_val]

In [None]:
max_len = 256

data = {}
for sentence, split in tqdm(sentences):
    
    inp_1_ = make_inputs(mt_1.tokenizer, [sentence], device=mt_1.device)
    inp_2_ = make_inputs(mt_2.tokenizer, [sentence], device=mt_2.device)
    position = None
    k = 0
    while k<10:
        position_tmp = random.randint(
            0, min(max_len - 1, 
                   len(inp_1_['input_ids'][0]) - 1, 
                   len(inp_2_['input_ids'][0]) - 1)
        )
        # cut the tokenized input at the sampled position and turn it back into a string.
        # add some buffer at the end such that the tokenization is not modified around the sampled position.
        prefix_1 = mt_1.tokenizer.decode(inp_1_['input_ids'][0][:position_tmp + int(sos_tok_1) + 5])
        prefix_2 = mt_2.tokenizer.decode(inp_2_['input_ids'][0][:position_tmp + int(sos_tok_2) + 5])
        
        # check that the selected position corresponds to the same part of the string by 
        # comparing the prefixes until the sampled position. also make sure that this re-tokenization
        # does not shift the sampled position off the sequence length.
        inp_1 = make_inputs(mt_1.tokenizer, [prefix_1], device=mt_1.device)
        inp_2 = make_inputs(mt_2.tokenizer, [prefix_2], device=mt_2.device)
        if prefix_1 == prefix_2 and position_tmp < min(len(inp_1['input_ids'][0]), 
                                                       len(inp_2['input_ids'][0])):
            position = position_tmp
            break
        k += 1
    if position is None:
        continue
    
    for mt, model_name, inp, sos_tok in zip(
        [mt_1, mt_2],
        [model_name_1, model_name_2],
        [inp_1, inp_2],
        [sos_tok_1, sos_tok_2]
    ):
        position_ = position + int(sos_tok)
        if (prefix_1, position_, split, model_name) not in data:
            output = mt.model(**inp, output_hidden_states = True)

            data[(prefix_1, position_, split, model_name)] =  [
                output["hidden_states"][layer+1][0][position_].detach().cpu().numpy()
                for layer in range(mt.num_layers)
            ]

df = pd.Series(data).reset_index()
df.columns = ['full_text', 'position', 'data_split', 'model_name', 'hidden_rep']   

for model_name in [model_name_1, model_name_2]:
    df[df['model_name'] == model_name].to_pickle(f"{model_name}_pile_trn_val.pkl")

In [29]:
# Pad and unpad 

pad = lambda x: np.hstack([x, np.ones((x.shape[0], 1))])
unpad = lambda x: x[:,:-1]

In [30]:
layer_sources = [l for l in range(0, mt_1.num_layers, 5)]
layer_targets = [l for l in range(0, mt_2.num_layers, 5)]

In [None]:
output_dir = f'{model_name_1}_{model_name_2_}_mappings_pile'
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
    
df_trn_1 = pd.DataFrame(df[(df['data_split'] == 'train') & 
                           (df['model_name'] == model_name_1)]['hidden_rep'].to_list(), 
                        columns=[layer for layer in range(mt_1.num_layers)])
df_trn_2 = pd.DataFrame(df[(df['data_split'] == 'train') & 
                           (df['model_name'] == model_name_2)]['hidden_rep'].to_list(), 
                        columns=[layer for layer in range(mt_2.num_layers)])

layer_sources = [l for l in range(0, mt_1.num_layers, 5)]
layer_targets = [l for l in range(0, mt_2.num_layers, 5)]

mappings = {}
for layer_source in tqdm(layer_sources):
    for layer_target in layer_targets:
        X = np.array(
            df_trn_1[layer_source].values.tolist()
        )
        Y = np.array(
            df_trn_2[layer_target].values.tolist()
        )

        # Solve the least squares problem X * A = Y
        # to find our transformation matrix A
        A, res, rank, s = np.linalg.lstsq(pad(X), pad(Y))
        transform = lambda x: unpad(pad(x) @ A)

        mappings[(layer_source, layer_target)] = A
        with open(f'{model_name_1}_{model_name_2_}_mappings_pile/mapping_{layer_source}-{layer_target}.npy', 'wb') as fd:
            np.save(fd, A)

        print(layer_source, layer_target, "max error on train:", np.abs(Y - transform(X)).max())

shutil.make_archive(output_dir, 'zip', output_dir)

In [None]:
mappings = {}
for layer_source in tqdm(layer_sources):
    for layer_target in layer_targets:
        with open(f'{model_name_1}_{model_name_2_}_mappings_pile/mapping_{layer_source}-{layer_target}.npy', 'rb') as fd:
            A = np.load(fd)
        mappings[(layer_source, layer_target)] = A

In [None]:
# Re-organize validation set

df_val = df[(df['data_split'] == 'validation')].groupby(['full_text', 'data_split']).agg(pd.Series.tolist).reset_index()
cols = ['position', 'model_name', 'hidden_rep']
for col in cols:
    df_val[[f'{col}_1', f'{col}_2']] = df_val[col].to_list()

df_val = df_val[[col for col in df_val.columns if col not in cols]]

In [None]:
# Evaluate linear mappings on the validation set of WikiText/a sample from the Pile

records = []
for layer_source in tqdm(layer_sources):
    for layer_target in tqdm(layer_targets):
        A = mappings[(layer_source, layer_target)]
        transform = lambda x: torch.tensor(
            np.squeeze(
                unpad(np.dot(
                    pad(np.expand_dims(x.detach().cpu().numpy(), 0)), 
                    A
                ))
            )
        ).to(mt_2.device)

        for idx, row in df_val.iterrows():
            prompt = row['full_text']
            position_source = row['position_1']
            position_target = row['position_2']
            prec_1, surprisal = evaluate_patch_next_token_prediction_x_model(
                mt_1, mt_2, prompt, prompt, layer_source, layer_target,
                position_source, position_target, position_prediction=position_target, transform=transform)

            records.append({'layer_source': layer_source,
                            'layer_target': layer_target,
                            'position_source': position_source,
                            'position_target': position_target,
                            'prec_1': prec_1, 
                            'surprisal': surprisal})
        

results = pd.DataFrame.from_records(records)
results.to_csv(f'{model_name_1}_{model_name_2_}_mappings_pile_eval.csv')

In [None]:
# Plot the resulted heatmap
metric = 'prec_1'
tmp = results[['layer_source', 'layer_target', metric]].groupby(['layer_source', 'layer_target']).agg("mean").reset_index()
tmp = tmp.pivot(index='layer_source', columns='layer_target', values=metric)

sns.heatmap(tmp, annot=True, fmt=".1f")

In [None]:
# Evaluate identity mapping on the validation set of WikiText

records = []
for layer_source in tqdm(layer_sources):
    for layer_target in tqdm(layer_targets):
        for idx, row in df_val.iterrows():
            prompt = row['full_text']
            position_source = row['position_1']
            position_target = row['position_2']
            prec_1, surprisal = evaluate_patch_next_token_prediction_x_model(
                mt_1, mt_2, prompt, prompt, layer_source, layer_target,
                position_source, position_target, position_prediction=position_target)

            records.append({'layer_source': layer_source,
                            'layer_target': layer_target,
                            'position_source': position_source,
                            'position_target': position_target,
                            'prec_1': prec_1, 
                            'surprisal': surprisal})
        
results = pd.DataFrame.from_records(records)
results.to_csv(f'{model_name_1}_{model_name_2_}_identity_pile_eval.csv')

In [None]:
# Evaluate the ID prompt on the validation set of WikiText

prompt_target = "cat -> cat\n1135 -> 1135\nhello -> hello\n?"
position_target = -1

records = []
for layer_source in tqdm(layer_sources):
    for layer_target in tqdm(layer_targets):
        for idx, row in df_val.iterrows():
            prompt_source = row['full_text']
            position_source = row['position_1']
            prec_1, surprisal = evaluate_patch_next_token_prediction_x_model(
                mt_1, mt_2, prompt_source, prompt_target, layer_source, layer_target,
                position_source, position_target, position_prediction=position_target, transform=None)

            records.append({'layer_source': layer_source,
                            'layer_target': layer_target,
                            'position_source': position_source,
                            'position_target': position_target,
                            'prec_1': prec_1, 
                            'surprisal': surprisal})
        
results = pd.DataFrame.from_records(records)
results.to_csv(f'{model_name_1}_{model_name_2_}_prompt-id_pile_eval.csv')

In [None]:
results1 = pd.read_csv(f'{model_name_1}_{model_name_2_}_identity_pile_eval.csv')
results1["variant"] = "identity"
results2 = pd.read_csv(f'{model_name_1}_{model_name_2_}_mappings_pile_eval.csv')
results2["variant"] = "affine mapping"
results3 = pd.read_csv(f'{model_name_1}_{model_name_2_}_prompt-id_pile_eval.csv')
results3["variant"] = "prompt id"

results = pd.concat([results1, results2, results3], ignore_index=True)

for metric in ['prec_1', 'surprisal']:
    ax = sns.lineplot(data=results, x='layer', y=metric, hue="variant")
    ax.set_title(f"{model_name_1.strip('./')} --> {model_name_2_}")
    ax.legend_.set_title('')
    plt.show()
    plt.clf()