In [1]:
%%capture
!pip install -r requirements.txt

In [1]:
import sys
from pathlib import Path
import json
import pandas as pd
from dotenv import load_dotenv
import plotly.express as px
import torch as t
import pandas as pd
from tools.globals import load_country_globals

from translate import Translator

from nnsight import LanguageModel
from transformers import AutoTokenizer

load_country_globals()
translator = Translator(from_lang="autodetect",to_lang="en")

device = t.device(
    "mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu"
)
load_dotenv()
t.set_grad_enabled(False)

t.manual_seed(42)
if t.cuda.is_available():
    t.cuda.manual_seed_all(42)

%load_ext autoreload
%autoreload 2

In [2]:
t.cuda.empty_cache()

In [2]:
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b")
nnmodel = LanguageModel('/dlabscratch1/public/llm_weights/gemma_hf/gemma-2-9b', 
                        device_map='cuda:0', 
                        dispatch=True, 
                        torch_dtype=t.bfloat16)

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [3]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")
tokenizer.pad_token = tokenizer.eos_token
nnmodel = LanguageModel('/dlabscratch1/public/llm_weights/llama3.1_hf/Meta-Llama-3.1-8B', 
                        device_map='cuda:0', 
                        dispatch=True, 
                        torch_dtype=t.bfloat16)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [3]:
from tools.nnsight_utils import collect_residuals, visualize_top_tokens
from tools.patchscope import patch_scope_gen

with open('data/associations.jsonl') as f:
    associations = json.load(f)

In [4]:
prompt_template_en = "grass -> green\nsnow -> white\nblood -> red\nocean -> blue\norange -> orange\n"
prompt_template_jp = "くさ -> みどり\nゆき -> しろ\nち -> あか\nうみ -> あお\n"
prompt_template_cn = "草 -> 绿\n雪 -> 白\n血 -> 红\n海 -> 蓝\n"

In [5]:
association = associations[0]
association


{'property': 'color',
 'concept': 'pumpkin',
 'values': {'English': {'text': 'pumpkin',
   'value': 'orange',
   'en_value': 'orange'},
  'Japanese': {'text': 'カボチャ', 'value': '緑', 'en_value': 'green'}}}

In [6]:
prompt_template = prompt_template_jp

source_prompt = association["values"]["English"]["text"]
print(source_prompt)

target_prompt = prompt_template + association["values"]["Japanese"]["text"] + " ->"
print(target_prompt)

pumpkin
くさ -> みどり
ゆき -> しろ
ち -> あか
うみ -> あお
カボチャ ->


In [56]:
prompt_template = prompt_template_cn

source_prompt = association["values"]["English"]["text"]
print(source_prompt)

target_prompt = prompt_template + association["values"]["Chinese"]["text"] + " ->"
print(target_prompt)

bride
草 -> 绿
雪 -> 白
血 -> 红
海 -> 蓝
新娘 ->


In [114]:
prompt_template = prompt_template_en
source_prompt = association["values"]["Chinese"]["text"]
print(source_prompt)

target_prompt = prompt_template + association["values"]["English"]["text"] + " ->"
print(target_prompt)

股价上涨
grass -> green
snow -> white
blood -> red
ocean -> blue
stock price increase ->


In [115]:
prompt_template = prompt_template_cn
source_prompt = prompt_template_jp + association["values"]["Japanese"]["text"] + " ->"
print(source_prompt)

target_prompt = prompt_template + association["values"]["English"]["text"] + " ->"
print(target_prompt)

KeyError: 'Japanese'

In [116]:
prompt_template = prompt_template_en
source_prompt = prompt_template_jp + association["values"]["Chinese"]["text"] + " ->"
print(source_prompt)

target_prompt = prompt_template + association["values"]["English"]["text"] + " ->"
print(target_prompt)

草 -> 緑
雪 -> 白
血 -> 赤
海 -> 青
股价上涨 ->
grass -> green
snow -> white
blood -> red
ocean -> blue
stock price increase ->


In [11]:
tokenizer.tokenize(source_prompt)

['pumpkin']

In [7]:
tokenizer.tokenize(target_prompt)

['く',
 'さ',
 '▁->',
 '▁み',
 'どり',
 '\n',
 'ゆき',
 '▁->',
 '▁し',
 'ろ',
 '\n',
 'ち',
 '▁->',
 '▁あ',
 'か',
 '\n',
 'う',
 'み',
 '▁->',
 '▁あ',
 'お',
 '\n',
 'カ',
 'ボ',
 'チャ',
 '▁->']

In [9]:
model_out = collect_residuals(nnmodel, target_prompt, calculate_probs=True)
visualize_top_tokens(nnmodel, model_out["probs"], token_index=-1, k=10)

  styled_df = df.style.applymap(color_background)
  styled_df = styled_df.applymap(lambda x: 'color: black')


Unnamed: 0,Top 1,Top 2,Top 3,Top 4,Top 5,Top 6,Top 7,Top 8,Top 9,Top 10
Layer 0,-> (1.0000),–> (0.0000),→ (0.0000),--> (0.0000),-> (0.0000),---> (0.0000),=> (0.0000),(0.0000),(0.0000),(0.0000)
Layer 1,-> (1.0000),–> (0.0000),---> (0.0000),--> (0.0000),→ (0.0000),-> (0.0000),->$ (0.0000),)-> (0.0000),==> (0.0000),->{ (0.0000)
Layer 2,-> (1.0000),–> (0.0000),---> (0.0000),--> (0.0000),→ (0.0000),-> (0.0000),==> (0.0000),–> (0.0000),)-> (0.0000),rightarrow (0.0000)
Layer 3,-> (1.0000),–> (0.0000),-> (0.0000),→ (0.0000),--> (0.0000),---> (0.0000),)-> (0.0000),=> (0.0000),rightarrow (0.0000),]-> (0.0000)
Layer 4,-> (1.0000),-> (0.0000),--> (0.0000),---> (0.0000),–> (0.0000),→ (0.0000),)-> (0.0000),=> (0.0000),rightarrow (0.0000),--> (0.0000)
Layer 5,-> (1.0000),--> (0.0000),=> (0.0000),-> (0.0000),rightarrow (0.0000),---> (0.0000),→ (0.0000),–> (0.0000),--> (0.0000),→ (0.0000)
Layer 6,-> (1.0000),=> (0.0000),-> (0.0000),--> (0.0000),rightarrow (0.0000),→ (0.0000),)-> (0.0000),---> (0.0000),]-> (0.0000),--> (0.0000)
Layer 7,-> (1.0000),-> (0.0000),=> (0.0000),--> (0.0000),rightarrow (0.0000),→ (0.0000),> (0.0000),→ (0.0000),--> (0.0000),(0.0000)
Layer 8,-> (1.0000),=> (0.0000),-> (0.0000),--> (0.0000),rightarrow (0.0000),> (0.0000),→ (0.0000),--> (0.0000),into (0.0000),=> (0.0000)
Layer 9,-> (1.0000),=> (0.0000),-> (0.0000),rightarrow (0.0000),--> (0.0000),<=> (0.0000),= (0.0000),into (0.0000),> (0.0000),== (0.0000)


In [25]:
prefix = "tr"

res_tr = t.load(f"residuals/gemma2_9b_it/{prefix}_en_tr_hint_bal.pt")
res_en = t.load(f"residuals/gemma2_9b_it/{prefix}_en_no_hint_bal.pt")

steering_vec = (res_tr - res_en).mean(dim=0)
steering_vec = steering_vec.unsqueeze(1)

  res_tr = t.load(f"residuals/gemma2_9b_it/{prefix}_en_tr_hint_bal.pt")
  res_en = t.load(f"residuals/gemma2_9b_it/{prefix}_en_no_hint_bal.pt")


In [17]:
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
nnmodel = LanguageModel('/dlabscratch1/public/llm_weights/gemma_hf/gemma-2-9b-it', 
                        device_map='cuda:0', 
                        dispatch=True, 
                        torch_dtype=t.bfloat16)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [26]:
# Iterate through each layer of steering_vec
probs = []
alpha = 5
for i, layer in enumerate(steering_vec):
    # Forward the layer through the lm_head
    logits = nnmodel.lm_head(nnmodel.model.norm(layer))
    # Convert logits to probabilities using softmax
    prob = t.nn.functional.softmax(logits, dim=-1)
    probs.append(prob)
probs = t.stack(probs, dim=0)
    

visualize_top_tokens(nnmodel, probs, token_index=-1, k=10)

Unnamed: 0,Top 1,Top 2,Top 3,Top 4,Top 5,Top 6,Top 7,Top 8,Top 9,Top 10
Layer 0,autorytatywna (0.9961),ویکی‌پدی (0.0041),OGND (0.0009),تقاوى (0.0001),Personensuche (0.0000),Roskov (0.0000),parsedMessage (0.0000),uxxxx (0.0000),RegressionTest (0.0000),nonUne (0.0000)
Layer 1,parsedMessage (0.9961),tvguidetime (0.0032),StoryboardSegue (0.0005),للاسماء (0.0001),queſta (0.0000),transQ (0.0000),ſta (0.0000),Infórmanos (0.0000),featureID (0.0000),myſelf (0.0000)
Layer 2,parsedMessage (0.9062),queſta (0.0452),للاسماء (0.0452),StoryboardSegue (0.0014),tvguidetime (0.0014),protoimpl (0.0006),featureID (0.0005),PMailer (0.0002),ویکی‌پدی (0.0001),تقاوى (0.0000)
Layer 3,tvguidetime (0.4238),StoryboardSegue (0.2578),ویکی‌پدی (0.2002),featureID (0.0737),parsedMessage (0.0271),queſta (0.0128),MigrationBuilder (0.0047),protoimpl (0.0011),للاسماء (0.0001),FunctionFlags (0.0000)
Layer 4,Roskov (0.2520),queſta (0.2520),੝ (0.1196),SharedDtor (0.0723),CppMethod (0.0723),MigrationBuilder (0.0266),StoryboardSegue (0.0266),tvguidetime (0.0126),imagui (0.0126),__*/ (0.0098)
Layer 5,Roskov (0.8828),MessageOf (0.0723),Personensuche (0.0439),MigrationBuilder (0.0006),'\\;' (0.0002),StoryboardSegue (0.0001),дописавши (0.0001),Personendaten (0.0001),__': (0.0001),CppMethod (0.0000)
Layer 6,MessageOf (0.6016),Personensuche (0.2207),CreateTagHelper (0.1719),Roskov (0.0025),CppMethod (0.0012),Personendaten (0.0012),Numerade (0.0003),дописавши (0.0003),'\\;' (0.0001),الحياه (0.0001)
Layer 7,CreateTagHelper (0.9844),MessageOf (0.0085),principalColumn (0.0052),MainAxisSize (0.0002),ProtoMessage (0.0001),يتيمه (0.0000),Personensuche (0.0000),ValueStyle (0.0000),tvguidetime (0.0000),дописавши (0.0000)
Layer 8,principalColumn (0.9961),setVerticalGroup (0.0017),XmlAccessType (0.0002),complexContent (0.0001),iprot (0.0001),CreateTagHelper (0.0001),yarnpkg (0.0000), (0.0000),httphttps (0.0000),+#+# (0.0000)
Layer 9,principalColumn (0.5039),transférez (0.3066), (0.0879),Bronnen (0.0471),XmlAccessType (0.0173),CppMethod (0.0153),AssemblyProduct (0.0044),Aiheesta (0.0044),IsContent (0.0030),setVerticalGroup (0.0030)


In [120]:
model_out = collect_residuals(nnmodel, source_prompt, calculate_probs=False)
res = patch_scope_gen(nnmodel, tokenizer, model_out, target_prompt=target_prompt, target_token_idx=-2, n_new_tokens=3)

In [122]:
model_out["residuals"].shape

torch.Size([42, 1, 21, 3584])

In [121]:
res[0]

{0: ['\nstock price'],
 1: ['\nstock price'],
 2: [' green\nstock'],
 3: [' green\nstock'],
 4: [' green\nstock'],
 5: [' green\nstock'],
 6: [' green\nstock'],
 7: [' green\nstock'],
 8: [' green\nstock'],
 9: [' green\nstock'],
 10: [' green\nstock'],
 11: [' green\nstock'],
 12: [' green\nstock'],
 13: [' green\nstock'],
 14: [' green\nstock'],
 15: [' green\nstock'],
 16: [' green\nstock'],
 17: [' green\nstock'],
 18: [' green\nstock'],
 19: [' green\nstock'],
 20: [' green\nstock'],
 21: [' green\nstock'],
 22: [' green\nstock'],
 23: [' green\nstock'],
 24: [' green\nstock'],
 25: [' green\nstock'],
 26: [' green\nstock'],
 27: [' green\nstock'],
 28: [' green\nstock'],
 29: [' green\nstock'],
 30: [' green\nstock'],
 31: [' green\nstock'],
 32: [' green\nstock'],
 33: [' green\nstock'],
 34: [' green\nstock'],
 35: [' green\nstock'],
 36: [' green\nstock'],
 37: [' green\nstock'],
 38: [' green\nstock'],
 39: [' green\nstock'],
 40: [' green\nstock'],
 41: [' green\nstock']}

In [112]:
visualize_top_tokens(nnmodel, res[1], token_index=0, k=10)

Unnamed: 0,Top 1,Top 2,Top 3,Top 4,Top 5,Top 6,Top 7,Top 8,Top 9,Top 10
Layer 0,(0.9922),-> (0.0052),blue (0.0003),(0.0003),arrow (0.0002),(0.0001),to (0.0001),green (0.0001),> (0.0001),black (0.0000)
Layer 1,(0.8984),arrow (0.0447),blue (0.0211),to (0.0099),green (0.0099),black (0.0047),(0.0017),(0.0014),-> (0.0014),red (0.0014)
Layer 2,(0.7773),arrow (0.1348),blue (0.0302),green (0.0183),to (0.0142),black (0.0111),red (0.0025),(0.0019),go (0.0015),yellow (0.0012)
Layer 3,black (0.6055),blue (0.1729),green (0.0820),yellow (0.0496),purple (0.0386),(0.0302),arrow (0.0110),red (0.0052),orange (0.0019),white (0.0009)
Layer 4,blue (0.3105),black (0.2412),green (0.1885),(0.1143),arrow (0.0540),yellow (0.0540),purple (0.0198),red (0.0093),orange (0.0027),-> (0.0013)
Layer 5,black (0.5312),yellow (0.1523),blue (0.1187),purple (0.0923),green (0.0718),arrow (0.0125),brown (0.0059),(0.0059),red (0.0036),orange (0.0028)
Layer 6,purple (0.8516),black (0.0898),blue (0.0425),yellow (0.0095),brown (0.0021),green (0.0008),red (0.0006),pink (0.0004),orange (0.0003),(0.0001)
Layer 7,purple (0.8516),black (0.0898),blue (0.0425),yellow (0.0095),green (0.0010),brown (0.0010),red (0.0008),pink (0.0005),orange (0.0003),grey (0.0002)
Layer 8,purple (0.5781),blue (0.3496),yellow (0.0369),brown (0.0136),black (0.0136),green (0.0050),red (0.0014),orange (0.0005),white (0.0004),pink (0.0003)
Layer 9,purple (0.5312),blue (0.3223),yellow (0.0718),brown (0.0437),black (0.0206),green (0.0059),red (0.0013),orange (0.0008),white (0.0004),pink (0.0003)


In [98]:
from tools.nnsight_utils import get_text_generations

gen = get_text_generations(nnmodel, tokenizer, [target_prompt], device=device, max_new_tokens=3)
print(gen)

print(translator.translate(gen[0]))

[' 黄色\n木']
Yellow
Thu


In [35]:
res

({0: [' 黑\n草'],
  1: [' 黑\n黑'],
  2: [' 黑\n黑'],
  3: [' 黑\n黑'],
  4: [' 黑\n黑'],
  5: [' 黑\n黑'],
  6: [' 黑\n黑'],
  7: [' 黑\n黑'],
  8: [' 黑\n黑'],
  9: [' 黑\n黑'],
  10: [' 黑\n黑'],
  11: [' 黑\n黑'],
  12: [' 黑\n黑'],
  13: [' 黑\n黑'],
  14: [' 黑\n黑'],
  15: [' 黑\n黑'],
  16: [' 黑\n黑'],
  17: [' 黑\n黑'],
  18: [' 黑\n黑'],
  19: [' 黑\n黑'],
  20: [' 黑\n黑'],
  21: [' 黑\n黑'],
  22: [' 黑\n黑'],
  23: [' 黑\n黑'],
  24: [' 黑\n黑'],
  25: [' 黑\n黑'],
  26: [' 黑\n黑'],
  27: [' 黑\n黑'],
  28: [' 黑\n黑'],
  29: [' 黑\n黑'],
  30: [' 黑\n黑'],
  31: [' 黑\n黑']},
 tensor([[[5.1036e-07, 9.1791e-06, 7.8082e-06,  ..., 1.5734e-10,
           1.5734e-10, 1.5734e-10],
          [1.6242e-06, 5.3346e-06, 1.8597e-05,  ..., 1.7189e-10,
           1.7189e-10, 1.7189e-10],
          [4.5538e-05, 1.1063e-03, 1.9379e-03,  ..., 3.0122e-09,
           3.0122e-09, 3.0122e-09]],
 
         [[2.1607e-07, 7.2718e-06, 5.0068e-06,  ..., 1.6098e-10,
           1.6098e-10, 1.6098e-10],
          [1.0356e-06, 3.6210e-06, 1.6212e-05,  ..., 1.864

In [20]:
visualize_top_tokens(nnmodel, res[1], token_index=-3, k=5)

  styled_df = df.style.applymap(color_background)
  styled_df = styled_df.applymap(lambda x: 'color: black')


Unnamed: 0,Top 1,Top 2,Top 3,Top 4,Top 5
Layer 0,纱 (0.3867),结婚 (0.2344),(0.2061),(0.0359),洁 (0.0247)
Layer 1,(0.5078),结婚 (0.2109),(0.0684),1 (0.0603),2 (0.0286)
Layer 2,(0.4336),结婚 (0.2637),1 (0.0664),(0.0664),2 (0.0356)
Layer 3,(0.5430),结婚 (0.1553),(0.0835),1 (0.0649),2 (0.0347)
Layer 4,(0.6484),(0.0879),结婚 (0.0776),1 (0.0532),(0.0251)
Layer 5,(0.6836),(0.0928),结婚 (0.0562),1 (0.0437),(0.0266)
Layer 6,(0.7227),(0.0864),1 (0.0408),结婚 (0.0359),(0.0248)
Layer 7,(0.6914),(0.1060),结婚 (0.0500),1 (0.0388),(0.0236)
Layer 8,(0.6758),(0.1035),1 (0.0488),结婚 (0.0488),2 (0.0261)
Layer 9,(0.6914),(0.1060),1 (0.0442),结婚 (0.0391),(0.0269)


In [33]:
translator.translate("结婚")

'Marriage'

In [27]:
from tools.apis import OpenAIWrapper
import os

openai = OpenAIWrapper(os.getenv("OPENAI_API_KEY"))

target_prompt =  "In Chinese\n" + target_prompt
res = openai.text_gen(target_prompt, model_name="gpt-4o")



In [21]:
target_prompt

'In Japan\n日本語の文脈では\n草 -> 緑\n雪 -> 白\n血 -> 赤\n海 -> 青\nカボチャ ->'

In [23]:
res

'オレンジ\n\n日本語の文脈における色の連想では、カボチャは通常オレンジ色と関連付けられます。'

In [12]:
res

'オレンジ'

In [28]:
translator.translate(res)

'Wedding dress - > White'

In [15]:
translator.translate("In Japanese context,", to_lang="jp")

TypeError: Translator.translate() got an unexpected keyword argument 'to_lang'