In [None]:
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import torch
import pandas as pd
from tqdm import tqdm
import numpy as np
from copy import deepcopy
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import math
import pickle

In [None]:
prompt_path = 'prompts/mt/prompt_en.txt'
os.path.basename(prompt_path)

In [None]:
# Target Language: Indonesia (3), English (5)
# Source Language: Hindi Latin (4), French (5), Javanese (1), Sundanese (1), Turkish (4), Welsh (1)

# Target Language: ind_Latn, eng_Latn
# Source Language: -, fra_Latn, jav_Latn, sun_Latn, tur_Latn, cym_Latn

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
langs = ['ind_Latn', 'eng_Latn', 'fra_Latn', 'jav_Latn', 'sun_Latn', 'tur_Latn', 'cym_Latn']
target_langs = ['ind_Latn', 'eng_Latn']
source_langs = ['fra_Latn', 'jav_Latn', 'sun_Latn', 'tur_Latn', 'cym_Latn']
languages_name = {
    'ind_Latn': 'Indonesian',
    'eng_Latn': 'English',
    'fra_Latn': 'French',
    'jav_Latn': 'Javanese',
    'sun_Latn': 'Sundanese',
    'tur_Latn': 'Turkish',
    'cym_Latn': 'Welsh'
}

In [None]:
from datasets import load_dataset
datasets_per_lang = {}
for lang in langs:
    datasets_per_lang[lang] = load_dataset("openlanguagedata/flores_plus", lang, split="devtest")

In [None]:
datasets_per_lang['ind_Latn'][0]

In [None]:
prompt_en = """Translate the following text from {source_lang} to {target_lang}.
Text: {text}
Translated Text:"""

## Inference

In [None]:
# Models
# LLama 3.2 1B
# Gemma 3 1B
# Sahabat-AI/gemma2-9b-cpt-sahabatai-v1-instruct
# Sahabat-AI/llama3-8b-cpt-sahabatai-v1-instruct
# bigscience/bloom-7b1
# sail/Sailor2-8B-Chat
# Qwen/Qwen3-8B
# CohereLabs/aya-expanse-8b

In [None]:
model_name = 'CohereLabs/aya-expanse-8b'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="cuda")
model.eval()

In [None]:
def inference_hooked_model(initial_dataset, source_lang: str, target_lang: str, languages_name: dict, model, tokenizer, model_name: str, initial_prompt: str = None, prompt_lang: str = None, save_results: bool = False, is_base_model: bool = False	):
	results = {}
	dataset = deepcopy(initial_dataset)
	if initial_prompt:
		print(f"Saving to outputs_1token_mt/{model_name}/prompt_{prompt_lang}/{source_lang}-{target_lang}")
	else:
		print(f"Saving to outputs_1token_mt/{model_name}/prompt_raw/{lang}")
	for test_instance in tqdm(dataset):
		# prepare the model input
		if initial_prompt:
			save_dir = f'outputs_1token_mt/{model_name}/prompt_{prompt_lang}/{source_lang}-{target_lang}/{test_instance['id']}'
			prompt = initial_prompt.replace("{text}", test_instance['text'])
			prompt = prompt.replace("{source_lang}", languages_name[source_lang])
			prompt = prompt.replace("{target_lang}", languages_name[target_lang])
			if is_base_model:
				text = prompt
			else: # if using a chat/instruct model
				messages = [
					{"role": "user", "content": prompt}
				]
		else:
			save_dir = f'outputs_1token_mt/{model_name}/prompt_raw/{source_lang}-{target_lang}/{test_instance['id']}'
			if is_base_model:
				text = test_instance['text']
			else: # if using a chat/instruct model
				messages = [
					{"role": "user", "content": deepcopy(test_instance['text'])}
				]
		if not is_base_model:
			if 'bloom' not in model_name:
				text = tokenizer.apply_chat_template(
					messages,
					tokenize=False,
					add_generation_prompt=True,
					enable_thinking=False # Switches between thinking and non-thinking modes. Default is True.
				)
			else:
				text = messages[0]['content']
		model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

		def hook_fn(m, i, o, layer_id):
			os.makedirs(save_dir, exist_ok=True)
			save_path = os.path.join(save_dir, f"{layer_id}.pt")
			# Fix (i guess because of the transformers version, it is not nested anymore)
			# torch.save(o[0][0, -1, :].detach().cpu(), save_path)
			# o: [batch_size, sequence_length, hidden_dimension]
			torch.save(o[0, -1, :].detach().cpu(), save_path)
			
		if 'bloom' in model_name:
			for i, layer in enumerate(model.transformer.h):
				layer.register_forward_hook(
					lambda m, i, o, layer_id=i: hook_fn(m, i, o, layer_id=layer_id)
				)
		else:
			for i, layer in enumerate(model.model.layers):
				layer.register_forward_hook(
					lambda m, i, o, layer_id=i: hook_fn(m, i, o, layer_id=layer_id)
				)

		# conduct text completion
		generated_ids = model.generate(
			**model_inputs,
			max_new_tokens=1
		)
		
		output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
		content = tokenizer.decode(output_ids, skip_special_tokens=True)
		results[test_instance['id']] = content

		# Clear hooks after processing
		if 'bloom' in model_name:
			for layer in model.transformer.h:
				for hook in layer._forward_hooks.values():
					layer._forward_hooks.clear()
		else:
			for layer in model.model.layers:
				for hook in layer._forward_hooks.values():
					layer._forward_hooks.clear()
	if save_results:			
		return results

In [None]:
print('Processing with english prompt.')
for target_lang in target_langs:
    for source_lang in source_langs:
        print("Processing source language:", source_lang, "and target language:", target_lang)
        results = inference_hooked_model(
            initial_dataset=datasets_per_lang[source_lang],
            source_lang=source_lang,
			target_lang=target_lang,
            languages_name=languages_name,
			model=model,
			tokenizer=tokenizer,
			model_name=model_name.split('/')[-1],
			initial_prompt=prompt_en,
			prompt_lang='en',
			save_results=True,
			is_base_model=False
		)

## Plot

In [None]:
language_pairs = []
for target_lang in target_langs:
	for source_lang in source_langs:
		language_pairs.append(f"{source_lang}-{target_lang}")

In [None]:
datasets_per_lang['fra_Latn'].to_pandas()

In [None]:
def plot_by_category(dataset_dict: dict, model_name: str, num_layers: int, labels: list, language_pairs: list, outputs_dir: str = 'outputs_1token_mt', prompt_lang: str = 'en', save_plot: bool = False, save_plot_indicator: str = 'topics', show_plot: bool = True, save_tsne: bool = False, calculate_tsne: bool = True):
	dataset_sample = dataset_dict['fra_Latn'].to_pandas()
	cmap = plt.get_cmap('tab10')
	color_map = {category: cmap(i) for i, category in enumerate(labels)}
	fig, axes = plt.subplots(math.ceil(num_layers/7), 7, figsize=(50, 5 * math.ceil(num_layers/7)))
	axes = axes.flatten()
	
	for layer in range(num_layers):
		# Load the activation for the current layer for all samples
		activation_np = []
		categories_or_langs = []  # Store either categories or language_pairs
		# Iterate through the dataset and load activations
		if calculate_tsne:
			for idx, row in dataset_sample.iterrows():
				for lang_pair in language_pairs:
					activation_path = f'{outputs_dir}/{model_name}/prompt_{prompt_lang}/{lang_pair}/{row['id']}/{layer}.pt'
					activation = torch.load(activation_path)
					activation = activation.float()
					activation_np.append(activation.cpu().numpy())
					if 'lang-pairs' in save_plot_indicator:
						categories_or_langs.append(lang_pair)
					else:
						raise ValueError("Invalid save_plot_indicator. Use 'lang-pairs'.")
			
			activation_np = np.array(activation_np)
		
			# Perform t-SNE with 2 components
			tsne = TSNE(n_components=2, random_state=42, n_jobs=-1)
			activation_2d = tsne.fit_transform(activation_np)
		else:
			# Read precomputed t-SNE results
			tsne_path = f'{outputs_dir}/{model_name}/tsne/prompt_{prompt_lang}/layer-{layer}_tsne_{save_plot_indicator}.npy'
			if os.path.exists(tsne_path):
				activation_2d = np.load(tsne_path)
			else:
				raise FileNotFoundError(f"Precomputed t-SNE results not found at {tsne_path}. Set calculate_tsne to True to compute t-SNE.")
			
			# Load metadata (categories or languages)
			metadata_path = f'{outputs_dir}/{model_name}/tsne/prompt_{prompt_lang}/layer-{layer}_metadata_{save_plot_indicator}.pkl'
			if os.path.exists(metadata_path):
				with open(metadata_path, 'rb') as f:
					categories_or_langs = pickle.load(f)
			else:
				raise FileNotFoundError(f"Metadata not found at {metadata_path}. Set calculate_tsne to True to compute t-SNE.")
		
		# Save t-SNE results if requested
		if save_tsne:
			tsne_dir = f'{outputs_dir}/{model_name}/tsne/prompt_{prompt_lang}'
			os.makedirs(tsne_dir, exist_ok=True)
			
			# Save t-SNE coordinates
			tsne_save_path = os.path.join(tsne_dir, f'layer-{layer}_tsne_{save_plot_indicator}.npy')
			np.save(tsne_save_path, activation_2d)
			
			# Save metadata (categories or languages) to a pickle file
			metadata_save_path = os.path.join(tsne_dir, f'layer-{layer}_metadata_{save_plot_indicator}.pkl')
			with open(metadata_save_path, 'wb') as f:
				pickle.dump(categories_or_langs, f)

			print(f"Saved t-SNE results for layer {layer} to {tsne_dir}")

		# Plot the t-SNE results (activation_2d), with colors based on the predicted category or language
		ax = axes[layer]
		ax.set_title(f'Layer {layer + 1}')
		
		# Create scatter plot for each label to enable legend
		for label in labels:
			mask = [cat_or_lang == label for cat_or_lang in categories_or_langs]
			if any(mask):
				ax.scatter(activation_2d[mask, 0], activation_2d[mask, 1], 
						c=color_map[label], s=10, alpha=0.5, label=label)
		
		ax.set_xlabel('t-SNE Component 1')
		ax.set_ylabel('t-SNE Component 2')
		ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
	
	# Save the plot if required
	if save_plot:
		os.makedirs(f'{outputs_dir}/{model_name}/plots/prompt_{prompt_lang}', exist_ok=True)
		plt.savefig(f'{outputs_dir}/{model_name}/plots/prompt_{prompt_lang}/tsne_{save_plot_indicator}.png', bbox_inches='tight')

	if show_plot:
		plt.tight_layout()
		plt.show()

In [None]:
target_langs

In [None]:
# Plot per target lang
for target_lang in target_langs:
	plot_by_category(
		dataset_dict=datasets_per_lang, 
		model_name=model_name.split('/')[-1],
		num_layers=len(model.transformer.h) if 'bloom' in model_name else len(model.model.layers),
		labels=[lang_pair for lang_pair in language_pairs if target_lang in lang_pair],
		language_pairs=[lang_pair for lang_pair in language_pairs if target_lang in lang_pair],
		outputs_dir='outputs_1token_mt',
		prompt_lang='en',
		save_plot=True,
		save_plot_indicator=f'lang-pairs_{target_lang}',
		show_plot=False,
		save_tsne=True,
		calculate_tsne=True
	)

In [None]:
# Plot merged ind and eng
plot_by_category(
	dataset_dict=datasets_per_lang, 
	model_name=model_name.split('/')[-1],
	num_layers=len(model.transformer.h) if 'bloom' in model_name else len(model.model.layers),
	labels=language_pairs,
	language_pairs=language_pairs,
	outputs_dir='outputs_1token_mt',
	prompt_lang='en',
	save_plot=True,
	save_plot_indicator=f'lang-pairs_merged',
	show_plot=False,
	save_tsne=True,
	calculate_tsne=True
)

In [None]:
target_lang = 'ind_Latn'
plot_by_category(
	dataset_dict=datasets_per_lang, 
	model_name=model_name.split('/')[-1],
	num_layers=len(model.transformer.h) if 'bloom' in model_name else len(model.model.layers),
	labels=[lang_pair for lang_pair in language_pairs if target_lang in lang_pair],
	language_pairs=[lang_pair for lang_pair in language_pairs if target_lang in lang_pair],
	outputs_dir='outputs_1token_mt',
	prompt_lang='en',
	save_plot=True,
	save_plot_indicator=f'lang-pairs_{target_lang}',
	show_plot=False,
	save_tsne=True,
	calculate_tsne=True
)

In [None]:
target_lang = 'eng_Latn'
plot_by_category(
	dataset_dict=datasets_per_lang, 
	model_name=model_name.split('/')[-1],
	num_layers=len(model.transformer.h) if 'bloom' in model_name else len(model.model.layers),
	labels=[lang_pair for lang_pair in language_pairs if target_lang in lang_pair],
	language_pairs=[lang_pair for lang_pair in language_pairs if target_lang in lang_pair],
	outputs_dir='outputs_1token_mt',
	prompt_lang='en',
	save_plot=True,
	save_plot_indicator=f'lang-pairs_{target_lang}',
	show_plot=False,
	save_tsne=True,
	calculate_tsne=True
)