In [1]:
import os
import sys

# Change to parent directory
os.chdir('..')
# Add current directory to path
sys.path.append('.')

In [7]:
import pandas as pd
import numpy as np
from src.utils.const import LANGCODE2LANGNAME, LANGNAME2LANGCODE, MODEL2HIDDEN_SIZE, MODEL2NUM_LAYERS, EXP2_CONFIG
from src.utils.metrics import AlignmentMetrics
import glob
import torch
from tqdm import tqdm

In [8]:
model_name = 'gemma-3-4b-it' # 1152 size for 1b, 2560 for 4b
model_to_num_layers = {
    'gemma-3-1b-it': 26,
	'gemma-3-4b-it': 34,
    'gemma-3-270m-it': 18,
    'gemma-2-9b-it': 42
}
model_to_hidden_size = {
	'gemma-3-1b-it': 1152,
	'gemma-3-4b-it': 2560,
	'gemma-3-270m-it': 640,
	'gemma-2-9b-it': 3584
}
num_layers = model_to_num_layers[model_name]
extraction_mode = 'raw'
token_position = 'last_token'
languages = glob.glob('outputs/topic_classification/gemma-3-4b-it/raw/*')
languages = sorted([lang.split('/')[-1] for lang in languages])
text_ids = glob.glob('outputs/topic_classification/gemma-3-4b-it/raw/ace_Latn/*')
text_ids = [text_id.split('/')[-1].split('.')[0] for text_id in text_ids]

In [9]:
# # Extract language codes from the comment and add eng_Latn
# languages = [
# 	'vie_Latn',  # Vietnamese
# 	'ind_Latn',  # Indonesian
# 	'tha_Thai',  # Thai
# 	'zsm_Latn',  # Malay
# 	'mya_Mymr',  # Burmese -> 0 problem
# 	'tgl_Latn',  # Tagalog
# 	'khm_Khmr',  # Khmer
# 	'ceb_Latn',  # Cebuano
# 	'lao_Laoo',  # Lao
# 	'jav_Latn',  # Javanese
# 	'war_Latn',  # Waray
# 	'sun_Latn',  # Sundanese
# 	'ilo_Latn',  # Ilocano
# 	'tam_Taml',  # Tamil
# 	'zho_Hans',  # Chinese
# 	'eng_Latn'   # English
# ]

languages = []
for family, langs in EXP2_CONFIG['languages'].items():
	languages.extend(langs)

languages = [LANGNAME2LANGCODE[lang] for lang in languages]
languages

['eng_Latn',
 'deu_Latn',
 'spa_Latn',
 'fra_Latn',
 'arb_Arab',
 'heb_Hebr',
 'rus_Cyrl',
 'slk_Latn',
 'tha_Thai',
 'lao_Laoo',
 'khm_Khmr',
 'vie_Latn',
 'swh_Latn',
 'xho_Latn',
 'zul_Latn',
 'urd_Arab',
 'hin_Deva',
 'kan_Knda',
 'tel_Telu',
 'jpn_Jpan',
 'kor_Hang',
 'tur_Latn',
 'azb_Arab',
 'azj_Latn',
 'tgl_Latn',
 'ceb_Latn',
 'ilo_Latn',
 'war_Latn',
 'yue_Hant',
 'zho_Hans',
 'ind_Latn',
 'zsm_Latn',
 'min_Latn',
 'min_Arab',
 'bjn_Latn',
 'bjn_Arab',
 'jav_Latn',
 'sun_Latn']

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Initialize empty torch tensor to hold all activations [text_id, layer_id, language, hidden_size]
activation_per_lang = torch.zeros((len(text_ids), num_layers + 1, len(languages), MODEL2HIDDEN_SIZE[model_name])).to(device)
print(f'Activation tensor shape: {activation_per_lang.shape}')

# # Reshape to [text_id * language, layer_id, hidden_size]
# activation_per_lang = activation_per_lang.view(-1, num_layers + 1, MODEL2HIDDEN_SIZE[model_name])
# print(f'Reshaped activation tensor shape: {activation_per_lang.shape}')

# # Initialize empty torch tensor to hold all labels [text_id, language]
# labels = torch.zeros((len(text_ids) * len(languages),), dtype=torch.long).to('cuda')
# print(f'Labels tensor shape: {labels.shape}')

Activation tensor shape: torch.Size([204, 35, 38, 2560])


In [11]:
activation_per_lang.shape

torch.Size([204, 35, 38, 2560])

In [None]:
# Load activations for all languages
id2langcode = {idx: lang for idx, lang in enumerate(languages)}
langcode2id = {lang: idx for idx, lang in enumerate(languages)}
text_idx = 0
for lang_idx, lang in tqdm(enumerate(languages), total=len(languages), desc='Loading activations for all languages'):
	for layer_id in range(-1, num_layers):
		if layer_id == -1:
			paths = sorted(glob.glob(f'outputs/topic_classification/{model_name}/{extraction_mode}/{lang}/*/{token_position}/layer_embed_tokens.pt'))
		else:
			paths = sorted(glob.glob(f'outputs/topic_classification/{model_name}/{extraction_mode}/{lang}/*/{token_position}/layer_{layer_id}.pt'))
		if len(paths) != len(text_ids):
			print(f"Warning: Expected {len(text_ids)} files for language '{lang}' at layer {layer_id}, but found {len(paths)} files.")
			break
		for path in paths:
			activation = torch.load(path)
			activation_per_lang[text_idx, layer_id + 1, lang_idx, :] = activation.to(device)
			text_idx += 1
		text_idx = 0

Loading activations for all languages: 100%|██████████| 38/38 [02:38<00:00,  4.17s/it]


In [13]:
import torch.nn.functional as F

feats_A = torch.randn(64, 8192)
feats_B = torch.randn(64, 8192)
feats_A = F.normalize(feats_A, dim=-1)
feats_B = F.normalize(feats_B, dim=-1)

# measure score
score = AlignmentMetrics.measure('cknna', feats_A, feats_B, topk=10)

# alternative
score = AlignmentMetrics.cknna(feats_A, feats_B, topk=10)

In [14]:
score

0.1209156326036149

In [35]:
# 4. Calculate the silhouette score on the GPU
# The function takes the data and the predicted labels as input.
cka_matrix = torch.zeros((num_layers + 1, len(languages), len(languages))).to(device)
for layer_id in tqdm(range(-1, num_layers)):
	# Calculate pairwise silhouette scores per language
	for lang_idx1 in range(len(languages)):
		for lang_idx2 in range(len(languages)):
			if lang_idx1 != lang_idx2:
				# Get activations for both languages
				activations_lang1 = activation_per_lang[:, layer_id + 1, lang_idx1, :]
				activations_lang2 = activation_per_lang[:, layer_id + 1, lang_idx2, :]

				# Normalize activations
				activations_lang1 = F.normalize(activations_lang1, dim=-1)
				activations_lang2 = F.normalize(activations_lang2, dim=-1)

				# Store score in matrix
				cka_matrix[layer_id + 1, lang_idx1, lang_idx2] = AlignmentMetrics.cka(activations_lang1, activations_lang2)
				# cka_matrix[layer_id + 1, lang_idx1, lang_idx2] = AlignmentMetrics.cknna(activations_lang1, activations_lang2, topk=10)

				# print(f"Wasserstein Distance between {id2langcode[lang_idx1]} and {id2langcode[lang_idx2]}: {score:.4f}")

# For this well-separated data, you should see a high score (e.g., > 0.8)
# A high score indicates that the clusters are dense and well-separated.

100%|██████████| 35/35 [00:34<00:00,  1.03it/s]


In [36]:
cka_matrix.shape

torch.Size([35, 38, 38])

In [37]:
# Get top k maximum distance values in each layer with their language pair identifiers

# Set k as a variable
k = 50  # You can change this value as needed

# Initialize lists to store results
all_max_distances = []
all_max_distance_pairs = []

# For each layer, find the top k maximum distances and corresponding language pairs
for layer_id in range(-1, num_layers):
	layer_distances = cka_matrix[layer_id + 1]
	
	# Create a mask to exclude diagonal elements and upper triangle (to avoid duplicates)
	mask = torch.tril(torch.ones_like(layer_distances, dtype=torch.bool), diagonal=-1)
	
	# Get all distances excluding diagonal and upper triangle, then flatten
	lower_triangle_distances = layer_distances[mask]
	
	# Get top k maximum distances
	topk_max_values, topk_indices = torch.topk(lower_triangle_distances, k=k, largest=True)
	
	# Convert flat indices back to 2D coordinates
	layer_max_distances = []
	layer_max_pairs = []
	
	# Get indices of lower triangle elements
	lower_i, lower_j = torch.where(mask)
	
	for idx in range(k):
		flat_idx = topk_indices[idx].item()
		lang_i_idx = lower_i[flat_idx].item()
		lang_j_idx = lower_j[flat_idx].item()
		
		layer_max_distances.append(topk_max_values[idx].item())
		layer_max_pairs.append((languages[lang_i_idx], languages[lang_j_idx]))
	
	all_max_distances.append(layer_max_distances)
	all_max_distance_pairs.append(layer_max_pairs)

# Create a detailed DataFrame
detailed_results = []
for layer_id in range(-1, num_layers):
	for rank in range(k):
		lang1, lang2 = all_max_distance_pairs[layer_id + 1][rank]
		detailed_results.append({
			'Layer': layer_id,
			'Rank': rank + 1,
			'Max_Distance': all_max_distances[layer_id + 1][rank],
			'Language_Pair': f"{LANGCODE2LANGNAME[lang1]} ({lang1}) - {LANGCODE2LANGNAME[lang2]} ({lang2})"
		})

results_df = pd.DataFrame(detailed_results)

print(f"Top {k} maximum distances per layer:")
print(results_df.to_string(index=False))

# Also print a summary view grouped by layer
print("\nSummary by layer:")
for layer_id in range(-1, num_layers):
	print(f"\nLayer {layer_id:2d}:")
	for rank in range(k):
		lang1, lang2 = all_max_distance_pairs[layer_id + 1][rank]
		distance = all_max_distances[layer_id + 1][rank]
		print(f"  {rank+1}. {distance:.6f} - {LANGCODE2LANGNAME[lang1]} ({lang1}) & {LANGCODE2LANGNAME[lang2]} ({lang2})")

Top 50 maximum distances per layer:
 Layer  Rank  Max_Distance                                                        Language_Pair
    -1     1      0.004707                               German (deu_Latn) - English (eng_Latn)
    -1     2      0.004707                              Spanish (spa_Latn) - English (eng_Latn)
    -1     3      0.004707                               Spanish (spa_Latn) - German (deu_Latn)
    -1     4      0.004707                               French (fra_Latn) - English (eng_Latn)
    -1     5      0.004707                                French (fra_Latn) - German (deu_Latn)
    -1     6      0.004707                               French (fra_Latn) - Spanish (spa_Latn)
    -1     7      0.004707                         MSA (Arabic) (arb_Arab) - English (eng_Latn)
    -1     8      0.004707                          MSA (Arabic) (arb_Arab) - German (deu_Latn)
    -1     9      0.004707                         MSA (Arabic) (arb_Arab) - Spanish (spa_Latn)
    

In [38]:
# Calculate mean and std deviation of the distances for each layer
layer_means = cka_matrix.mean(dim=(1, 2)).cpu().numpy()
layer_stds = cka_matrix.std(dim=(1, 2)).cpu().numpy()

# Create a DataFrame to display the statistics
layer_stats_df = pd.DataFrame({
	'Layer': range(-1, num_layers),
	'Mean_Distance': layer_means,
	'Std_Distance': layer_stds
})

print("Layer-wise Distance Statistics:")
print(layer_stats_df.to_string(index=False))

Layer-wise Distance Statistics:
 Layer  Mean_Distance  Std_Distance
    -1       0.004583      0.000754
     0       0.512486      0.135642
     1       0.454054      0.118735
     2       0.247477      0.109502
     3       0.277518      0.104709
     4       0.363834      0.117647
     5       0.449188      0.110776
     6       0.553071      0.146932
     7       0.559505      0.143927
     8       0.604105      0.150326
     9       0.603123      0.158799
    10       0.582137      0.157613
    11       0.539116      0.156634
    12       0.503287      0.161996
    13       0.470749      0.158534
    14       0.458775      0.156725
    15       0.469961      0.164584
    16       0.449368      0.171718
    17       0.455619      0.176292
    18       0.432363      0.179169
    19       0.427790      0.192119
    20       0.412117      0.197841
    21       0.403720      0.204201
    22       0.387681      0.200641
    23       0.379895      0.201768
    24       0.374747      0.204

In [40]:
lengths_per_family = {}
for family, langs in EXP2_CONFIG['languages'].items():
    lengths_per_family[family] = len(langs)

In [42]:
# Heatmap per layer of the silhouette scores between languages, make a subplot for each layer with 4 columns
import seaborn as sns
import matplotlib.pyplot as plt

# Find global min and max across all layers for consistent scale
global_min = float('inf')
global_max = float('-inf')
for layer_id in range(-1, num_layers):
	layer_distances = cka_matrix[layer_id + 1]
	global_min = min(global_min, layer_distances.min().item())
	global_max = max(global_max, layer_distances.max().item())

fig, axes = plt.subplots(nrows=(num_layers + 1) // 4 + 1, ncols=4, figsize=(20, 5 * ((num_layers + 1) // 4 + 1)))
axes = axes.flatten()
for layer_id in tqdm(range(0, num_layers)):
	layer_distances = cka_matrix[layer_id + 1]
	sns.heatmap(layer_distances.cpu().numpy(), 
		xticklabels=[LANGCODE2LANGNAME[lang] for lang in languages], 
		yticklabels=[LANGCODE2LANGNAME[lang] for lang in languages], 
		ax=axes[layer_id], 
		cmap='viridis',
		vmin=global_min,
		vmax=global_max,
		# annot=True,
		# fmt='.2f',
	)
	axes[layer_id].set_title(f'Layer {layer_id} Silhouette Scores')
	plt.setp(axes[layer_id].get_xticklabels(), rotation=90, ha='right', rotation_mode='anchor')

	# Setup fontsize of the x and y label
	axes[layer_id].tick_params(axis='x', labelsize=5)
	axes[layer_id].tick_params(axis='y', labelsize=5)
	
	# Add bolded grid lines to separate language families
	# Calculate cumulative positions for family boundaries
	cumulative_langs = 0
	for family, num_langs in lengths_per_family.items():
		cumulative_langs += num_langs
		# Draw horizontal and vertical lines at family boundaries
		axes[layer_id].axhline(cumulative_langs, color='black', linewidth=1.5)
		axes[layer_id].axvline(cumulative_langs, color='black', linewidth=1.5)

	# if layer_id + 1 == 3:
	# 	break
	
plt.tight_layout()
plt.savefig(f'cka_heatmap_{model_name}.png', dpi=300, bbox_inches='tight')
plt.close()

100%|██████████| 34/34 [01:22<00:00,  2.41s/it]


In [None]:
import seaborn as sns

# Make 5 bins of the silhouette scores and color the heatmap accordingly
import matplotlib.pyplot as plt

# Find global min and max across all layers for consistent scale
global_min = float('inf')
global_max = float('-inf')
for layer_id in range(-1, num_layers):
	layer_distances = cka_matrix[layer_id + 1]
	# Exclude diagonal elements (which are 0)
	mask_no_diag = ~torch.eye(len(languages), dtype=torch.bool, device='cuda')
	non_diag_values = layer_distances[mask_no_diag]
	global_min = min(global_min, non_diag_values.min().item())
	global_max = max(global_max, non_diag_values.max().item())

# Create 3 bins
n_bins = 3
bin_edges = np.linspace(global_min, global_max, n_bins + 1)

# Create subplots
fig, axes = plt.subplots(nrows=(num_layers + 1) // 4 + 1, ncols=4, figsize=(20, 5 * ((num_layers + 1) // 4 + 1)))
axes = axes.flatten()

for layer_id in tqdm(range(0, num_layers)):
	layer_distances = cka_matrix[layer_id + 1].cpu().numpy()
	
	# Discretize the distances into bins
	binned_distances = np.digitize(layer_distances, bin_edges) - 1
	binned_distances = np.clip(binned_distances, 0, n_bins - 1)
	
	sns.heatmap(binned_distances, 
		xticklabels=[LANGCODE2LANGNAME[lang] for lang in languages], 
		yticklabels=[LANGCODE2LANGNAME[lang] for lang in languages], 
		ax=axes[layer_id], 
		cmap='viridis',
		vmin=0,
		vmax=n_bins - 1,
		cbar_kws={'label': 'Bin', 'ticks': range(n_bins)}
	)
	axes[layer_id].set_title(f'Layer {layer_id} Binned Silhouette Scores')
	plt.setp(axes[layer_id].get_xticklabels(), rotation=90, ha='right', rotation_mode='anchor')

	# Setup fontsize of the x and y label
	axes[layer_id].tick_params(axis='x', labelsize=5)
	axes[layer_id].tick_params(axis='y', labelsize=5)
	
	# Add bolded grid lines to separate language families
	cumulative_langs = 0
	for family, num_langs in lengths_per_family.items():
		cumulative_langs += num_langs
		axes[layer_id].axhline(cumulative_langs, color='black', linewidth=1.5)
		axes[layer_id].axvline(cumulative_langs, color='black', linewidth=1.5)
	
plt.tight_layout()
plt.savefig(f'silhouette_scores_binned{n_bins}_heatmap_{model_name}.png', dpi=300, bbox_inches='tight')
plt.close()

print(f"Bin edges: {bin_edges}")

In [None]:
import numpy as np
from sklearn.cluster import AgglomerativeClustering
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import dendrogram, linkage

# Create subplots for dendrograms across all layers
fig, axes = plt.subplots(nrows=9, ncols=5, figsize=(25, 80))
fig.suptitle('Hierarchical Clustering Dendrograms Across All Layers', fontsize=16)

# Flatten axes for easier indexing
axes = axes.flatten()

# First pass: find the maximum distance across all layers
max_distance = 0
linkage_results = []
for layer_id in range(-1, num_layers):
	linked = linkage(silhouette_score_matrix[layer_id + 1].cpu().numpy(), 'complete')
	linkage_results.append(linked)
	max_distance = max(max_distance, linked[:, 2].max())

# Second pass: create dendrograms with consistent scale
for layer_id in range(-1, num_layers):
	ax_idx = layer_id + 1
	
	# Create dendrogram in the corresponding subplot
	dendrogram(linkage_results[ax_idx], 
			   labels=[LANGCODE2LANGNAME[lang] for lang in languages], 
			   orientation='right',
			   ax=axes[ax_idx])
	
	axes[ax_idx].set_title(f"Layer {layer_id}", fontsize=12)
	axes[ax_idx].tick_params(axis='y', labelsize=8)
	axes[ax_idx].tick_params(axis='x', labelsize=8)
	axes[ax_idx].set_xlim([0, max_distance])  # Set consistent x-axis limits

plt.tight_layout()
plt.show()


In [None]:
import numpy as np
from sklearn.cluster import AgglomerativeClustering
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import dendrogram, linkage
# Create subplots for dendrograms across all layers
fig, axes = plt.subplots(nrows=9, ncols=5, figsize=(25, 80))
fig.suptitle('Hierarchical Clustering Dendrograms Across All Layers', fontsize=16)

# Flatten axes for easier indexing
axes = axes.flatten()

for layer_id in range(-1, num_layers):
	ax_idx = layer_id + 1
	
	# Perform clustering for this layer
	linked = linkage(silhouette_score_matrix[layer_id + 1].cpu().numpy(), 'complete')
	
	# Create dendrogram in the corresponding subplot
	dendrogram(linked, 
			   labels=[LANGCODE2LANGNAME[lang] for lang in languages], 
			   orientation='right',
			   ax=axes[ax_idx])
	
	axes[ax_idx].set_title(f"Layer {layer_id}", fontsize=12)
	axes[ax_idx].tick_params(axis='y', labelsize=8)
	axes[ax_idx].tick_params(axis='x', labelsize=8)

plt.tight_layout()
plt.show()
