In [1]:
import os
import sys

# Change to parent directory
os.chdir('..')
# Add current directory to path
sys.path.append('.')

In [2]:
import pandas as pd
import numpy as np
from src.utils.const import LANGCODE2LANGNAME, LANGNAME2LANGCODE, MODEL2HIDDEN_SIZE, MODEL2NUM_LAYERS, EXP2_CONFIG
import glob
import torch
from tqdm import tqdm

In [3]:
model_name = 'gemma-3-4b-it' # 1152 size for 1b, 2560 for 4b
model_to_num_layers = {
    'gemma-3-1b-it': 26,
	'gemma-3-4b-it': 34,
    'gemma-3-270m-it': 18,
    'gemma-2-9b-it': 42
}
model_to_hidden_size = {
	'gemma-3-1b-it': 1152,
	'gemma-3-4b-it': 2560,
	'gemma-3-270m-it': 640,
	'gemma-2-9b-it': 3584
}
num_layers = model_to_num_layers[model_name]
extraction_mode = 'raw'
token_position = 'last_token'
languages = glob.glob('outputs/topic_classification/gemma-3-4b-it/raw/*')
languages = sorted([lang.split('/')[-1] for lang in languages])
text_ids = glob.glob('outputs/topic_classification/gemma-3-4b-it/raw/ace_Latn/*')
text_ids = [text_id.split('/')[-1].split('.')[0] for text_id in text_ids]

In [4]:
# # Extract language codes from the comment and add eng_Latn
# languages = [
# 	'vie_Latn',  # Vietnamese
# 	'ind_Latn',  # Indonesian
# 	'tha_Thai',  # Thai
# 	'zsm_Latn',  # Malay
# 	'mya_Mymr',  # Burmese -> 0 problem
# 	'tgl_Latn',  # Tagalog
# 	'khm_Khmr',  # Khmer
# 	'ceb_Latn',  # Cebuano
# 	'lao_Laoo',  # Lao
# 	'jav_Latn',  # Javanese
# 	'war_Latn',  # Waray
# 	'sun_Latn',  # Sundanese
# 	'ilo_Latn',  # Ilocano
# 	'tam_Taml',  # Tamil
# 	'zho_Hans',  # Chinese
# 	'eng_Latn'   # English
# ]

languages = []
for family, langs in EXP2_CONFIG['languages'].items():
	languages.extend(langs)

languages = [LANGNAME2LANGCODE[lang] for lang in languages]
languages

['eng_Latn',
 'deu_Latn',
 'spa_Latn',
 'fra_Latn',
 'arb_Arab',
 'heb_Hebr',
 'rus_Cyrl',
 'slk_Latn',
 'tha_Thai',
 'lao_Laoo',
 'khm_Khmr',
 'vie_Latn',
 'swh_Latn',
 'xho_Latn',
 'zul_Latn',
 'urd_Arab',
 'hin_Deva',
 'kan_Knda',
 'tel_Telu',
 'jpn_Jpan',
 'kor_Hang',
 'tur_Latn',
 'azb_Arab',
 'azj_Latn',
 'tgl_Latn',
 'ceb_Latn',
 'ilo_Latn',
 'war_Latn',
 'yue_Hant',
 'zho_Hans',
 'ind_Latn',
 'zsm_Latn',
 'min_Latn',
 'min_Arab',
 'bjn_Latn',
 'bjn_Arab',
 'jav_Latn',
 'sun_Latn']

In [5]:
# Check if a GPU is available and set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [6]:
# Initialize empty torch tensor to hold all activations [text_id, layer_id, language, hidden_size]
activation_per_lang = torch.zeros((len(text_ids), num_layers + 1, len(languages), MODEL2HIDDEN_SIZE[model_name])).to(device)
print(f'Activation tensor shape: {activation_per_lang.shape}')

# Reshape to [text_id * language, layer_id, hidden_size]
activation_per_lang = activation_per_lang.view(-1, num_layers + 1, MODEL2HIDDEN_SIZE[model_name])
print(f'Reshaped activation tensor shape: {activation_per_lang.shape}')

# Initialize empty torch tensor to hold all labels [text_id, language]
labels = torch.zeros((len(text_ids) * len(languages),), dtype=torch.long).to(device)
print(f'Labels tensor shape: {labels.shape}')

Activation tensor shape: torch.Size([204, 35, 38, 2560])
Reshaped activation tensor shape: torch.Size([7752, 35, 2560])
Labels tensor shape: torch.Size([7752])


In [7]:
activation_per_lang.shape

torch.Size([7752, 35, 2560])

In [8]:
# # Load activations for all languages
# id2langcode = {idx: lang for idx, lang in enumerate(languages)}
# langcode2id = {lang: idx for idx, lang in enumerate(languages)}
# text_idx = 0
# for lang_idx, lang in tqdm(enumerate(languages), total=len(languages), desc='Loading activations for all languages'):
# 	for layer_id in range(-1, num_layers):
# 		if layer_id == -1:
# 			paths = sorted(glob.glob(f'outputs/topic_classification/{model_name}/{extraction_mode}/{lang}/*/{token_position}/layer_embed_tokens.pt'))
# 		else:
# 			paths = sorted(glob.glob(f'outputs/topic_classification/{model_name}/{extraction_mode}/{lang}/*/{token_position}/layer_{layer_id}.pt'))
# 		if len(paths) != len(text_ids):
# 			print(f"Warning: Expected {len(text_ids)} files for language '{lang}' at layer {layer_id}, but found {len(paths)} files.")
# 			break
# 		for path in paths:
# 			activation = torch.load(path)
# 			activation_per_lang[text_idx, layer_id + 1, lang_idx, :] = activation.to('cuda')
# 			text_idx += 1
# 		text_idx = 0

In [9]:
# Load activations for all languages
text_idx = 0
id2langcode = {idx: lang for idx, lang in enumerate(languages)}
langcode2id = {lang: idx for idx, lang in enumerate(languages)}

for lang_idx, lang in tqdm(enumerate(languages), total=len(languages), desc='Loading activations for all languages'):
	for layer_id in range(-1, num_layers):
		if layer_id == -1:
			paths = sorted(glob.glob(f'outputs/topic_classification/{model_name}/{extraction_mode}/{lang}/*/{token_position}/layer_embed_tokens.pt'))
		else:
			paths = sorted(glob.glob(f'outputs/topic_classification/{model_name}/{extraction_mode}/{lang}/*/{token_position}/layer_{layer_id}.pt'))
		if len(paths) != len(text_ids):
			print(f"Warning: Expected {len(text_ids)} files for language '{lang}' at layer {layer_id}, but found {len(paths)} files.")
			break
		for text_idx_inner, path in enumerate(paths):
			activation = torch.load(path)
			# Calculate the flat index for the reshaped tensor
			flat_idx = text_idx_inner * len(languages) + lang_idx
			activation_per_lang[flat_idx, layer_id + 1, :] = activation.to(device)
			# Populate labels with language_id (only need to do this once per text-language pair)
			if layer_id == -1:
				labels[flat_idx] = lang_idx


Loading activations for all languages: 100%|██████████| 38/38 [02:04<00:00,  3.28s/it]


In [10]:
activation_per_lang.shape

torch.Size([7752, 35, 2560])

In [11]:
labels.shape

torch.Size([7752])

In [12]:
activation_per_lang[:, layer_id + 1, :].shape

torch.Size([7752, 2560])

### JS Div

In [None]:
import torch
import torch.nn.functional as F

def js_divergence_pytorch(p, q, reduction='sum'):
    """
    Calculates the Jensen-Shannon Divergence between two probability distributions
    using PyTorch.
    
    Args:
        p (torch.Tensor): The first probability distribution.
        q (torch.Tensor): The second probability distribution.
        reduction (str): Specifies the reduction to apply to the output: 
                         'none' | 'batchmean' | 'sum' | 'mean'.
                         Default: 'sum'.
    """
    # Calculate the midpoint distribution
    m = 0.5 * (p + q)
    
    # Calculate the KL divergences. F.kl_div expects log-probabilities for the input.
    # The target (second argument) should be probabilities.
    kl_p_m = F.kl_div(m.log(), p, reduction=reduction)
    kl_q_m = F.kl_div(m.log(), q, reduction=reduction)
    
    # Calculate the JS Divergence
    jsd = 0.5 * (kl_p_m + kl_q_m)
    
    return jsd



# You can also compute it for a batch of distributions
p_batch = torch.tensor([[0.1, 0.9], [0.8, 0.2]], device=device)
q_batch = torch.tensor([[0.2, 0.8], [0.7, 0.3]], device=device)

# 'batchmean' averages the divergence over the batch dimension
jsd_batch = js_divergence_pytorch(p_batch, q_batch, reduction='batchmean')
print(f"JS Divergence for batch (PyTorch, GPU): {jsd_batch.item():.4f}")

In [None]:
# 4. Calculate the silhouette score on the GPU
# The function takes the data and the predicted labels as input.
js_matrix = torch.zeros((num_layers + 1, len(languages), len(languages))).to(device)
for layer_id in tqdm(range(-1, num_layers)):
	# Calculate pairwise silhouette scores per language
	for lang_idx1 in range(len(languages)):
		for lang_idx2 in range(len(languages)):
			if lang_idx1 != lang_idx2:
				# Take all index of labels that have value lang_idx1 and lang_idx2
				indexes_lang1 = (labels == lang_idx1).nonzero(as_tuple=True)[0]
				indexes_lang2 = (labels == lang_idx2).nonzero(as_tuple=True)[0]

				# Take activations for those indexes
				activations_lang1 = activation_per_lang[indexes_lang1, layer_id + 1, :]
				activations_lang2 = activation_per_lang[indexes_lang2, layer_id + 1, :]
				labels_lang1 = labels[indexes_lang1]
				labels_lang2 = labels[indexes_lang2]

				# Calculate JS Div between the two sets of activations
				score = js_divergence_pytorch(activations_lang1, activations_lang2, reduction='batchmean')
				
				# Store score in matrix
				js_matrix[layer_id + 1, lang_idx1, lang_idx2] = score

In [None]:
js_matrix[1, :, :]

In [None]:
# Get top k minimum distance values in each layer with their language pair identifiers

# Set k as a variable
k = 50  # You can change this value as needed

# Initialize lists to store results
all_min_distances = []
all_min_distance_pairs = []

# For each layer, find the top k minimum distances and corresponding language pairs
for layer_id in range(-1, num_layers):
	layer_distances = js_matrix[layer_id + 1]
	
	# Create a mask to exclude diagonal elements and upper triangle (to avoid duplicates)
	mask = torch.tril(torch.ones_like(layer_distances, dtype=torch.bool), diagonal=-1)
	
	# Get all distances excluding diagonal and upper triangle, then flatten
	lower_triangle_distances = layer_distances[mask]
	
	# Get top k minimum distances
	topk_min_values, topk_indices = torch.topk(lower_triangle_distances, k=k, largest=False)
	
	# Convert flat indices back to 2D coordinates
	layer_min_distances = []
	layer_min_pairs = []
	
	# Get indices of lower triangle elements
	lower_i, lower_j = torch.where(mask)
	
	for idx in range(k):
		flat_idx = topk_indices[idx].item()
		lang_i_idx = lower_i[flat_idx].item()
		lang_j_idx = lower_j[flat_idx].item()
		
		layer_min_distances.append(topk_min_values[idx].item())
		layer_min_pairs.append((languages[lang_i_idx], languages[lang_j_idx]))
	
	all_min_distances.append(layer_min_distances)
	all_min_distance_pairs.append(layer_min_pairs)

# Create a detailed DataFrame
detailed_results = []
for layer_id in range(-1, num_layers):
	for rank in range(k):
		lang1, lang2 = all_min_distance_pairs[layer_id + 1][rank]
		detailed_results.append({
			'Layer': layer_id,
			'Rank': rank + 1,
			'Min_Distance': all_min_distances[layer_id + 1][rank],
			'Language_Pair': f"{LANGCODE2LANGNAME[lang1]} ({lang1}) - {LANGCODE2LANGNAME[lang2]} ({lang2})"
		})

results_df = pd.DataFrame(detailed_results)

print(f"Top {k} minimum distances per layer:")
print(results_df.to_string(index=False))

# Also print a summary view grouped by layer
print("\nSummary by layer:")
for layer_id in range(-1, num_layers):
	print(f"\nLayer {layer_id:2d}:")
	for rank in range(k):
		lang1, lang2 = all_min_distance_pairs[layer_id + 1][rank]
		distance = all_min_distances[layer_id + 1][rank]
		print(f"  {rank+1}. {distance:.6f} - {LANGCODE2LANGNAME[lang1]} ({lang1}) & {LANGCODE2LANGNAME[lang2]} ({lang2})")

### Wasserstein

In [13]:
import numpy as np
from scipy.stats import wasserstein_distance

# These arrays represent the *values* at which the probability masses are located.
# Think of these as the positions of the "dirt piles".
values_p = np.array([0, 1, 3])
values_q = np.array([5, 6, 8])

# These arrays represent the *weights* or probabilities at each value.
# They must sum to 1.
weights_p = np.array([0.2, 0.5, 0.3])
weights_q = np.array([0.2, 0.5, 0.3])

# Calculate the 1st Wasserstein distance (EMD)
# It calculates the cost to transform distribution p into distribution q
dist = wasserstein_distance(values_p, values_q)

print(f"Wasserstein Distance (SciPy): {dist:.4f}")
# The output is 5.0, because each chunk of "dirt" has to move 5 units to the right.
# (0.2 * (5-0)) + (0.5 * (6-1)) + (0.3 * (8-3)) = 1.0 + 2.5 + 1.5 = 5.0

Wasserstein Distance (SciPy): 5.0000


In [14]:
import ot
# 4. Calculate the silhouette score on the GPU
# The function takes the data and the predicted labels as input.
wasserstein_distance_matrix = torch.zeros((num_layers + 1, len(languages), len(languages))).to(device)
for layer_id in tqdm(range(-1, num_layers)):
	# Calculate pairwise silhouette scores per language
	for lang_idx1 in range(len(languages)):
		for lang_idx2 in range(len(languages)):
			if lang_idx1 != lang_idx2:
				# Take all index of labels that have value lang_idx1 and lang_idx2
				indexes_lang1 = (labels == lang_idx1).nonzero(as_tuple=True)[0]
				indexes_lang2 = (labels == lang_idx2).nonzero(as_tuple=True)[0]

				# Take activations for those indexes
				activations_lang1 = activation_per_lang[indexes_lang1, layer_id + 1, :]
				activations_lang2 = activation_per_lang[indexes_lang2, layer_id + 1, :]
				labels_lang1 = labels[indexes_lang1]
				labels_lang2 = labels[indexes_lang2]

				# Uniform weights for both distributions
				activation1_weights = torch.ones(activations_lang1.shape[0],).to(device) / activations_lang1.shape[0]
				activation2_weights = torch.ones(activations_lang2.shape[0],).to(device) / activations_lang2.shape[0]

				# Cost matrix is also computed on the GPU
				cost_matrix = ot.dist(activations_lang1, activations_lang2)

				# EMD calculation runs on the GPU
				emd_distance_gpu = ot.emd2(activation1_weights, activation2_weights, cost_matrix)
				
				# Store score in matrix
				wasserstein_distance_matrix[layer_id + 1, lang_idx1, lang_idx2] = emd_distance_gpu

				# print(f"Wasserstein Distance between {id2langcode[lang_idx1]} and {id2langcode[lang_idx2]}: {score:.4f}")

# For this well-separated data, you should see a high score (e.g., > 0.8)
# A high score indicates that the clusters are dense and well-separated.

100%|██████████| 35/35 [07:07<00:00, 12.21s/it]


In [17]:
# Get top k minimum distance values in each layer with their language pair identifiers

# Set k as a variable
k = 50  # You can change this value as needed

# Initialize lists to store results
all_min_distances = []
all_min_distance_pairs = []

# For each layer, find the top k minimum distances and corresponding language pairs
for layer_id in range(-1, num_layers):
	layer_distances = wasserstein_distance_matrix[layer_id + 1]
	
	# Create a mask to exclude diagonal elements and upper triangle (to avoid duplicates)
	mask = torch.tril(torch.ones_like(layer_distances, dtype=torch.bool), diagonal=-1)
	
	# Get all distances excluding diagonal and upper triangle, then flatten
	lower_triangle_distances = layer_distances[mask]
	
	# Get top k minimum distances
	topk_min_values, topk_indices = torch.topk(lower_triangle_distances, k=k, largest=False)
	
	# Convert flat indices back to 2D coordinates
	layer_min_distances = []
	layer_min_pairs = []
	
	# Get indices of lower triangle elements
	lower_i, lower_j = torch.where(mask)
	
	for idx in range(k):
		flat_idx = topk_indices[idx].item()
		lang_i_idx = lower_i[flat_idx].item()
		lang_j_idx = lower_j[flat_idx].item()
		
		layer_min_distances.append(topk_min_values[idx].item())
		layer_min_pairs.append((languages[lang_i_idx], languages[lang_j_idx]))
	
	all_min_distances.append(layer_min_distances)
	all_min_distance_pairs.append(layer_min_pairs)

# Create a detailed DataFrame
detailed_results = []
for layer_id in range(-1, num_layers):
	for rank in range(k):
		lang1, lang2 = all_min_distance_pairs[layer_id + 1][rank]
		detailed_results.append({
			'Layer': layer_id,
			'Rank': rank + 1,
			'Min_Distance': all_min_distances[layer_id + 1][rank],
			'Language_Pair': f"{LANGCODE2LANGNAME[lang1]} ({lang1}) - {LANGCODE2LANGNAME[lang2]} ({lang2})"
		})

results_df = pd.DataFrame(detailed_results)

print(f"Top {k} minimum distances per layer:")
print(results_df.to_string(index=False))

# Also print a summary view grouped by layer
print("\nSummary by layer:")
for layer_id in range(-1, num_layers):
	print(f"\nLayer {layer_id:2d}:")
	for rank in range(k):
		lang1, lang2 = all_min_distance_pairs[layer_id + 1][rank]
		distance = all_min_distances[layer_id + 1][rank]
		print(f"  {rank+1}. {distance:.6f} - {LANGCODE2LANGNAME[lang1]} ({lang1}) & {LANGCODE2LANGNAME[lang2]} ({lang2})")

Top 50 minimum distances per layer:
 Layer  Rank  Min_Distance                                                        Language_Pair
    -1     1  2.929688e-03                               German (deu_Latn) - English (eng_Latn)
    -1     2  2.929688e-03                              Spanish (spa_Latn) - English (eng_Latn)
    -1     3  2.929688e-03                               Spanish (spa_Latn) - German (deu_Latn)
    -1     4  2.929688e-03                               French (fra_Latn) - English (eng_Latn)
    -1     5  2.929688e-03                                French (fra_Latn) - German (deu_Latn)
    -1     6  2.929688e-03                               French (fra_Latn) - Spanish (spa_Latn)
    -1     7  2.929688e-03                         MSA (Arabic) (arb_Arab) - English (eng_Latn)
    -1     8  2.929688e-03                          MSA (Arabic) (arb_Arab) - German (deu_Latn)
    -1     9  2.929688e-03                         MSA (Arabic) (arb_Arab) - Spanish (spa_Latn)
    

In [16]:
# Calculate mean and std deviation of the distances for each layer
layer_means = wasserstein_distance_matrix.mean(dim=(1, 2)).cpu().numpy()
layer_stds = wasserstein_distance_matrix.std(dim=(1, 2)).cpu().numpy()

# Create a DataFrame to display the statistics
layer_stats_df = pd.DataFrame({
	'Layer': range(-1, num_layers),
	'Mean_Distance': layer_means,
	'Std_Distance': layer_stds
})

print("Layer-wise Distance Statistics:")
print(layer_stats_df.to_string(index=False))

Layer-wise Distance Statistics:
 Layer  Mean_Distance  Std_Distance
    -1   2.852590e-03  4.691254e-04
     0   6.386620e+02  5.423558e+02
     1   1.813951e+03  1.223371e+03
     2   3.709408e+03  2.356150e+03
     3   5.247855e+03  2.999606e+03
     4   1.030571e+04  7.206918e+03
     5   1.445517e+04  9.123320e+03
     6   3.149641e+04  2.058406e+04
     7   5.193120e+04  3.755681e+04
     8   1.039401e+05  1.072516e+05
     9   1.236785e+05  1.112519e+05
    10   5.913339e+05  7.385030e+05
    11   1.273864e+06  1.538876e+06
    12   1.227095e+06  1.536666e+06
    13   9.167796e+05  9.011554e+05
    14   1.501129e+06  1.844294e+06
    15   2.085146e+06  2.176387e+06
    16   3.144592e+06  3.231702e+06
    17   7.488772e+06  6.520528e+06
    18   9.458525e+06  7.799436e+06
    19   1.195854e+07  8.861363e+06
    20   1.440090e+07  9.155014e+06
    21   1.890945e+07  1.163127e+07
    22   2.342922e+07  1.380477e+07
    23   3.219347e+07  2.115258e+07
    24   6.297903e+07  2.924195e

In [None]:
import numpy as np
from sklearn.cluster import AgglomerativeClustering
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import dendrogram, linkage
# Create subplots for dendrograms across all layers
fig, axes = plt.subplots(nrows=9, ncols=5, figsize=(25, 80))
fig.suptitle('Hierarchical Clustering Dendrograms Across All Layers', fontsize=16)

# Flatten axes for easier indexing
axes = axes.flatten()

for layer_id in range(-1, num_layers):
	ax_idx = layer_id + 1
	
	# Perform clustering for this layer
	linked = linkage(wasserstein_distance_matrix[layer_id + 1].cpu().numpy(), 'complete')
	
	# Create dendrogram in the corresponding subplot
	dendrogram(linked, 
			   labels=[LANGCODE2LANGNAME[lang] for lang in languages], 
			   orientation='right',
			   ax=axes[ax_idx])
	
	axes[ax_idx].set_title(f"Layer {layer_id}", fontsize=12)
	axes[ax_idx].tick_params(axis='y', labelsize=8)
	axes[ax_idx].tick_params(axis='x', labelsize=8)

plt.tight_layout()
plt.show()


In [None]:
import numpy as np
from sklearn.cluster import AgglomerativeClustering
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import dendrogram, linkage

# Create subplots for dendrograms across all layers
fig, axes = plt.subplots(nrows=9, ncols=5, figsize=(25, 80))
fig.suptitle('Hierarchical Clustering Dendrograms Across All Layers', fontsize=16)

# Flatten axes for easier indexing
axes = axes.flatten()

# First pass: find the maximum distance across all layers
max_distance = 0
linkage_results = []
for layer_id in range(-1, num_layers):
	linked = linkage(wasserstein_distance_matrix[layer_id + 1].cpu().numpy(), 'complete')
	linkage_results.append(linked)
	max_distance = max(max_distance, linked[:, 2].max())

# Second pass: create dendrograms with consistent scale
for layer_id in range(-1, num_layers):
	ax_idx = layer_id + 1
	
	# Create dendrogram in the corresponding subplot
	dendrogram(linkage_results[ax_idx], 
			   labels=[LANGCODE2LANGNAME[lang] for lang in languages], 
			   orientation='right',
			   ax=axes[ax_idx])
	
	axes[ax_idx].set_title(f"Layer {layer_id}", fontsize=12)
	axes[ax_idx].tick_params(axis='y', labelsize=8)
	axes[ax_idx].tick_params(axis='x', labelsize=8)
	# axes[ax_idx].set_xlim([0, max_distance])  # Set consistent x-axis limits

plt.tight_layout()
plt.show()
