In [1]:
import os
os.chdir('..')

In [2]:
import pandas as pd
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from dotenv import load_dotenv
from datasets import load_dataset, concatenate_datasets
from tqdm import tqdm
load_dotenv()

True

In [3]:
class BaseActivationSaver:
	"""
	Base class for saving activations from different models.
	"""
	def __init__(self, base_save_dir: str, task_id: str, data_split: str, model_name: str, prompt_id: str):
		self.task_id = task_id
		self.data_split = data_split
		self.model_name = model_name
		self.prompt_id = prompt_id
		self.base_save_dir = base_save_dir
		self.current_id = None
		self.current_lang = None

	def set_id(self, new_id):
		self.current_id = new_id

	def set_lang(self, new_lang):
		self.current_lang = new_lang

	def hook_fn(self, module, input, output, layer_id):
		raise NotImplementedError("This method should be overridden by subclasses.")

	def pre_hook_fn(self, module, input, layer_id):
		raise NotImplementedError("This method should be overridden by subclasses.")

	# Check if activations for an instance already exist
	def check_exists(self):
		path_last_token = os.path.join(self.base_save_dir, self.task_id, self.data_split, self.model_name.split('/')[-1], self.prompt_id, self.current_lang, self.current_id, "last_token")
		path_average = os.path.join(self.base_save_dir, self.task_id, self.data_split, self.model_name.split('/')[-1], self.prompt_id, self.current_lang, self.current_id, "average")
		check_files = os.listdir(path_last_token) if os.path.exists(path_last_token) else []

		# # Check if each extraction exist
		# post_attn_files = [f for f in check_files if 'postattn' in f]
		# post_mlp_files = [f for f in check_files if 'postmlp' in f]
		# embed_token_file = [f for f in check_files if 'embed_tokens' in f]
		# if len(post_attn_files) == 0 or len(post_mlp_files) == 0 or len(embed_token_file) == 0:
		# 	return False
		
		# Check average directory files
		check_files_avg = os.listdir(path_average) if os.path.exists(path_average) else []

		# If there are any files, return True
		return bool(check_files) and bool(check_files_avg)

	def _save_activation_last_token(self, tensor, layer_id):
		path = os.path.join(self.base_save_dir, self.task_id, self.data_split, self.model_name.split('/')[-1], self.prompt_id, self.current_lang, self.current_id, "last_token")
		os.makedirs(path, exist_ok=True)
		save_path = os.path.join(path, f"layer_{layer_id}.pt")
		torch.save(tensor[0, -1, :].detach().cpu(), save_path)

	def _save_activation_average(self, tensor, layer_id):
		path = os.path.join(self.base_save_dir, self.task_id, self.data_split, self.model_name.split('/')[-1], self.prompt_id, self.current_lang, self.current_id, "average")
		os.makedirs(path, exist_ok=True)
		save_path = os.path.join(path, f"layer_{layer_id}.pt")
		torch.save(tensor[0].mean(dim=0).detach().cpu(), save_path)
	
	def _check_set_id_lang(self, layer_id):
		if self.current_id is None:
			print(f"Warning: ID not set for layer {layer_id}")
			return False

		if self.current_lang is None:
			print(f"Warning: Language not set for layer {layer_id}")
			return False
		
		return True
	
class GeneralActivationSaver(BaseActivationSaver): # Handle Gemma3, Qwen, Pythia models (models that activation are returned in form of a tuple)

	def hook_fn(self, module, input, output, layer_id):
		if self._check_set_id_lang(layer_id) is False:
			return
		try:
			self._save_activation_last_token(tensor=output[0] if isinstance(output, tuple) else output, layer_id=layer_id) # Unpack tensor from the tuple
			self._save_activation_average(tensor=output[0] if isinstance(output, tuple) else output, layer_id=layer_id)
		except Exception as e:
			print(f"Error in hook_fn for layer {layer_id}: {e}")
	
	def pre_hook_fn(self, module, input, layer_id):
		if self._check_set_id_lang(layer_id) is False:
			return

		try:
			# Extract residual connection after attention (precisely after post-attention layer norm)
			self._save_activation_last_token(tensor=input[0] if isinstance(input, tuple) else input, layer_id=layer_id)
			self._save_activation_average(tensor=input[0] if isinstance(input, tuple) else input, layer_id=layer_id)
		except Exception as e:
			print(f"Error in pre_hook_fn for layer {layer_id}: {e}")

class CohereDecoderActivationSaver(BaseActivationSaver):
	def __init__(self, base_save_dir: str, task_id: str, data_split: str, model_name: str, prompt_id: str):
		super().__init__(base_save_dir, task_id, data_split, model_name, prompt_id)
		self.initial_residual = None
		self.attn_output = None
	
	def hook_fn_embed_tokens(self, module, input, output, layer_id):
		if self._check_set_id_lang(layer_id) is False:
			return
		
		try:
			self._save_activation_last_token(tensor=output[0] if isinstance(output, tuple) else output, layer_id=layer_id)
			self._save_activation_average(tensor=output[0] if isinstance(output, tuple) else output, layer_id=layer_id)
		except Exception as e:
			print(f"Error in hook_fn_embed_tokens for layer {layer_id}: {e}")

	def hook_fn_set_initial_residual(self, module, input, output, layer_id):
		if self._check_set_id_lang(layer_id) is False:
			return
		
		self.initial_residual = input[0] if isinstance(input, tuple) else input
	
	def hook_fn_set_attn_output(self, module, input, output, layer_id):
		if self._check_set_id_lang(layer_id) is False:
			return
		self.attn_output = output[0] if isinstance(output, tuple) else output
	
	def hook_fn_final_output(self, module, input, output, layer_id):
		if self._check_set_id_lang(layer_id) is False:
			return

		# Compute residual post MLP
		if self.initial_residual is None or self.attn_output is None:
			print(f"Warning: Missing stored tensors for layer {layer_id}")
			raise ValueError("Stored tensors are None")
		
		output = output[0] if isinstance(output, tuple) else output
		residual_post_mlp = self.initial_residual + self.attn_output + output

		try:
			self._save_activation_last_token(tensor=self.attn_output, layer_id=layer_id.replace('residual-postmlp', 'residual-postattn'))
			self._save_activation_average(tensor=self.attn_output, layer_id=layer_id.replace('residual-postmlp', 'residual-postattn'))
		except Exception as e:
			print(f"Error in hook_fn_final_output for layer {layer_id.replace('residual-postmlp', 'residual-postattn')}: {e}")
		
		try:
			self._save_activation_last_token(tensor=residual_post_mlp, layer_id=layer_id)
			self._save_activation_average(tensor=residual_post_mlp, layer_id=layer_id)
		except Exception as e:
			print(f"Error in hook_fn_final_output for layer {layer_id}: {e}")

		# Reset stored tensors
		self.initial_residual = None
		self.attn_output = None

In [4]:
class BaseHookedModel:
	"""
	Base class for hooking into different models.
	"""
	def __init__(self, model_name: str, saver: BaseActivationSaver):
		device = "cpu"
		model_dtype = torch.float16
		if torch.cuda.is_available():
			device = "cuda"
			compute_capability = torch.cuda.get_device_capability()[0]

			# Use bfloat16 if supported
			if compute_capability >= 8:
				model_dtype = torch.bfloat16
		
		self.model_name = model_name
		self.saver = saver
		self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype, device_map=device, cache_dir=os.getenv("HF_CACHE_DIR"))
		self.model.eval()
	
	def _setup_hooks(self):
		raise NotImplementedError("This method should be overridden by subclasses.")
	
	def set_saver_id(self, new_id: int):
		self.saver.set_id(new_id)

	def set_saver_lang(self, new_lang: str):
		self.saver.set_lang(new_lang)
		
	def generate(self, inputs):
		with torch.no_grad():
			outputs = self.model.generate(
				**inputs,
				max_new_tokens=1,
			)
		return outputs

	# Clear hooks for debugging purposes
	def clear_hooks(self):
		if 'bloom' in self.model_name:
			for i, layer in enumerate(self.model.transformer.h):
				layer._forward_hooks.clear()
		else:
			self.model.model.embed_tokens._forward_hooks.clear()
			self.model.model.norm._forward_hooks.clear()
			for i, layer in enumerate(self.model.model.layers):
				layer._forward_hooks.clear()
				layer._forward_pre_hooks.clear()
	
class Gemma3MultimodalHookedModel(BaseHookedModel): # For gemma-3 >=4b
	def __init__(self, model_name: str, saver: BaseActivationSaver):
		super().__init__(model_name, saver)
		self._setup_hooks()

	def _setup_hooks(self):
		self.model.model.language_model.embed_tokens.register_forward_hook(lambda module, input, output, layer_id="embed_tokens": self.saver.hook_fn(module, input, output, layer_id))

		# Decoder layers
		for i, layer in enumerate(self.model.model.language_model.layers):

			# Final output of decoder layer hook
			layer.register_forward_hook(lambda module, input, output, layer_id=f'residual-postmlp_{i}': self.saver.hook_fn(module, input, output, layer_id))

			# Post-attention layer norm pre-hook (residual post attention)
			layer.pre_feedforward_layernorm.register_forward_pre_hook(lambda module, input, layer_id=f"residual-postattn_{i}": self.saver.pre_hook_fn(module, input, layer_id))

class PythiaHookedModel(BaseHookedModel): # For pythia models
	def __init__(self, model_name: str, saver: BaseActivationSaver):
		super().__init__(model_name, saver)
		self._setup_hooks()

	def _setup_hooks(self):
		# Embedding layer
		self.model.gpt_neox.embed_in.register_forward_hook(lambda module, input, output, layer_id="embed_tokens": self.saver.hook_fn(module, input, output, layer_id))

		# Decoder layers
		for i, layer in enumerate(self.model.gpt_neox.layers):

			# Final output of decoder layer hook
			layer.register_forward_hook(lambda module, input, output, layer_id=f'residual-postmlp_{i}': self.saver.hook_fn(module, input, output, layer_id))

			# Post-attention layer norm pre-hook (residual post attention)
			layer.post_attention_layernorm.register_forward_pre_hook(lambda module, input, layer_id=f"residual-postattn_{i}": self.saver.pre_hook_fn(module, input, layer_id))

class CohereDecoderHookedModel(BaseHookedModel): # For cohere decoder models
	def __init__(self, model_name: str, saver: CohereDecoderActivationSaver):
		super().__init__(model_name, saver)
		self._setup_hooks()

	def _setup_hooks(self):
		# Embedding layer
		self.model.model.embed_tokens.register_forward_hook(lambda module, input, output, layer_id="embed_tokens": self.saver.hook_fn_embed_tokens(module, input, output, layer_id))

		# Decoder layers
		for i, layer in enumerate(self.model.model.layers):

			# Pre-attention layer norm hook (residual pre attention)
			layer.input_layernorm.register_forward_hook(lambda module, input, output, layer_id=f"residual-preattn_{i}": self.saver.hook_fn_set_initial_residual(module, input, output, layer_id))

			# Post-attention layer norm pre-hook (residual post attention)
			layer.self_attn.register_forward_hook(lambda module, input, output, layer_id=f"residual-postattn_{i}": self.saver.hook_fn_set_attn_output(module, input, output, layer_id))
			
			# Final output of decoder layer hook
			layer.register_forward_hook(lambda module, input, output, layer_id=f'residual-postmlp_{i}': self.saver.hook_fn_final_output(module, input, output, layer_id))

class Qwen3HookedModel(BaseHookedModel): # For Qwen models
	def __init__(self, model_name: str, saver: BaseActivationSaver):
		super().__init__(model_name, saver)
		self._setup_hooks()

	def _setup_hooks(self):
		# Embedding layer
		self.model.model.embed_tokens.register_forward_hook(lambda module, input, output, layer_id="embed_tokens": self.saver.hook_fn(module, input, output, layer_id))

		# Decoder layers
		for i, layer in enumerate(self.model.model.layers):

			# Post-attention layer norm pre-hook (residual post attention)
			layer.post_attention_layernorm.register_forward_pre_hook(lambda module, input, layer_id=f"residual-postattn_{i}": self.saver.pre_hook_fn(module, input, layer_id))
			
			# Final output of decoder layer hook
			layer.register_forward_hook(lambda module, input, output, layer_id=f'residual-postmlp_{i}': self.saver.hook_fn(module, input, output, layer_id))

class LlamaHookedModel(BaseHookedModel): # For Llama 3 models
	def __init__(self, model_name: str, saver: BaseActivationSaver):
		super().__init__(model_name, saver)
		self._setup_hooks()

	def _setup_hooks(self):
		# Embedding layer
		self.model.model.embed_tokens.register_forward_hook(lambda module, input, output, layer_id="embed_tokens": self.saver.hook_fn(module, input, output, layer_id))

		# Decoder layers
		for i, layer in enumerate(self.model.model.layers):

			# Post-attention layer norm pre-hook (residual post attention)
			layer.post_attention_layernorm.register_forward_pre_hook(lambda module, input, layer_id=f"residual-postattn_{i}": self.saver.pre_hook_fn(module, input, layer_id))
			
			# Final output of decoder layer hook
			layer.register_forward_hook(lambda module, input, output, layer_id=f'residual-postmlp_{i}': self.saver.hook_fn(module, input, output, layer_id))

In [5]:
# login huggingface
from huggingface_hub import login
login(token=os.getenv("HF_TOKEN"))

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [6]:
model_name = 'meta-llama/Llama-3.1-8B-Instruct'
# model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="cuda", cache_dir=os.getenv("HF_CACHE_DIR"))

In [8]:
prompt_lang = 'no_prompt'
is_base_model = True
data_split = 'dev'
sample_size = 5
output_dir = 'output_temp'

In [9]:
# Load Model
print(f'Load model: {model_name}')

if prompt_lang == 'all':
	prompt_id_saver = 'prompted'
elif prompt_lang == 'no_prompt':
	prompt_id_saver = 'raw'
else:
	prompt_id_saver = f'prompt_{prompt_lang}' 

if 'cohere' in model_name.lower():
	saver = CohereDecoderActivationSaver(output_dir, task_id='next_token', data_split=data_split, model_name=model_name, prompt_id=prompt_id_saver)
else:
	saver = GeneralActivationSaver(output_dir, task_id='next_token', data_split=data_split, model_name=model_name, prompt_id=prompt_id_saver)

if 'gemma-3' in model_name.lower():
	hooked_model = Gemma3MultimodalHookedModel(model_name, saver=saver)
elif 'cohere' in model_name.lower():
	hooked_model = CohereDecoderHookedModel(model_name, saver=saver)
elif 'pythia' in model_name.lower():
	hooked_model = PythiaHookedModel(model_name, saver=saver)
elif 'qwen' in model_name.lower():
	hooked_model = Qwen3HookedModel(model_name, saver=saver)
elif 'meta-llama' in model_name.lower():
	hooked_model = LlamaHookedModel(model_name, saver=saver)
else:
	raise ValueError(f"Model {model_name} not supported in this script.")

tokenizer = AutoTokenizer.from_pretrained(model_name)

languages = ['ind_Latn', 'eng_Latn']

# Feed Forward
# for lang in languages:
for lang in languages:
	
	# Load Dataset
	datasets_per_lang = {}
	if data_split == 'all':
		datasets_per_lang_temp = {}
		datasets_per_lang_temp[lang] = load_dataset("openlanguagedata/flores_plus", lang, cache_dir=os.getenv("HF_CACHE_DIR"))
		datasets_per_lang[lang] = concatenate_datasets([datasets_per_lang_temp[lang]['dev'], datasets_per_lang_temp[lang]['devtest']])
	else:
		datasets_per_lang[lang] = load_dataset("openlanguagedata/flores_plus", lang, split=data_split, cache_dir=os.getenv("HF_CACHE_DIR"))
	
	# Sample Dataset
	if sample_size:
		datasets_per_lang[lang] = datasets_per_lang[lang].shuffle(seed=42).select(range(sample_size))

	# Load Prompt Template
	if prompt_lang == "all": 
		with open(f'./prompts/next_token/{lang}.txt') as f:
			prompt_template = f.read()
	else:
		with open(f'./prompts/next_token/{prompt_lang}.txt') as f:
			prompt_template = f.read()
	
	# Iterate Through Each Instance
	for instance in tqdm(datasets_per_lang[lang], desc=f"Processing activation for next token prediction task ({lang})"):
		# Set ID and Language in Saver
		hooked_model.set_saver_id(str(instance['id']))
		hooked_model.set_saver_lang(lang)

		# Check if activations already exist
		if saver.check_exists():
			print(f"Activations already exist for ID {instance['id']} in language {lang}. Skipping...")
			continue

		# Build Prompt Based on Template
		prompt = prompt_template.replace("{text}", instance['text'])

		# Inference
		if is_base_model or 'bloom' in model_name:
			text = prompt
		else:
			
			# Gemma2 does not support system message
			if 'google/gemma-2' in model_name.lower():
				messages = [
					{'role': 'user', 'content': prompt}
				]
			else:
				messages = [
					{'role': 'system', 'content': ''},
					{'role': 'user', 'content': prompt}
				]

			if 'meta-llama' in model_name.lower():
				user_prompt = messages[-1]['content']
				text = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{user_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
			else:
				text = tokenizer.apply_chat_template(
					messages,
					tokenize=False,
					add_generation_prompt=True,
					enable_thinking=False 
				)
		
		inputs = tokenizer([text], return_tensors="pt").to(hooked_model.model.device)
		_ = hooked_model.generate(inputs)

Load model: meta-llama/Llama-3.1-8B-Instruct


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/223 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/217 [00:00<?, ?it/s]

Processing activation for next token prediction task (ind_Latn):   0%|          | 0/5 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Processing activation for next token prediction task (ind_Latn):  20%|██        | 1/5 [00:00<00:02,  1.50it/s]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Processing activation for next token prediction task (ind_Latn):  40%|████      | 2/5 [00:00<00:01,  2.97it/s]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Processing activation for next token prediction task (ind_Latn):  60%|██████    | 3/5 [00:00<00:00,  4.36it/s]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Processing activation for next token prediction task (ind_Latn):  80%|████████  | 4/5 [00:00<00:00,  5.58it/s]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Processing activation for next token prediction task (ind_Latn): 100%|██████████| 5/5 [00:01<

Resolving data files:   0%|          | 0/223 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/217 [00:00<?, ?it/s]

Processing activation for next token prediction task (eng_Latn):   0%|          | 0/5 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Processing activation for next token prediction task (eng_Latn):  20%|██        | 1/5 [00:00<00:00,  9.64it/s]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Processing activation for next token prediction task (eng_Latn):  40%|████      | 2/5 [00:00<00:00,  9.55it/s]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Processing activation for next token prediction task (eng_Latn):  60%|██████    | 3/5 [00:00<00:00,  9.53it/s]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Processing activation for next token prediction task (eng_Latn):  80%|████████  | 4/5 [00:00<00:00,  9.54it/s]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Processing activation for next token prediction task (eng_Latn): 100%|██████████| 5/5 [00:00<