# Imports

In [None]:
from ast import literal_eval
import functools
import json
import os
import random
import shutil

# Scienfitic packages
import numpy as np
import pandas as pd
import torch
import datasets
torch.set_grad_enabled(False)

# Visuals
from matplotlib import pyplot as plt
import seaborn as sns
sns.set(context="notebook",
        rc={"font.size":16,
            "axes.titlesize":16,
            "axes.labelsize":16,
            "xtick.labelsize": 16.0,
            "ytick.labelsize": 16.0,
            "legend.fontsize": 16.0})
palette_ = sns.color_palette("Set1")
palette = palette_[2:5] + palette_[7:]
sns.set_theme(style='whitegrid')

# Utilities

from general_utils import (
  ModelAndTokenizer,
  make_inputs,
  decode_tokens,
  find_token_range,
  predict_from_input,
)

from patchscopes_utils import *

from tqdm import tqdm
tqdm.pandas()

In [None]:
model_to_hook = {
    "EleutherAI/pythia-12b": set_hs_patch_hooks_neox,
    "meta-llama/Llama-2-13b-hf": set_hs_patch_hooks_llama,
    "lmsys/vicuna-7b-v1.5": set_hs_patch_hooks_llama,
    "./stable-vicuna-13b": set_hs_patch_hooks_llama,
    "CarperAI/stable-vicuna-13b-delta": set_hs_patch_hooks_llama,
    "EleutherAI/gpt-j-6b": set_hs_patch_hooks_gptj
}


In [None]:
# Load model

model_name = "EleutherAI/pythia-12b"
sos_tok = False

if "13b" in model_name or "12b" in model_name:
    torch_dtype = torch.float16
else:
    torch_dtype = None

mt = ModelAndTokenizer(
    model_name,
    low_cpu_mem_usage=False,
    torch_dtype=torch_dtype,
)
mt.set_hs_patch_hooks = model_to_hook[model_name]
mt.model.eval()

# Next token prediction

In [None]:
# load dataset

pile_dataset = datasets.load_from_disk('./the_pile_deduplicated')
print(len(pile_dataset))
pile_dataset = pile_dataset.filter(
    lambda x: len(x['text'].split(' ')) < 250 and len(x['text']) < 2000
).shuffle(seed=42)
print(len(pile_dataset))

trn_n = 10000
val_n = 2000
pile_trn = pile_dataset['text'][:trn_n]
pile_val = pile_dataset['text'][trn_n:trn_n+val_n]
sentences = [(x, 'train') for x in pile_trn] + [(x, 'validation') for x in pile_val]

In [None]:
# Across layer mappings

data = {}
for sentence, split in tqdm(sentences):
    inp = make_inputs(mt.tokenizer, [sentence], device=mt.model.device)
    if sos_tok:
        start_pos = 1
    else:
        start_pos = 0
    position = random.randint(start_pos, len(inp['input_ids'][0]) - 1)

    if (sentence, position, split) not in data:
        output = mt.model(**inp, output_hidden_states = True)

        data[(sentence, position, split)] =  [
            output["hidden_states"][layer+1][0][position].detach().cpu().numpy()
            for layer in range(mt.num_layers)
        ]

df = pd.Series(data).reset_index()
df.columns = ['full_text', 'position', 'data_split', 'hidden_rep']

df.to_pickle(model_name+"_pile_trn_val.pkl")

In [None]:
# Prompt-id mappings

prompt_target = "cat -> cat\n1135 -> 1135\nhello -> hello\n?"
inp_target = make_inputs(mt.tokenizer, [prompt_target], device=mt.model.device)

data = {}
for sentence, split in tqdm(sentences):
    inp = make_inputs(mt.tokenizer, [sentence], device=mt.model.device)
    if sos_tok:
        start_pos = 1
    else:
        start_pos = 0
    position = random.randint(start_pos, len(inp['input_ids'][0]) - 2)

    if (sentence, position, split, "source") not in data:
        output = mt.model(**inp, output_hidden_states = True)
        _, answer_t = torch.max(torch.softmax(output.logits[0, -1, :], dim=0), dim=0)
        data[(sentence, position, split, "source")] =  [
            output["hidden_states"][layer+1][0][position].detach().cpu().numpy()
            for layer in range(mt.num_layers)
        ]

        inp_target['input_ids'][0][-1] = answer_t
        output = mt.model(**inp_target, output_hidden_states = True)
        data[(sentence, position, split, "target")] =  [
            output["hidden_states"][layer+1][0][-1].detach().cpu().numpy()
            for layer in range(mt.num_layers)
        ]

df = pd.Series(data).reset_index()
df.columns = ['full_text', 'position', 'data_split', 'prompt', 'hidden_rep']

df.to_pickle(model_name+"_pile_trn_val.pkl")


In [None]:
# Pad and unpad 

pad = lambda x: np.hstack([x, np.ones((x.shape[0], 1))])
unpad = lambda x: x[:,:-1]

In [None]:
# Across layer mappings

output_dir = f'{model_name}_mappings_pile'
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

df_trn = pd.DataFrame(df[df['data_split'] == 'train']['hidden_rep'].to_list(),
                      columns=[layer for layer in range(mt.num_layers)])

target_layer = mt.num_layers - 1
Y = np.array(
    df_trn[target_layer].values.tolist()
)

mappings = []
for layer in range(mt.num_layers):
    X = np.array(
        df_trn[layer].values.tolist()
    )

    # Solve the least squares problem X * A = Y
    # to find our transformation matrix A
    A, res, rank, s = np.linalg.lstsq(pad(X), pad(Y))
    transform = lambda x: unpad(pad(x) @ A)

    mappings.append(A)
    with open(f'{output_dir}/mapping_{layer}-{target_layer}.npy', 'wb') as fd:
        np.save(fd, A)

    print(layer, "max error on train:", np.abs(Y - transform(X)).max())

shutil.make_archive(output_dir, 'zip', output_dir)

In [None]:
# Prompt-id mappings

output_dir = f'{model_name}_mappings_pile_prompt-id'
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

df_trn_src = pd.DataFrame(df[(df['data_split'] == 'train') & (df['prompt'] == 'source')]['hidden_rep'].to_list(),
                          columns=[layer for layer in range(mt.num_layers)])
df_trn_tgt = pd.DataFrame(df[(df['data_split'] == 'train') & (df['prompt'] == 'target')]['hidden_rep'].to_list(),
                          columns=[layer for layer in range(mt.num_layers)])

mappings = []
for layer in range(mt.num_layers):
    X = np.array(
        df_trn_src[layer].values.tolist()
    )
    Y = np.array(
        df_trn_tgt[layer].values.tolist()
    )

    # Solve the least squares problem X * A = Y
    # to find our transformation matrix A
    A, res, rank, s = np.linalg.lstsq(pad(X), pad(Y))
    transform = lambda x: unpad(pad(x) @ A)

    mappings.append(A)
    with open(f'{output_dir}/mapping_{layer}.npy', 'wb') as fd:
        np.save(fd, A)

    print(layer, "max error on train:", np.abs(Y - transform(X)).max())

shutil.make_archive(output_dir, 'zip', output_dir)

In [None]:
mappings = []
for layer in tqdm(range(mt.num_layers)):
    with open(f'{model_name}_mappings_pile/mapping_{layer}-{mt.num_layers-1}.npy', 'rb') as fd:
        A = np.load(fd)
    mappings.append(A)

In [None]:
# Evaluate linear mappings on the validation set of WikiText
device = mt.model.device
target_layer = mt.num_layers - 1

records = []
for layer in tqdm(range(mt.num_layers)):
    A = mappings[layer]
    transform = lambda x: torch.tensor(
        np.squeeze(
            unpad(np.dot(
                pad(np.expand_dims(x.detach().cpu().numpy(), 0)),
                A
            ))
        )
    ).to(device)

    for idx, row in df[df['data_split'] == 'validation'].iterrows():
        prompt = row['full_text']
        position = row['position']
        prec_1, surprisal = evaluate_patch_next_token_prediction(
            mt, prompt, prompt, layer, target_layer,
            position, position, position_prediction=position, transform=transform)

        records.append({'layer': layer, 'prec_1': prec_1, 'surprisal': surprisal})


results = pd.DataFrame.from_records(records)
results.to_csv(f'{model_name}_mappings_pile_eval.csv')

In [None]:
# Evaluate identity mapping on the validation set of WikiText

target_layer = mt.num_layers - 1

records = []
for layer in tqdm(range(mt.num_layers)):
    for idx, row in df[df['data_split'] == 'validation'].iterrows():
        prompt = row['full_text']
        position = row['position']
        prec_1, surprisal = evaluate_patch_next_token_prediction(
            mt, prompt, prompt, layer, target_layer,
            position, position, position_prediction=position)

        records.append({'layer': layer, 'prec_1': prec_1, 'surprisal': surprisal})

results = pd.DataFrame.from_records(records)
results.to_csv(f'{model_name}_identity_pile_eval.csv')

In [None]:
# Evaluate the ID prompt on the validation set of WikiText (with/without mappings)

device = mt.model.device

prompt_target = "cat -> cat\n1135 -> 1135\nhello -> hello\n?"
position_target = -1
apply_mappings = True

records = []
for layer in tqdm(range(mt.num_layers)):
    if apply_mappings:
        A = mappings[layer]
        transform = lambda x: torch.tensor(
            np.squeeze(
                unpad(np.dot(
                    pad(np.expand_dims(x.detach().cpu().numpy(), 0)),
                    A
                ))
            )
        ).to(device)
    else:
        transform = None

    for idx, row in df[df['data_split'] == 'validation'].iterrows():
        if 'prompt' in row and row['prompt'] == 'target':
            continue
        prompt_source = row['full_text']
        position_source = row['position']
        prec_1, surprisal = evaluate_patch_next_token_prediction(
            mt, prompt_source, prompt_target, layer, layer,
            position_source, position_target, position_prediction=position_target, transform=transform)

        records.append({'layer': layer, 'prec_1': prec_1, 'surprisal': surprisal})

results = pd.DataFrame.from_records(records)
if apply_mappings:
    results.to_csv(f'{model_name}_prompt-id-mapping_pile_eval.csv')
else:
    results.to_csv(f'{model_name}_prompt-id_pile_eval.csv')

In [None]:
results1 = pd.read_csv(f'{model_name}_identity_pile_eval.csv')
results1["variant"] = "identity"
results2 = pd.read_csv(f'{model_name}_mappings_pile_eval.csv')
results2["variant"] = "affine mapping"
results3 = pd.read_csv(f'{model_name}_prompt-id_pile_eval.csv')
results3["variant"] = "prompt id"

results = pd.concat([results1, results2, results3], ignore_index=True)

for metric in ['prec_1', 'surprisal']:
    ax = sns.lineplot(data=results, x='layer', y=metric, hue="variant")
    ax.set_title(model_name.strip('./'))
    ax.legend_.set_title('')
    plt.show()
    plt.clf()