In [1]:
import os
os.chdir('..')

In [2]:
import pandas as pd
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM
from dotenv import load_dotenv
from datasets import load_dataset, concatenate_datasets
from tqdm import tqdm
load_dotenv()

True

In [3]:
class BaseActivationSaver:
	"""
	Base class for saving activations from different models.
	"""
	def __init__(self, base_save_dir: str, task_id: str, data_split: str, model_name: str, prompt_id: str):
		self.task_id = task_id
		self.data_split = data_split
		self.model_name = model_name
		self.prompt_id = prompt_id
		self.base_save_dir = base_save_dir
		self.current_id = None
		self.current_lang = None

	def set_id(self, new_id):
		self.current_id = new_id

	def set_lang(self, new_lang):
		self.current_lang = new_lang

	def hook_fn(self, module, input, output, layer_id):
		raise NotImplementedError("This method should be overridden by subclasses.")

	def pre_hook_fn(self, module, input, layer_id):
		raise NotImplementedError("This method should be overridden by subclasses.")

	# Check if activations for an instance already exist
	def check_exists(self):
		path_last_token = os.path.join(self.base_save_dir, self.task_id, self.data_split, self.model_name.split('/')[-1], self.prompt_id, self.current_lang, self.current_id, "last_token")
		path_average = os.path.join(self.base_save_dir, self.task_id, self.data_split, self.model_name.split('/')[-1], self.prompt_id, self.current_lang, self.current_id, "average")
		check_files = os.listdir(path_last_token) if os.path.exists(path_last_token) else []

		# # Check if each extraction exist
		# post_attn_files = [f for f in check_files if 'postattn' in f]
		# post_mlp_files = [f for f in check_files if 'postmlp' in f]
		# embed_token_file = [f for f in check_files if 'embed_tokens' in f]
		# if len(post_attn_files) == 0 or len(post_mlp_files) == 0 or len(embed_token_file) == 0:
		# 	return False
		
		# Check average directory files
		check_files_avg = os.listdir(path_average) if os.path.exists(path_average) else []

		# If there are any files, return True
		return bool(check_files) and bool(check_files_avg)

	def _save_activation_last_token(self, tensor, layer_id):
		path = os.path.join(self.base_save_dir, self.task_id, self.data_split, self.model_name.split('/')[-1], self.prompt_id, self.current_lang, self.current_id, "last_token")
		os.makedirs(path, exist_ok=True)
		save_path = os.path.join(path, f"layer_{layer_id}.pt")
		torch.save(tensor[0, -1, :].detach().cpu(), save_path)

	def _save_activation_average(self, tensor, layer_id):
		path = os.path.join(self.base_save_dir, self.task_id, self.data_split, self.model_name.split('/')[-1], self.prompt_id, self.current_lang, self.current_id, "average")
		os.makedirs(path, exist_ok=True)
		save_path = os.path.join(path, f"layer_{layer_id}.pt")
		torch.save(tensor[0].mean(dim=0).detach().cpu(), save_path)
	
	def _check_set_id_lang(self, layer_id):
		if self.current_id is None:
			print(f"Warning: ID not set for layer {layer_id}")
			return False

		if self.current_lang is None:
			print(f"Warning: Language not set for layer {layer_id}")
			return False
		
		return True
	
class GeneralActivationSaver(BaseActivationSaver): # Handle Gemma3, Qwen, Pythia, Llama, Aya101 (T5) models (models that activation are returned in form of a tuple)

	def hook_fn(self, module, input, output, layer_id):
		if self._check_set_id_lang(layer_id) is False:
			return
		try:
			self._save_activation_last_token(tensor=output[0] if isinstance(output, tuple) else output, layer_id=layer_id) # Unpack tensor from the tuple
			self._save_activation_average(tensor=output[0] if isinstance(output, tuple) else output, layer_id=layer_id)
		except Exception as e:
			print(f"Error in hook_fn for layer {layer_id}: {e}")
	
	def pre_hook_fn(self, module, input, layer_id):
		if self._check_set_id_lang(layer_id) is False:
			return

		try:
			# Extract residual connection after attention (precisely after post-attention layer norm)
			self._save_activation_last_token(tensor=input[0] if isinstance(input, tuple) else input, layer_id=layer_id)
			self._save_activation_average(tensor=input[0] if isinstance(input, tuple) else input, layer_id=layer_id)
		except Exception as e:
			print(f"Error in pre_hook_fn for layer {layer_id}: {e}")

In [4]:
class BaseHookedModel:
	"""
	Base class for hooking into different models.
	"""
	def __init__(self, model_name: str, saver: BaseActivationSaver):
		device = "cpu"
		model_dtype = torch.float16
		if torch.cuda.is_available():
			device = "cuda"
			compute_capability = torch.cuda.get_device_capability()[0]

			# Use bfloat16 if supported
			if compute_capability >= 8:
				model_dtype = torch.bfloat16
		
		self.model_name = model_name
		self.saver = saver
		self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype, device_map=device, cache_dir=os.getenv("HF_CACHE_DIR"))
		self.model.eval()
	
	def _setup_hooks(self):
		raise NotImplementedError("This method should be overridden by subclasses.")
	
	def set_saver_id(self, new_id: int):
		self.saver.set_id(new_id)

	def set_saver_lang(self, new_lang: str):
		self.saver.set_lang(new_lang)
		
	def generate(self, inputs):
		with torch.no_grad():
			outputs = self.model.generate(
				**inputs,
				max_new_tokens=1,
			)
		return outputs

	# Clear hooks for debugging purposes
	def clear_hooks(self):
		if 'bloom' in self.model_name:
			for i, layer in enumerate(self.model.transformer.h):
				layer._forward_hooks.clear()
		else:
			self.model.model.embed_tokens._forward_hooks.clear()
			self.model.model.norm._forward_hooks.clear()
			for i, layer in enumerate(self.model.model.layers):
				layer._forward_hooks.clear()
				layer._forward_pre_hooks.clear()
				


In [5]:
class BaseHookedSeq2SeqModel:
	"""
	Base class for hooking into different seq2seq models.
	"""
	def __init__(self, model_name: str, saver: BaseActivationSaver):
		device = "cpu"
		model_dtype = torch.float16
		if torch.cuda.is_available():
			device = "cuda"
			compute_capability = torch.cuda.get_device_capability()[0]

			# Use bfloat16 if supported
			if compute_capability >= 8:
				model_dtype = torch.bfloat16
		
		self.model_name = model_name
		self.saver = saver
		self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, torch_dtype=model_dtype, device_map=device, cache_dir=os.getenv("HF_CACHE_DIR"))
		self.model.eval()

	def _setup_hooks(self):
		raise NotImplementedError("This method should be overridden by subclasses.")
	
	def set_saver_id(self, new_id: int):
		self.saver.set_id(new_id)

	def set_saver_lang(self, new_lang: str):
		self.saver.set_lang(new_lang)
		
	def generate(self, inputs):
		with torch.no_grad():
			outputs = self.model.generate(
				**inputs,
				max_new_tokens=1,
			)
		return 

class T5HookedModel(BaseHookedSeq2SeqModel):
	"""
	Hooked model for T5 architecture.
	"""
	def __init__(self, model_name: str, saver: BaseActivationSaver):
		super().__init__(model_name, saver)
		self._setup_hooks()
		
	def _setup_hooks(self):
		# Hook embedding layer
		self.model.decoder.embed_tokens.register_forward_hook(
			lambda module, input, output: self.saver.hook_fn(module, input, output, layer_id="embed_tokens")
		)

		for i, block in enumerate(self.model.decoder.block):
			# Self-attention residual
			block.layer[0].register_forward_hook(lambda module, input, output, layer_id=f"residual-postselfattn_{i}": self.saver.hook_fn(module, input, output, layer_id=layer_id))
			
			# Cross-attention residual
			block.layer[1].register_forward_hook(lambda module, input, output, layer_id=f"residual-postcrossattn_{i}": self.saver.hook_fn(module, input, output, layer_id=layer_id))
			
			# MLP residual
			block.layer[2].register_forward_hook(lambda module, input, output, layer_id=f"residual-postmlp_{i}": self.saver.hook_fn(module, input, output, layer_id=layer_id))
        


In [6]:
# login huggingface
from huggingface_hub import login
login(token=os.getenv("HF_TOKEN"))

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


In [7]:
model_name = "CohereLabs/aya-101"
# model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="cuda", cache_dir=os.getenv("HF_CACHE_DIR"))

In [8]:
model_name

'CohereLabs/aya-101'

In [9]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, BitsAndBytesConfig

# # 1. Define the 4-bit configuration
# bnb_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_quant_type="nf4",      # Normalized Float 4 (recommended for accuracy)
#     bnb_4bit_compute_dtype=torch.bfloat16, # Compute in float16 or bfloat16
#     bnb_4bit_use_double_quant=True, # Double quantization to save even more memory
# )

model_id = "CohereLabs/aya-101"

# 2. Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)

# 3. Load the model with the 4-bit config
model = AutoModelForSeq2SeqLM.from_pretrained(
    model_id,
    # quantization_config=bnb_config,
	torch_dtype=torch.bfloat16,
    device_map="cuda", # Automatically distributes layers to GPU
    cache_dir=os.getenv("HF_CACHE_DIR")
)

Loading checkpoint shards:   0%|          | 0/11 [00:00<?, ?it/s]

In [10]:
model

T5ForConditionalGeneration(
  (shared): Embedding(250112, 4096)
  (encoder): T5Stack(
    (embed_tokens): Embedding(250112, 4096)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=4096, out_features=4096, bias=False)
              (k): Linear(in_features=4096, out_features=4096, bias=False)
              (v): Linear(in_features=4096, out_features=4096, bias=False)
              (o): Linear(in_features=4096, out_features=4096, bias=False)
              (relative_attention_bias): Embedding(32, 64)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=4096, out_features=10240, bias=False)
              (wi_1): Linear(in_features=4096, out_features=10240, bias=False)
   

In [10]:
prompt_lang = 'no_prompt'
is_base_model = True
data_split = 'dev'
sample_size = 5
output_dir = 'output_temp'

In [11]:
# Load Model
print(f'Load model: {model_name}')

if prompt_lang == 'all':
	prompt_id_saver = 'prompted'
elif prompt_lang == 'no_prompt':
	prompt_id_saver = 'raw'
else:
	prompt_id_saver = f'prompt_{prompt_lang}' 

saver = GeneralActivationSaver(output_dir, task_id='next_token', data_split=data_split, model_name=model_name, prompt_id=prompt_id_saver)

if 'aya-101' in model_name.lower():
	hooked_model = T5HookedModel(model_name, saver=saver)
else:
	raise ValueError(f"Model {model_name} not supported in this script.")

tokenizer = AutoTokenizer.from_pretrained(model_name)

languages = ['ind_Latn', 'eng_Latn']

# Feed Forward
# for lang in languages:
for lang in languages:
	
	# Load Dataset
	datasets_per_lang = {}
	if data_split == 'all':
		datasets_per_lang_temp = {}
		datasets_per_lang_temp[lang] = load_dataset("openlanguagedata/flores_plus", lang, cache_dir=os.getenv("HF_CACHE_DIR"))
		datasets_per_lang[lang] = concatenate_datasets([datasets_per_lang_temp[lang]['dev'], datasets_per_lang_temp[lang]['devtest']])
	else:
		datasets_per_lang[lang] = load_dataset("openlanguagedata/flores_plus", lang, split=data_split, cache_dir=os.getenv("HF_CACHE_DIR"))
	
	# Sample Dataset
	if sample_size:
		datasets_per_lang[lang] = datasets_per_lang[lang].shuffle(seed=42).select(range(sample_size))

	# Load Prompt Template
	if prompt_lang == "all": 
		with open(f'./prompts/next_token/{lang}.txt') as f:
			prompt_template = f.read()
	else:
		with open(f'./prompts/next_token/{prompt_lang}.txt') as f:
			prompt_template = f.read()
	
	# Iterate Through Each Instance
	for instance in tqdm(datasets_per_lang[lang], desc=f"Processing activation for next token prediction task ({lang})"):
		# Set ID and Language in Saver
		hooked_model.set_saver_id(str(instance['id']))
		hooked_model.set_saver_lang(lang)

		# Check if activations already exist
		if saver.check_exists():
			print(f"Activations already exist for ID {instance['id']} in language {lang}. Skipping...")
			continue

		# Build Prompt Based on Template
		prompt = prompt_template.replace("{text}", instance['text'])

		# Inference
		if is_base_model or 'bloom' in model_name:
			text = prompt
		else:
			
			# Gemma2 does not support system message
			if 'google/gemma-2' in model_name.lower():
				messages = [
					{'role': 'user', 'content': prompt}
				]
			else:
				messages = [
					{'role': 'system', 'content': ''},
					{'role': 'user', 'content': prompt}
				]

			if 'meta-llama' in model_name.lower():
				user_prompt = messages[-1]['content']
				text = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{user_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
			else:
				text = tokenizer.apply_chat_template(
					messages,
					tokenize=False,
					add_generation_prompt=True,
					enable_thinking=False 
				)
		
		inputs = tokenizer([text], return_tensors="pt").to(hooked_model.model.device)
		_ = hooked_model.generate(inputs)

Load model: CohereLabs/aya-101


Loading checkpoint shards:   0%|          | 0/11 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/224 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/218 [00:00<?, ?it/s]

Processing activation for next token prediction task (ind_Latn): 100%|██████████| 5/5 [00:01<00:00,  3.75it/s]


Resolving data files:   0%|          | 0/224 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/218 [00:00<?, ?it/s]

Processing activation for next token prediction task (eng_Latn): 100%|██████████| 5/5 [00:00<00:00,  6.64it/s]
