In [1]:
%%capture
!pip install -r requirements.txt

In [None]:
import sys
from pathlib import Path
import json
import pandas as pd
from dotenv import load_dotenv
import plotly.express as px
import torch as t
import pandas as pd
from tools.globals import load_country_globals
from tools.nnsight_utils import collect_residuals
from tqdm import tqdm
from nnsight import LanguageModel
from transformers import AutoTokenizer

load_country_globals()

device = t.device(
    "mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu"
)
load_dotenv()
t.set_grad_enabled(False)

t.manual_seed(42)
if t.cuda.is_available():
    t.cuda.manual_seed_all(42)

%load_ext autoreload
%autoreload 2

In [3]:
prompt_suffix = {
    "English": "My guess is **",
    "Turkish": "Tahminim **",
    "French": "Ma supposition est **",
    "Russian": "Моё предположение **",
    "Bengali": "আমার অনুমান হলো **",
}

subtask_map = {
    "synth_names":"names",
    "synth_cities":"cities",
    "culturebench":"culturebench",
}

def eval_or_skip(txt):
    try:
        return eval(txt)
    except:
        return None

country_to_suffix= {"Turkey": "tr", "France": "fr", "Russia": "ru", "Bangladesh": "bn", "United States":"us"}

### Load Model

In [9]:
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
nnmodel = LanguageModel('/dlabscratch1/public/llm_weights/gemma_hf/gemma-2-9b-it', 
                        device_map='cuda:0', 
                        dispatch=True, 
                        torch_dtype=t.bfloat16)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

### Load dataset

In [None]:
steering_df = pd.read_csv("model_gen/gemma2_9b_it_guess_suffix_fixed.csv")
steering_df["translated"] = steering_df.apply(lambda x: True if x["country"] == "United States" else x["lang"]!="English", axis=1)

print(steering_df.iloc[29]["input"])

<bos><start_of_turn>user
Question:You must select one option and answer. First, state the selected option in full, then explain your guess.
What is a common living arrangement for children after they reach adulthood?
Options:
Children go to live with their distant relatives for better education or job opportunities.
Children often continue to live with their parents, or have their parents move into their homes to take care of them<end_of_turn>
<start_of_turn>model
My guess is **


### Calculate residuals

In [None]:
# Explicit (English) residuals
countries = ["Turkey", "France", "Russia", "Bangladesh", "United States"]

for country in countries:
    print(country)
    for subtask in ["names", "cities", "culturedistil", "culturebench"]:
        print(subtask)
        pos_examples = steering_df.query("country == @country and subtask == @subtask and lang=='English' and hint==True and ans_type == 'local'")
        neg_examples = steering_df.query("country == @country and subtask == @subtask and lang=='English' and hint==False and ans_type == 'west'")

        common_question_ids = set(pos_examples["question_id"].unique()) & set(neg_examples["question_id"].unique())

        pos_examples = pos_examples.query("question_id in @common_question_ids").sort_values("question_id").copy()
        neg_examples = neg_examples.query("question_id in @common_question_ids").sort_values("question_id").copy()
        
        print(pos_examples.shape, neg_examples.shape)
        
        pos_res_list = []
        neg_res_list = []

        for pos, neg in tqdm(zip(pos_examples["input"], neg_examples["input"]), total=len(pos_examples)):
            pos_res = collect_residuals(nnmodel, pos, calculate_probs=False)["residuals"][:,0,-1,:]
            neg_res = collect_residuals(nnmodel, neg, calculate_probs=False)["residuals"][:,0,-1,:]
            pos_res_list.append(pos_res)
            neg_res_list.append(neg_res)
        pos_res = t.stack(pos_res_list)
        neg_res = t.stack(neg_res_list)

        t.save(pos_res, f"residuals/gemma2_9b_it_exp_en/{country_to_suffix[country]}_{subtask}_pos.pt")
        t.save(neg_res, f"residuals/gemma2_9b_it_exp_en/{country_to_suffix[country]}_{subtask}_neg.pt")

In [None]:
# Explicit (Translated) residuals
countries = ["Turkey", "France", "Russia", "Bangladesh", "United States"]

for country in countries:
    print(country)
    for subtask in ["names", "cities", "culturedistil", "culturebench"]:
        print(subtask)
        pos_examples = steering_df.query("country == @country and subtask == @subtask and translated and hint==True and ans_type == 'local'")
        neg_examples = steering_df.query("country == @country and subtask == @subtask and translated and hint==False and ans_type == 'west'")

        common_question_ids = set(pos_examples["question_id"].unique()) & set(neg_examples["question_id"].unique())

        pos_examples = pos_examples.query("question_id in @common_question_ids").sort_values("question_id").copy()
        neg_examples = neg_examples.query("question_id in @common_question_ids").sort_values("question_id").copy()
        
        print(pos_examples.shape, neg_examples.shape)

        pos_res_list = []
        neg_res_list = []

        for pos, neg in tqdm(zip(pos_examples["input"], neg_examples["input"]), total=len(pos_examples)):
            pos_res = collect_residuals(nnmodel, pos, calculate_probs=False)["residuals"][:,0,-1,:]
            neg_res = collect_residuals(nnmodel, neg, calculate_probs=False)["residuals"][:,0,-1,:]
            pos_res_list.append(pos_res)
            neg_res_list.append(neg_res)
        pos_res = t.stack(pos_res_list)
        neg_res = t.stack(neg_res_list)

        t.save(pos_res, f"residuals/gemma2_9b_it_exp_trans/{country_to_suffix[country]}_{subtask}_pos.pt")
        t.save(neg_res, f"residuals/gemma2_9b_it_exp_trans/{country_to_suffix[country]}_{subtask}_neg.pt")

In [None]:
# Implicit residuals
countries = ["Turkey", "France", "Russia", "Bangladesh"]

for country in countries:
    print(country)
    for subtask in ["names", "cities", "culturedistil", "culturebench"]:
        print(subtask)
        pos_examples = steering_df.query("country == @country and subtask == @subtask and lang!='English' and hint==False and ans_type == 'local'")
        neg_examples = steering_df.query("country == @country and subtask == @subtask and lang=='English' and hint==False and ans_type == 'west'")

        common_question_ids = set(pos_examples["question_id"].unique()) & set(neg_examples["question_id"].unique())

        pos_examples = pos_examples.query("question_id in @common_question_ids").sort_values("question_id").copy()
        neg_examples = neg_examples.query("question_id in @common_question_ids").sort_values("question_id").copy()
        
        print(pos_examples.shape, neg_examples.shape)

        pos_res_list = []
        neg_res_list = []

        for pos, neg in tqdm(zip(pos_examples["input"], neg_examples["input"]), total=len(pos_examples)):
            pos_res = collect_residuals(nnmodel, pos, calculate_probs=False)["residuals"][:,0,-1,:]
            neg_res = collect_residuals(nnmodel, neg, calculate_probs=False)["residuals"][:,0,-1,:]
            pos_res_list.append(pos_res)
            neg_res_list.append(neg_res)
        pos_res = t.stack(pos_res_list)
        neg_res = t.stack(neg_res_list)

        t.save(pos_res, f"residuals/gemma2_9b_it_imp/{country_to_suffix[country]}_{subtask}_pos.pt")
        t.save(neg_res, f"residuals/gemma2_9b_it_imp/{country_to_suffix[country]}_{subtask}_neg.pt")


### Calculate steering vectors

In [None]:
# Implicit, per culture steering vectors
tasks = ["names", "cities", "culturedistil", "culturebench"]

for prefix in ["tr", "fr", "ru", "bn"]:
    print(prefix)
    steering_vecs = []
    for task in tasks:
        res_tr = t.load(f"residuals/gemma2_9b_it_v3_imp/{prefix}_{task}_pos.pt", weights_only=True)
        res_en = t.load(f"residuals/gemma2_9b_it_v3_imp/{prefix}_{task}_neg.pt", weights_only=True)

        steering_vec = (res_tr - res_en)
        steering_vecs.append(steering_vec)
    
    print(len(steering_vecs))
    steering_vec = t.cat(steering_vecs, dim=0)
    print(steering_vec.shape)
    steering_vec = steering_vec.mean(dim=0)
    t.save(steering_vec,f"vectors/gemma2_9b_it/implicit/{prefix}_avg_all_tasks.pt")

tr
4
torch.Size([201, 42, 3584])
fr
4
torch.Size([169, 42, 3584])
ru
4
torch.Size([108, 42, 3584])
bn
4
torch.Size([220, 42, 3584])


In [None]:
# Explicit, per culture (Translated) steering vectors
tasks = ["names", "cities", "culturedistil", "culturebench"]

for prefix in ["tr", "fr", "ru", "bn", "us"]:
    print(prefix)
    steering_vecs = []
    for task in tasks:
        res_tr = t.load(f"residuals/gemma2_9b_it_exp_trans/{prefix}_{task}_pos.pt", weights_only=True)
        res_en = t.load(f"residuals/gemma2_9b_it_exp_trans/{prefix}_{task}_neg.pt", weights_only=True)

        steering_vec = (res_tr - res_en)
        steering_vecs.append(steering_vec)
    
    print(len(steering_vecs))
    steering_vec = t.cat(steering_vecs, dim=0)
    print(steering_vec.shape)
    steering_vec = steering_vec.mean(dim=0)
    t.save(steering_vec,f"vectors/gemma2_9b_it/per_culture/{prefix}_trans_avg_all_tasks.pt")


tr
4
torch.Size([232, 42, 3584])
fr
4
torch.Size([165, 42, 3584])
ru
4
torch.Size([267, 42, 3584])
bn
4
torch.Size([117, 42, 3584])
us
4
torch.Size([206, 42, 3584])


In [None]:
# Explicit, per culture (English) steering vectors
tasks = ["names", "cities", "culturedistil", "culturebench"]

for prefix in ["tr", "fr", "ru", "bn", "us"]:
    print(prefix)
    steering_vecs = []
    for task in tasks:
        res_tr = t.load(f"residuals/gemma2_9b_it_exp_en/{prefix}_{task}_pos.pt", weights_only=True)
        res_en = t.load(f"residuals/gemma2_9b_it_exp_en/{prefix}_{task}_neg.pt", weights_only=True)

        steering_vec = (res_tr - res_en)
        steering_vecs.append(steering_vec)
    
    print(len(steering_vecs))
    steering_vec = t.cat(steering_vecs, dim=0)
    print(steering_vec.shape)
    steering_vec = steering_vec.mean(dim=0)
    t.save(steering_vec,f"vectors/gemma2_9b_it/per_culture/{prefix}_en_avg_all_tasks.pt")

In [None]:
# Leave-one-out Universal, (Translated) steering vectors
tasks = ["names", "cities", "culturedistil", "culturebench"]

for pref_to_skip in ["tr", "fr", "ru", "bn", "us"]:
    for prefix in ["tr", "fr", "ru", "bn", "us"]:
        steering_vecs = []
        if prefix == pref_to_skip:
            continue
        for task in tasks:
            res_tr = t.load(f"residuals/gemma2_9b_it_exp_trans/{prefix}_{task}_pos.pt", weights_only=True)
            res_en = t.load(f"residuals/gemma2_9b_it_exp_trans/{prefix}_{task}_neg.pt", weights_only=True)

            steering_vec = (res_tr - res_en)
            steering_vecs.append(steering_vec)
        
        print(len(steering_vecs))

    steering_vec = t.cat(steering_vecs, dim=0)
    print(steering_vec.shape)

    steering_vec = steering_vec.mean(dim=0)
    t.save(steering_vec,f"vectors/gemma2_9b_it/universal/trans_universal_{pref_to_skip}_out.pt")

In [None]:
# Leave-one-out Universal, (English) steering vectors
tasks = ["names", "cities", "culturedistil", "culturebench"]

for pref_to_skip in ["tr", "fr", "ru", "bn", "us"]:
    for prefix in ["tr", "fr", "ru", "bn", "us"]:
        steering_vecs = []
        if prefix == pref_to_skip:
            continue
        for task in tasks:
            res_tr = t.load(f"residuals/gemma2_9b_it_exp_en/{prefix}_{task}_pos.pt", weights_only=True)
            res_en = t.load(f"residuals/gemma2_9b_it_exp_en/{prefix}_{task}_neg.pt", weights_only=True)

            steering_vec = (res_tr - res_en)
            steering_vecs.append(steering_vec)
        
        print(len(steering_vecs))

    steering_vec = t.cat(steering_vecs, dim=0)
    print(steering_vec.shape)

    steering_vec = steering_vec.mean(dim=0)
    t.save(steering_vec,f"vectors/gemma2_9b_it/universal/en_universal_{pref_to_skip}_out.pt")

In [None]:
# Per task-culture, (English) steering vectors
tasks = ["names", "cities", "culturedistil", "culturebench"]

for prefix in ["tr", "fr", "ru", "bn", "us"]:
    print(prefix)
    for task in tasks:
        res_tr = t.load(f"residuals/gemma2_9b_it_exp_en/{prefix}_{task}_pos.pt", weights_only=True)
        res_en = t.load(f"residuals/gemma2_9b_it_exp_en/{prefix}_{task}_neg.pt", weights_only=True)

        steering_vec = (res_tr - res_en).mean(dim=0)

        print(steering_vec.shape)
        t.save(steering_vec,f"vectors/gemma2_9b_it/per_task/{prefix}_{task}_en.pt")

tr
torch.Size([42, 3584])
torch.Size([42, 3584])
torch.Size([42, 3584])
torch.Size([42, 3584])
fr
torch.Size([42, 3584])
torch.Size([42, 3584])
torch.Size([42, 3584])
torch.Size([42, 3584])
ru
torch.Size([42, 3584])
torch.Size([42, 3584])
torch.Size([42, 3584])
torch.Size([42, 3584])
bn
torch.Size([42, 3584])
torch.Size([42, 3584])
torch.Size([42, 3584])
torch.Size([42, 3584])
us
torch.Size([42, 3584])
torch.Size([42, 3584])
torch.Size([42, 3584])
torch.Size([42, 3584])
