In [None]:
import os
import sys

# Change to parent directory
os.chdir('..')
# Add current directory to path
sys.path.append('.')

In [None]:
import pandas as pd
import numpy as np
from src.utils.const import LANGCODE2LANGNAME, LANGNAME2LANGCODE, MODEL2HIDDEN_SIZE, MODEL2NUM_LAYERS, EXP2_CONFIG, EXP3_CONFIG
import glob
import torch
from tqdm import tqdm

In [None]:
model_name = 'gemma-3-4b-it' # 1152 size for 1b, 2560 for 4b
model_to_num_layers = {
    'gemma-3-1b-it': 26,
	'gemma-3-4b-it': 34,
    'gemma-3-270m-it': 18,
    'gemma-2-9b-it': 42
}
model_to_hidden_size = {
	'gemma-3-1b-it': 1152,
	'gemma-3-4b-it': 2560,
	'gemma-3-270m-it': 640,
	'gemma-2-9b-it': 3584
}
num_layers = model_to_num_layers[model_name]
extraction_mode = 'raw'
token_position = 'last_token'
languages = glob.glob('outputs/topic_classification/gemma-3-4b-it/raw/*')
languages = sorted([lang.split('/')[-1] for lang in languages])
text_ids = glob.glob('outputs/topic_classification/gemma-3-4b-it/raw/eng_Latn/*')
text_ids = [text_id.split('/')[-1].split('.')[0] for text_id in text_ids]

In [None]:
activation_path = 'outputs_flores_plus/next_token/dev/Qwen3-8B/raw/amh_Ethi/2/average/layer_residual-postattn_6.pt'
activation = torch.load(activation_path)
print(f'Activation shape: {activation.shape}')  # Should be (num_texts,

In [None]:
# # Extract language codes from the comment and add eng_Latn
# languages = [
# 	'vie_Latn',  # Vietnamese
# 	'ind_Latn',  # Indonesian
# 	'tha_Thai',  # Thai
# 	'zsm_Latn',  # Malay
# 	'mya_Mymr',  # Burmese -> 0 problem
# 	'tgl_Latn',  # Tagalog
# 	'khm_Khmr',  # Khmer
# 	'ceb_Latn',  # Cebuano
# 	'lao_Laoo',  # Lao
# 	'jav_Latn',  # Javanese
# 	'war_Latn',  # Waray
# 	'sun_Latn',  # Sundanese
# 	'ilo_Latn',  # Ilocano
# 	'tam_Taml',  # Tamil
# 	'zho_Hans',  # Chinese
# 	'eng_Latn'   # English
# ]

languages = []
for family, langs in EXP3_CONFIG['languages'].items():
	languages.extend(langs)

languages = [LANGNAME2LANGCODE[lang] for lang in languages]
languages

In [None]:
print(f"Text IDs: {len(text_ids)}, Num layers: {num_layers}, Number of languages: {len(languages)}, Hidden size: {model_to_hidden_size[model_name]}")

In [None]:
# Initialize empty torch tensor to hold all activations [text_id, layer_id, language, hidden_size]
activation_per_lang = torch.zeros((len(text_ids), num_layers + 1, len(languages), MODEL2HIDDEN_SIZE[model_name])).to('cuda')
print(f'Activation tensor shape: {activation_per_lang.shape}')

# Reshape to [text_id * language, layer_id, hidden_size]
activation_per_lang = activation_per_lang.view(-1, num_layers + 1, MODEL2HIDDEN_SIZE[model_name])
print(f'Reshaped activation tensor shape: {activation_per_lang.shape}')

# Initialize empty torch tensor to hold all labels [text_id, language]
labels = torch.zeros((len(text_ids) * len(languages),), dtype=torch.long).to('cuda')
print(f'Labels tensor shape: {labels.shape}')

In [None]:
activation_per_lang.shape

In [None]:
# # Load activations for all languages
# id2langcode = {idx: lang for idx, lang in enumerate(languages)}
# langcode2id = {lang: idx for idx, lang in enumerate(languages)}
# text_idx = 0
# for lang_idx, lang in tqdm(enumerate(languages), total=len(languages), desc='Loading activations for all languages'):
# 	for layer_id in range(-1, num_layers):
# 		if layer_id == -1:
# 			paths = sorted(glob.glob(f'outputs/topic_classification/{model_name}/{extraction_mode}/{lang}/*/{token_position}/layer_embed_tokens.pt'))
# 		else:
# 			paths = sorted(glob.glob(f'outputs/topic_classification/{model_name}/{extraction_mode}/{lang}/*/{token_position}/layer_{layer_id}.pt'))
# 		if len(paths) != len(text_ids):
# 			print(f"Warning: Expected {len(text_ids)} files for language '{lang}' at layer {layer_id}, but found {len(paths)} files.")
# 			break
# 		for path in paths:
# 			activation = torch.load(path)
# 			activation_per_lang[text_idx, layer_id + 1, lang_idx, :] = activation.to('cuda')
# 			text_idx += 1
# 		text_idx = 0

In [None]:
# Load activations for all languages
text_idx = 0
id2langcode = {idx: lang for idx, lang in enumerate(languages)}
langcode2id = {lang: idx for idx, lang in enumerate(languages)}

for lang_idx, lang in tqdm(enumerate(languages), total=len(languages), desc='Loading activations for all languages'):
	for layer_id in range(-1, num_layers):
		if layer_id == -1:
			paths = sorted(glob.glob(f'outputs/topic_classification/{model_name}/{extraction_mode}/{lang}/*/{token_position}/layer_embed_tokens.pt'))
		else:
			paths = sorted(glob.glob(f'outputs/topic_classification/{model_name}/{extraction_mode}/{lang}/*/{token_position}/layer_{layer_id}.pt'))
		if len(paths) != len(text_ids):
			print(f"Warning: Expected {len(text_ids)} files for language '{lang}' at layer {layer_id}, but found {len(paths)} files.")
			break
		for text_idx_inner, path in enumerate(paths):
			activation = torch.load(path)
			# Calculate the flat index for the reshaped tensor
			flat_idx = text_idx_inner * len(languages) + lang_idx
			activation_per_lang[flat_idx, layer_id + 1, :] = activation.to('cuda')
			# Populate labels with language_id (only need to do this once per text-language pair)
			if layer_id == -1:
				labels[flat_idx] = lang_idx


In [None]:
activation_per_lang.shape

In [None]:
labels.shape

In [None]:
activation_per_lang[:, layer_id + 1, :].shape

In [None]:
import cudf
import cupy as cp
from cuml.cluster import KMeans
from cuml.metrics.cluster.silhouette_score import cython_silhouette_score
from sklearn.datasets import make_blobs

In [None]:
# Take all index of labels that have value 1
indexes_lang1 = ((labels == 1) | (labels == 2)).nonzero(as_tuple=True)[0]
# Take activations for those indexes
activations_lang1 = activation_per_lang[indexes_lang1, layer_id + 1, :]
labels_lang1 = labels[indexes_lang1]

In [None]:


# 4. Calculate the silhouette score on the GPU
# The function takes the data and the predicted labels as input.
silhouette_score_matrix = torch.zeros((num_layers + 1, len(languages), len(languages))).to('cuda')
for layer_id in tqdm(range(-1, num_layers)):
	# Calculate pairwise silhouette scores per language
	for lang_idx1 in range(len(languages)):
		for lang_idx2 in range(len(languages)):
			if lang_idx1 != lang_idx2:
				# Take all index of labels that have value 1
				indexes_lang1 = ((labels == lang_idx1) | (labels == lang_idx2)).nonzero(as_tuple=True)[0]
				# Take activations for those indexes
				activations_lang1 = activation_per_lang[indexes_lang1, layer_id + 1, :]
				labels_lang1 = labels[indexes_lang1]
				# Combine activations and labels to both one tensor
				score = cython_silhouette_score(activations_lang1, labels_lang1)
				# Store score in matrix
				silhouette_score_matrix[layer_id + 1, lang_idx1, lang_idx2] = score

				# print(f"Silhouette Score between {id2langcode[lang_idx1]} and {id2langcode[lang_idx2]}: {score:.4f}")

# For this well-separated data, you should see a high score (e.g., > 0.8)
# A high score indicates that the clusters are dense and well-separated.

In [None]:
silhouette_score_matrix[1, :, :]

In [None]:
# Store silhouette score matrix
os.makedirs('outputs/silhouette_scores', exist_ok=True)
torch.save(silhouette_score_matrix, f'outputs/silhouette_scores/{model_name}_{extraction_mode}_{token_position}_silhouette_scores.pt')

In [None]:
# Get top k minimum distance values in each layer with their language pair identifiers

# Set k as a variable
k = 50  # You can change this value as needed

# Initialize lists to store results
all_min_distances = []
all_min_distance_pairs = []

# For each layer, find the top k minimum distances and corresponding language pairs
for layer_id in range(-1, num_layers):
	layer_distances = silhouette_score_matrix[layer_id + 1]
	
	# Create a mask to exclude diagonal elements and upper triangle (to avoid duplicates)
	mask = torch.tril(torch.ones_like(layer_distances, dtype=torch.bool), diagonal=-1)
	
	# Get all distances excluding diagonal and upper triangle, then flatten
	lower_triangle_distances = layer_distances[mask]
	
	# Get top k minimum distances
	topk_min_values, topk_indices = torch.topk(lower_triangle_distances, k=k, largest=False)
	
	# Convert flat indices back to 2D coordinates
	layer_min_distances = []
	layer_min_pairs = []
	
	# Get indices of lower triangle elements
	lower_i, lower_j = torch.where(mask)
	
	for idx in range(k):
		flat_idx = topk_indices[idx].item()
		lang_i_idx = lower_i[flat_idx].item()
		lang_j_idx = lower_j[flat_idx].item()
		
		layer_min_distances.append(topk_min_values[idx].item())
		layer_min_pairs.append((languages[lang_i_idx], languages[lang_j_idx]))
	
	all_min_distances.append(layer_min_distances)
	all_min_distance_pairs.append(layer_min_pairs)

# Create a detailed DataFrame
detailed_results = []
for layer_id in range(-1, num_layers):
	for rank in range(k):
		lang1, lang2 = all_min_distance_pairs[layer_id + 1][rank]
		detailed_results.append({
			'Layer': layer_id,
			'Rank': rank + 1,
			'Min_Distance': all_min_distances[layer_id + 1][rank],
			'Language_Pair': f"{LANGCODE2LANGNAME[lang1]} ({lang1}) - {LANGCODE2LANGNAME[lang2]} ({lang2})"
		})

results_df = pd.DataFrame(detailed_results)

print(f"Top {k} minimum distances per layer:")
print(results_df.to_string(index=False))

# Also print a summary view grouped by layer
print("\nSummary by layer:")
for layer_id in range(-1, num_layers):
	print(f"\nLayer {layer_id:2d}:")
	for rank in range(k):
		lang1, lang2 = all_min_distance_pairs[layer_id + 1][rank]
		distance = all_min_distances[layer_id + 1][rank]
		print(f"  {rank+1}. {distance:.6f} - {LANGCODE2LANGNAME[lang1]} ({lang1}) & {LANGCODE2LANGNAME[lang2]} ({lang2})")

In [None]:
# Calculate mean and std deviation of the distances for each layer
layer_means = silhouette_score_matrix.mean(dim=(1, 2)).cpu().numpy()
layer_stds = silhouette_score_matrix.std(dim=(1, 2)).cpu().numpy()

# Create a DataFrame to display the statistics
layer_stats_df = pd.DataFrame({
	'Layer': range(-1, num_layers),
	'Mean_Distance': layer_means,
	'Std_Distance': layer_stds
})

print("Layer-wise Distance Statistics:")
print(layer_stats_df.to_string(index=False))

In [None]:
lengths_per_family = {}
for family, langs in EXP3_CONFIG['languages'].items():
    lengths_per_family[family] = len(langs)


In [None]:
lengths_per_family

In [None]:
# Heatmap per layer of the silhouette scores between languages, make a subplot for each layer with 4 columns
import seaborn as sns
import matplotlib.pyplot as plt

# Find global min and max across all layers for consistent scale
global_min = float('inf')
global_max = float('-inf')
for layer_id in range(-1, num_layers):
	layer_distances = silhouette_score_matrix[layer_id + 1]
	global_min = min(global_min, layer_distances.min().item())
	global_max = max(global_max, layer_distances.max().item())

fig, axes = plt.subplots(nrows=(num_layers + 1) // 4 + 1, ncols=4, figsize=(20, 5 * ((num_layers + 1) // 4 + 1)))
axes = axes.flatten()
for layer_id in tqdm(range(0, num_layers)):
	layer_distances = silhouette_score_matrix[layer_id + 1]
	sns.heatmap(layer_distances.cpu().numpy(), 
		xticklabels=[LANGCODE2LANGNAME[lang] for lang in languages], 
		yticklabels=[LANGCODE2LANGNAME[lang] for lang in languages], 
		ax=axes[layer_id], 
		cmap='viridis_r',
		vmin=global_min,
		vmax=global_max,
		# annot=True,
		# fmt='.2f',
	)
	axes[layer_id].set_title(f'Layer {layer_id} Silhouette Scores')
	plt.setp(axes[layer_id].get_xticklabels(), rotation=90, ha='right', rotation_mode='anchor')

	# Setup fontsize of the x and y label
	axes[layer_id].tick_params(axis='x', labelsize=5)
	axes[layer_id].tick_params(axis='y', labelsize=5)
	
	# Add bolded grid lines to separate language families
	# Calculate cumulative positions for family boundaries
	cumulative_langs = 0
	for family, num_langs in lengths_per_family.items():
		cumulative_langs += num_langs
		# Draw horizontal and vertical lines at family boundaries
		axes[layer_id].axhline(cumulative_langs, color='black', linewidth=1.5)
		axes[layer_id].axvline(cumulative_langs, color='black', linewidth=1.5)

	# if layer_id + 1 == 3:
	# 	break
	
plt.tight_layout()
plt.savefig(f'silhouette_scores_heatmap_{model_name}_testing.png', dpi=300, bbox_inches='tight')
plt.close()

In [None]:
import seaborn as sns

# Make 5 bins of the silhouette scores and color the heatmap accordingly
import matplotlib.pyplot as plt

# Find global min and max across all layers for consistent scale
global_min = float('inf')
global_max = float('-inf')
for layer_id in range(-1, num_layers):
	layer_distances = silhouette_score_matrix[layer_id + 1]
	# Exclude diagonal elements (which are 0)
	mask_no_diag = ~torch.eye(len(languages), dtype=torch.bool, device='cuda')
	non_diag_values = layer_distances[mask_no_diag]
	global_min = min(global_min, non_diag_values.min().item())
	global_max = max(global_max, non_diag_values.max().item())

# Create 3 bins
n_bins = 5
bin_edges = np.linspace(global_min, global_max, n_bins + 1)

# Create subplots
fig, axes = plt.subplots(nrows=(num_layers + 1) // 4 + 1, ncols=4, figsize=(20, 5 * ((num_layers + 1) // 4 + 1)))
axes = axes.flatten()

for layer_id in tqdm(range(0, num_layers)):
	layer_distances = silhouette_score_matrix[layer_id + 1].cpu().numpy()
	
	# Discretize the distances into bins
	binned_distances = np.digitize(layer_distances, bin_edges) - 1
	binned_distances = np.clip(binned_distances, 0, n_bins - 1)
	
	sns.heatmap(binned_distances, 
		xticklabels=[LANGCODE2LANGNAME[lang] for lang in languages], 
		yticklabels=[LANGCODE2LANGNAME[lang] for lang in languages], 
		ax=axes[layer_id], 
		cmap='viridis_r',
		vmin=0,
		vmax=n_bins - 1,
		cbar_kws={'label': 'Bin', 'ticks': range(n_bins)}
	)
	axes[layer_id].set_title(f'Layer {layer_id} Binned Silhouette Scores')
	plt.setp(axes[layer_id].get_xticklabels(), rotation=90, ha='right', rotation_mode='anchor')

	# Setup fontsize of the x and y label
	axes[layer_id].tick_params(axis='x', labelsize=5)
	axes[layer_id].tick_params(axis='y', labelsize=5)
	
	# Add bolded grid lines to separate language families
	cumulative_langs = 0
	for family, num_langs in lengths_per_family.items():
		cumulative_langs += num_langs
		axes[layer_id].axhline(cumulative_langs, color='black', linewidth=1.5)
		axes[layer_id].axvline(cumulative_langs, color='black', linewidth=1.5)
	
plt.tight_layout()
plt.savefig(f'silhouette_scores_binned{n_bins}_heatmap_{model_name}.png', dpi=300, bbox_inches='tight')
plt.close()

print(f"Bin edges: {bin_edges}")

In [None]:
import numpy as np
from sklearn.cluster import AgglomerativeClustering
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import dendrogram, linkage

# Create subplots for dendrograms across all layers
fig, axes = plt.subplots(nrows=9, ncols=5, figsize=(35, 90))
fig.suptitle('Hierarchical Clustering Dendrograms Across All Layers', fontsize=16)

# Flatten axes for easier indexing
axes = axes.flatten()

# First pass: find the maximum distance across all layers
max_distance = 0
linkage_results = []
for layer_id in range(-1, num_layers):
	linked = linkage(silhouette_score_matrix[layer_id + 1].cpu().numpy(), 'complete')
	linkage_results.append(linked)
	max_distance = max(max_distance, linked[:, 2].max())

# Second pass: create dendrograms with consistent scale
for layer_id in tqdm(range(0, num_layers)):
	ax_idx = layer_id + 1
	
	# Create dendrogram in the corresponding subplot
	dendrogram(linkage_results[ax_idx], 
			   labels=[LANGCODE2LANGNAME[lang] for lang in languages], 
			   orientation='right',
			   ax=axes[ax_idx - 1])
	
	axes[ax_idx - 1].set_title(f"Layer {layer_id}", fontsize=12)
	axes[ax_idx - 1].tick_params(axis='y', labelsize=6)
	axes[ax_idx - 1].tick_params(axis='x', labelsize=6)
	axes[ax_idx - 1].set_xlim([0, max_distance])  # Set consistent x-axis limits

# Save image
image_path = f'plot/silhouette_dendrograms_{model_name}_70lang.png'
os.makedirs('plot', exist_ok=True)
plt.savefig(image_path, dpi=300, bbox_inches='tight')
plt.close()


In [None]:
import numpy as np
from sklearn.cluster import AgglomerativeClustering
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import dendrogram, linkage
# Create subplots for dendrograms across all layers
fig, axes = plt.subplots(nrows=9, ncols=5, figsize=(25, 80))
fig.suptitle('Hierarchical Clustering Dendrograms Across All Layers', fontsize=16)

# Flatten axes for easier indexing
axes = axes.flatten()

for layer_id in range(-1, num_layers):
	ax_idx = layer_id + 1
	
	# Perform clustering for this layer
	linked = linkage(silhouette_score_matrix[layer_id + 1].cpu().numpy(), 'complete')
	
	# Create dendrogram in the corresponding subplot
	dendrogram(linked, 
			   labels=[LANGCODE2LANGNAME[lang] for lang in languages], 
			   orientation='right',
			   ax=axes[ax_idx])
	
	axes[ax_idx].set_title(f"Layer {layer_id}", fontsize=12)
	axes[ax_idx].tick_params(axis='y', labelsize=8)
	axes[ax_idx].tick_params(axis='x', labelsize=8)

plt.tight_layout()
plt.show()
