In [1]:
import os
import sys

# Change to parent directory
os.chdir('..')
# Add current directory to path
sys.path.append('.')

In [2]:
import pandas as pd
import numpy as np
from src.utils.const import LANGCODE2LANGNAME, LANGNAME2LANGCODE, MODEL2HIDDEN_SIZE, MODEL2NUM_LAYERS, EXP2_CONFIG, EXP3_CONFIG, EXP4_CONFIG, MODEL2HF_NAME
import glob
import torch
from tqdm import tqdm

In [3]:
model_names = glob.glob('outputs_silhouette/exp4/next_token/dev/*')
model_names = [mn.split('/')[-1] for mn in model_names]
model_names.sort()
print(model_names)
model_name = model_names[0]  # Change index to select different models

['gemma-3-12b-it']


In [4]:
languages = []
for family, langs in EXP4_CONFIG['languages'].items():
	languages.extend(langs)

languages = [LANGNAME2LANGCODE[lang] for lang in languages]
languages

['eng_Latn',
 'deu_Latn',
 'nld_Latn',
 'swe_Latn',
 'nob_Latn',
 'isl_Latn',
 'spa_Latn',
 'fra_Latn',
 'ita_Latn',
 'por_Latn',
 'ron_Latn',
 'arb_Arab',
 'heb_Hebr',
 'amh_Ethi',
 'zgh_Tfng',
 'taq_Latn',
 'taq_Tfng',
 'hau_Latn',
 'som_Latn',
 'gaz_Latn',
 'fin_Latn',
 'hun_Latn',
 'ekk_Latn',
 'rus_Cyrl',
 'ukr_Cyrl',
 'srp_Cyrl',
 'bul_Cyrl',
 'slk_Latn',
 'pol_Latn',
 'ces_Latn',
 'lit_Latn',
 'lvs_Latn',
 'tha_Thai',
 'lao_Laoo',
 'khm_Khmr',
 'vie_Latn',
 'yor_Latn',
 'ibo_Latn',
 'swh_Latn',
 'xho_Latn',
 'zul_Latn',
 'urd_Arab',
 'hin_Deva',
 'ben_Beng',
 'mar_Deva',
 'tam_Taml',
 'kan_Knda',
 'tel_Telu',
 'jpn_Jpan',
 'kor_Hang',
 'tur_Latn',
 'azb_Arab',
 'azj_Latn',
 'pes_Arab',
 'kmr_Latn',
 'pbt_Arab',
 'fil_Latn',
 'ceb_Latn',
 'ilo_Latn',
 'war_Latn',
 'cmn_Hans',
 'yue_Hant',
 'cmn_Hant',
 'wuu_Hans',
 'ind_Latn',
 'zsm_Latn',
 'min_Latn',
 'min_Arab',
 'bjn_Latn',
 'bjn_Arab',
 'jav_Latn',
 'sun_Latn',
 'quy_Latn',
 'gug_Latn',
 'ayr_Latn',
 'fij_Latn',
 'mri_Latn',

In [5]:
def get_google_sheet(sheet_id: str, sheet_gid: str) -> pd.DataFrame:
	"""
	Downloads a specific sheet from a Google Sheet into a pandas DataFrame.

	Args:
		sheet_id: The ID of the Google Sheet.
		sheet_gid: The GID of the specific sheet to download.

	Returns:
		A pandas DataFrame containing the data from the specified sheet.
	"""
	url = f'https://docs.google.com/spreadsheets/d/{sheet_id}/export?format=csv&gid={sheet_gid}'
	df = pd.read_csv(url)
	return df
google_sheet_id = '1CmhOZeYTbfePLI2-rMubJpnKHuS6RLEfHYZD6-rVQ0M'  # Replace with your actual Google Sheet ID
gid = '0'  # Replace with the actual GID for the Indonesian sheet
corrected_data = []
try:
	df_lang = get_google_sheet(google_sheet_id, gid)
except Exception as e:
	print(f"An error occurred: {e}")
	print("Please ensure the Google Sheet is shared correctly and the IDs are correct.")

In [6]:
df_lang

Unnamed: 0,Language name,Language code,Joshi’s class,Language family,Language sub-family,Language sub-sub-family,Region,Syntax
0,Amharic,amh_Ethi,2,Afro-Asiatic,Semitic,Amharic-Argobba,Africa,SOV
1,Moroccan Arabic,ary_Arab,-,Afro-Asiatic,Semitic,North African Arabic,Africa,VSO
2,Egyptian Arabic,arz_Arab,3,Afro-Asiatic,Semitic,Egyptic Arabic,Africa,SVO
3,South Levantine Arabic,ajp_Arab,-,Afro-Asiatic,Semitic,Levantine-Cypriot Arabic,Asia 1,VSO
4,North Levantine Arabic,apc_Arab,-,Afro-Asiatic,Semitic,Levantine-Cypriot Arabic,Asia 1,VSO
...,...,...,...,...,...,...,...,...
107,North Azerbaijani,azj_Latn,1,Turkic,Oghuz,Central Oghuz,Asia 1,
108,Turkmen,tuk_Latn,1,Turkic,Oghuz,East Oghuz,Asia 1,SOV
109,Turkish,tur_Latn,4,Turkic,Oghuz,West Oghuz,Asia 1,SOV
110,Uyghur,uig_Arab,1,Turkic,Turkestan,Uyghuric,Asia 1,SOV


In [7]:
df_lang['Syntax'] = df_lang['Syntax'].fillna('Unknown')

In [8]:
df_lang['script'] = df_lang['Language code'].apply(lambda x: x.split('_')[1])

In [9]:
df_lang['Language sub-sub-family'] = df_lang['Language sub-sub-family'].fillna(df_lang['Language sub-family'])
df_lang['Language sub-sub-family'] = df_lang['Language sub-sub-family'].fillna(df_lang['Language family'])

In [10]:
lengths_per_category = {}
for family, langs in EXP4_CONFIG['languages'].items():
	lengths_per_category[family] = len(langs)

In [11]:
lang_to_subsubfamily = {}
for lang in languages:
	lang_name = LANGCODE2LANGNAME[lang]
	family_row = df_lang[df_lang['Language name'] == lang_name]
	if not family_row.empty:
		family = family_row.iloc[0]['Language sub-sub-family']
		lang_to_subsubfamily[lang_name] = family
	else:
		print(f"Warning: Language {lang_name} not found in the DataFrame.")
		lang_to_subsubfamily[lang_name] = 'Unknown'

In [12]:
lang_to_region = {}
for lang in languages:
	lang_name = LANGCODE2LANGNAME[lang]
	region_row = df_lang[df_lang['Language name'] == lang_name]
	if not region_row.empty:
		region = region_row.iloc[0]['Region']
		lang_to_region[lang_name] = region
	else:
		print(f"Warning: Language {lang_name} not found in the DataFrame.")
		lang_to_region[lang_name] = 'Unknown'

In [13]:
lang_to_script = {}
for lang in languages:
	lang_name = LANGCODE2LANGNAME[lang]
	script_row = df_lang[df_lang['Language name'] == lang_name]
	if not script_row.empty:
		script = script_row.iloc[0]['script']
		lang_to_script[lang_name] = script
	else:
		print(f"Warning: Language {lang_name} not found in the DataFrame.")
		lang_to_script[lang_name] = 'Unknown'

In [14]:
lang_to_subfamily = {}
for lang in languages:
	lang_name = LANGCODE2LANGNAME[lang]
	subfamily_row = df_lang[df_lang['Language name'] == lang_name]
	if not subfamily_row.empty:
		subfamily = subfamily_row.iloc[0]['Language sub-family']
		lang_to_subfamily[lang_name] = subfamily
	else:
		print(f"Warning: Language {lang_name} not found in the DataFrame.")
		lang_to_subfamily[lang_name] = 'Unknown'

In [15]:
lang_to_family = {}
for lang in languages:
	lang_name = LANGCODE2LANGNAME[lang]
	family_row = df_lang[df_lang['Language name'] == lang_name]
	if not family_row.empty:
		family = family_row.iloc[0]['Language family']
		lang_to_family[lang_name] = family
	else:
		lang_to_family[lang_name] = 'Unknown'

In [16]:
lang_to_syntax = {}
for lang in languages:
	lang_name = LANGCODE2LANGNAME[lang]
	syntax_row = df_lang[df_lang['Language name'] == lang_name]
	if not syntax_row.empty:
		syntax = syntax_row.iloc[0]['Syntax']
		lang_to_syntax[lang_name] = syntax
	else:
		print(f"Warning: Language {lang_name} not found in the DataFrame.")
		lang_to_syntax[lang_name] = 'Unknown'

In [17]:
set(lang_to_family.values()).__len__(), set(lang_to_region.values()).__len__(), set(lang_to_script.values()).__len__(), set(lang_to_subfamily.values()).__len__(), set(lang_to_subsubfamily.values()).__len__(), set(lang_to_syntax.values()).__len__()

(15, 8, 19, 19, 39, 5)

In [18]:
categories = ['family', 'region', 'script', 'subfamily', 'subsubfamily', 'syntax']

### Dendogram

In [29]:
extraction_mode = 'raw'
token_position = 'last_token'
task_id = 'next_token'
exp_id = 'exp4'
activation_loc = 'residual-postmlp'

In [30]:
sil_path  = f'outputs_silhouette/{exp_id}/{task_id}/dev/{model_name}/{extraction_mode}/{token_position}/{activation_loc}/silhouette_score_matrix.pt'
silhouette_score_matrix = torch.load(sil_path, map_location='cpu')
num_layers = MODEL2NUM_LAYERS[model_name]
hidden_size = MODEL2HIDDEN_SIZE[model_name]

In [31]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from scipy.cluster.hierarchy import dendrogram, linkage
from scipy.spatial.distance import squareform
from tqdm import tqdm
import os

In [32]:
model_names

['gemma-3-12b-it']

In [33]:
from math import ceil
ceil(num_layers / 5)

10

In [34]:
silhouette_score_matrix[1]

tensor([[0.0000, 0.2258, 0.1829,  ..., 0.2000, 0.4143, 0.4065],
        [0.0000, 0.0000, 0.1561,  ..., 0.2552, 0.4891, 0.3448],
        [0.0000, 0.0000, 0.0000,  ..., 0.2097, 0.4637, 0.3802],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.3613, 0.3850],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.5688],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])

In [35]:
# Check if a matrix is symmetric, if not, show which part of the matrix is not symmetric
def is_symmetric(matrix, tol=1e-8):
	"""
	Check if a matrix is symmetric.
	
	Args:
		matrix: numpy array or torch tensor
		tol: tolerance for comparison
		
	Returns:
		bool: True if symmetric, False otherwise
	"""
	# Convert torch tensor to numpy if needed
	if hasattr(matrix, 'numpy'):
		matrix = matrix.cpu().numpy()
	
	diff = np.abs(matrix - matrix.T)
	is_sym = np.all(diff < tol)
	
	if not is_sym:
		# Find indices where symmetry is violated
		non_sym_indices = np.where(diff >= tol)
		# print(f"Matrix is NOT symmetric. Found {len(non_sym_indices[0])} violations.")
		# print(f"First 10 non-symmetric positions (i, j, |M[i,j] - M[j,i]|):")
		# for idx in range(min(10, len(non_sym_indices[0]))):
		# 	i, j = non_sym_indices[0][idx], non_sym_indices[1][idx]
		# 	print(f"  ({i}, {j}): difference = {diff[i, j]:.6f}")
	return is_sym
for layer_id in range(-1, num_layers):
	if not is_symmetric(silhouette_score_matrix[layer_id + 1]):
		# Make the matrix symmetric by copying the upper triangle to the lower triangle
		matrix = silhouette_score_matrix[layer_id + 1].cpu().numpy()
		symmetric_matrix = (matrix + matrix.T) / 2
		silhouette_score_matrix[layer_id + 1] = torch.tensor(symmetric_matrix)
		print(f"Corrected symmetry for layer {layer_id + 1}. {is_symmetric(silhouette_score_matrix[layer_id + 1])}")

Corrected symmetry for layer 1. True
Corrected symmetry for layer 2. True
Corrected symmetry for layer 3. True
Corrected symmetry for layer 4. True
Corrected symmetry for layer 5. True
Corrected symmetry for layer 6. True
Corrected symmetry for layer 7. True
Corrected symmetry for layer 8. True
Corrected symmetry for layer 9. True
Corrected symmetry for layer 10. True
Corrected symmetry for layer 11. True
Corrected symmetry for layer 12. True
Corrected symmetry for layer 13. True
Corrected symmetry for layer 14. True
Corrected symmetry for layer 15. True
Corrected symmetry for layer 16. True
Corrected symmetry for layer 17. True
Corrected symmetry for layer 18. True
Corrected symmetry for layer 19. True
Corrected symmetry for layer 20. True
Corrected symmetry for layer 21. True
Corrected symmetry for layer 22. True
Corrected symmetry for layer 23. True
Corrected symmetry for layer 24. True
Corrected symmetry for layer 25. True
Corrected symmetry for layer 26. True
Corrected symmetry fo

In [37]:
# Check negative values
for layer_id in range(-1, num_layers):
	if (silhouette_score_matrix[layer_id + 1] < 0).any():
		num_neg = (silhouette_score_matrix[layer_id + 1] < 0).sum().item()
		print(f"Layer {layer_id + 1} has {num_neg} negative values.")
	
	# Print the negative values
	neg_values = silhouette_score_matrix[layer_id + 1][silhouette_score_matrix[layer_id + 1] < 0]
	print(f"Negative values in layer {layer_id + 1}: {neg_values}")

Negative values in layer 0: tensor([])
Negative values in layer 1: tensor([])
Negative values in layer 2: tensor([])
Negative values in layer 3: tensor([])
Negative values in layer 4: tensor([])
Negative values in layer 5: tensor([])
Negative values in layer 6: tensor([])
Negative values in layer 7: tensor([])
Negative values in layer 8: tensor([])
Negative values in layer 9: tensor([])
Negative values in layer 10: tensor([])
Negative values in layer 11: tensor([])
Negative values in layer 12: tensor([])
Negative values in layer 13: tensor([])
Negative values in layer 14: tensor([])
Negative values in layer 15: tensor([])
Negative values in layer 16: tensor([])
Negative values in layer 17: tensor([])
Negative values in layer 18: tensor([])
Negative values in layer 19: tensor([])
Negative values in layer 20: tensor([])
Negative values in layer 21: tensor([])
Negative values in layer 22: tensor([])
Negative values in layer 23: tensor([])
Negative values in layer 24: tensor([])
Negative v

In [38]:
model_names

['gemma-3-12b-it']

In [39]:
for model_name in model_names:
	# if '-14B' in model_name:
	# 	continue
	extraction_mode = 'raw'
	token_position = 'last_token'
	task_id = 'next_token'
	exp_id = 'exp4'
	activation_loc = 'residual-postmlp'

	num_layers = MODEL2NUM_LAYERS[model_name]
	hidden_size = MODEL2HIDDEN_SIZE[model_name]

	sil_path  = f'outputs_silhouette/{exp_id}/{task_id}/dev/{model_name}/{extraction_mode}/{token_position}/{activation_loc}/silhouette_score_matrix.pt'
	silhouette_score_matrix = torch.load(sil_path, map_location='cpu')
	print(f"Processing model: {model_name} with {num_layers} layers and hidden size {hidden_size}")
	for category in categories:
		print(f"Generating dendrograms colored by {category}...")
		
		# ==========================================
		# 1. COLOR MAP SETUP
		# ==========================================
		# We use 'nipy_spectral' which is great for high cardinality (many categories)
		unique_categories = list(set(globals()[f"lang_to_{category}"].values()))
		n_categories = len(unique_categories)

		# Create a color map object
		cmap = plt.get_cmap('nipy_spectral', n_categories)

		# Map every family name to a specific RGBA color
		category_colors = {cat: cmap(i) for i, cat in enumerate(unique_categories)}

		# ==========================================
		# 2. PLOTTING LOOP
		# ==========================================

		# Create subplots
		ncols = 5
		nrows = ceil(num_layers / ncols)
		# Make the figure size dynamic based on number of rows
		fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(35, 10 * nrows))
		fig.suptitle('Hierarchical Clustering Dendrograms Across All Layers', fontsize=24, y=0.99)

		axes = axes.flatten()

		# First pass: find max distance (kept from your code)
		max_distance = 0
		linkage_results = []
		# Assuming num_layers is defined elsewhere in your variable scope
		for layer_id in range(-1, num_layers):
			if not is_symmetric(silhouette_score_matrix[layer_id + 1]):
				# Make the matrix symmetric by copying the upper triangle to the lower triangle
				matrix = silhouette_score_matrix[layer_id + 1].cpu().numpy()
				symmetric_matrix = (matrix + matrix.T) / 2
				silhouette_score_matrix[layer_id + 1] = torch.tensor(symmetric_matrix)
				# print(f"Corrected symmetry for layer {layer_id + 1}. {is_symmetric(silhouette_score_matrix[layer_id + 1])}")
			
			# Cap the values at 0 for distance metric
			silhouette_score_matrix[layer_id + 1] = torch.clamp(silhouette_score_matrix[layer_id + 1], min=0)

			linked = linkage(squareform(silhouette_score_matrix[layer_id + 1].cpu().numpy()), 'complete', metric='precomputed')
			linkage_results.append(linked)
			max_distance = max(max_distance, linked[:, 2].max())


		# Second pass: Plot and Color
		for layer_id in tqdm(range(0, num_layers)):
			ax_idx = layer_id + 1
			current_ax = axes[ax_idx - 1]
			
			# 1. Create Dendrogram
			d = dendrogram(
				linkage_results[ax_idx], 
				labels=[LANGCODE2LANGNAME[lang] for lang in languages], 
				orientation='right',
				ax=current_ax
			)
			
			current_ax.set_title(f"Layer {layer_id}", fontsize=12)
			current_ax.tick_params(axis='x', labelsize=6)
			
			# Set consistent x-axis limits
			current_ax.set_xlim([0, max_distance])  

			# ==========================================
			# 3. APPLY COLORS TO Y-AXIS LABELS
			# ==========================================
			
			# Get all text objects on the Y-axis
			y_labels = current_ax.get_ymajorticklabels()
			
			for label in y_labels:
				lang_name = label.get_text()
				
				# Look up the family, default to 'Unknown' if missing
				fam = globals()[f"lang_to_{category}"].get(lang_name, "Unknown")
				
				# Get the color, default to black if family not found in color map
				col = category_colors.get(fam, "black")
				
				# Set the color and force font size
				label.set_color(col)
				label.set_fontsize(6) # Ensuring size is readable


		# ==========================================
		# 4. CREATE A LEGEND (Optional but recommended)
		# ==========================================
		# Since you have 35 categories, we put the legend on the very first 
		# plot or a dedicated space, or outside the figure.
		# Here we add it to the top of the Figure.

		handles = [mpatches.Patch(color=category_colors[f], label=f) for f in unique_categories]
		fig.legend(handles=handles, loc='upper center', bbox_to_anchor=(0.5, 0.98), 
				ncol=7, fontsize=10, title=f"Language {category.capitalize()}s")

		# Save image
		image_path = os.path.join(os.path.dirname(sil_path), f'dendrogram_by_{category}.png')
		print(f"Saving dendrogram figure to {image_path}...")
		# Increased top margin (rect) to make room for the big legend
		plt.tight_layout(rect=[0, 0.0, 1, 0.96]) 
		plt.savefig(image_path, dpi=300)
		plt.close()

Processing model: gemma-3-12b-it with 48 layers and hidden size 3840
Generating dendrograms colored by family...


100%|██████████| 48/48 [00:04<00:00,  9.88it/s]


Saving dendrogram figure to outputs_silhouette/exp4/next_token/dev/gemma-3-12b-it/raw/last_token/residual-postmlp/dendrogram_by_family.png...
Generating dendrograms colored by region...


100%|██████████| 48/48 [00:04<00:00, 10.17it/s]


Saving dendrogram figure to outputs_silhouette/exp4/next_token/dev/gemma-3-12b-it/raw/last_token/residual-postmlp/dendrogram_by_region.png...
Generating dendrograms colored by script...


100%|██████████| 48/48 [00:04<00:00,  9.83it/s]


Saving dendrogram figure to outputs_silhouette/exp4/next_token/dev/gemma-3-12b-it/raw/last_token/residual-postmlp/dendrogram_by_script.png...
Generating dendrograms colored by subfamily...


100%|██████████| 48/48 [00:05<00:00,  9.49it/s]


Saving dendrogram figure to outputs_silhouette/exp4/next_token/dev/gemma-3-12b-it/raw/last_token/residual-postmlp/dendrogram_by_subfamily.png...
Generating dendrograms colored by subsubfamily...


100%|██████████| 48/48 [00:05<00:00,  9.27it/s]


Saving dendrogram figure to outputs_silhouette/exp4/next_token/dev/gemma-3-12b-it/raw/last_token/residual-postmlp/dendrogram_by_subsubfamily.png...
Generating dendrograms colored by syntax...


100%|██████████| 48/48 [00:05<00:00,  8.74it/s]


Saving dendrogram figure to outputs_silhouette/exp4/next_token/dev/gemma-3-12b-it/raw/last_token/residual-postmlp/dendrogram_by_syntax.png...


In [40]:
for activation_loc in ['residual-preattn', 'residual-postattn']:
	for model_name in model_names:
		# if '-14B' in model_name:
		# 	continue
		extraction_mode = 'raw'
		token_position = 'last_token'
		task_id = 'next_token'
		exp_id = 'exp4'

		num_layers = MODEL2NUM_LAYERS[model_name]
		hidden_size = MODEL2HIDDEN_SIZE[model_name]

		sil_path  = f'outputs_silhouette/{exp_id}/{task_id}/dev/{model_name}/{extraction_mode}/{token_position}/{activation_loc}/silhouette_score_matrix.pt'
		silhouette_score_matrix = torch.load(sil_path, map_location='cpu')
		print(f"Processing model: {model_name} with {num_layers} layers and hidden size {hidden_size}")
		for category in categories:
			print(f"Generating dendrograms colored by {category}...")
			
			# ==========================================
			# 1. COLOR MAP SETUP
			# ==========================================
			# We use 'nipy_spectral' which is great for high cardinality (many categories)
			unique_categories = list(set(globals()[f"lang_to_{category}"].values()))
			n_categories = len(unique_categories)

			# Create a color map object
			cmap = plt.get_cmap('nipy_spectral', n_categories)

			# Map every family name to a specific RGBA color
			category_colors = {cat: cmap(i) for i, cat in enumerate(unique_categories)}

			# ==========================================
			# 2. PLOTTING LOOP
			# ==========================================

			# Create subplots
			ncols = 5
			nrows = ceil(num_layers / ncols)
			# Make the figure size dynamic based on number of rows
			fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(35, 10 * nrows))
			fig.suptitle('Hierarchical Clustering Dendrograms Across All Layers', fontsize=24, y=0.99)

			axes = axes.flatten()

			# First pass: find max distance (kept from your code)
			max_distance = 0
			linkage_results = []
			# Assuming num_layers is defined elsewhere in your variable scope
			for layer_id in range(-1, num_layers):
				if not is_symmetric(silhouette_score_matrix[layer_id + 1]):
					# Make the matrix symmetric by copying the upper triangle to the lower triangle
					matrix = silhouette_score_matrix[layer_id + 1].cpu().numpy()
					symmetric_matrix = (matrix + matrix.T) / 2
					silhouette_score_matrix[layer_id + 1] = torch.tensor(symmetric_matrix)
					# print(f"Corrected symmetry for layer {layer_id + 1}. {is_symmetric(silhouette_score_matrix[layer_id + 1])}")
				
				# Cap the values at 0 for distance metric
				silhouette_score_matrix[layer_id + 1] = torch.clamp(silhouette_score_matrix[layer_id + 1], min=0)

				linked = linkage(squareform(silhouette_score_matrix[layer_id + 1].cpu().numpy()), 'complete', metric='precomputed')
				linkage_results.append(linked)
				max_distance = max(max_distance, linked[:, 2].max())


			# Second pass: Plot and Color
			for layer_id in tqdm(range(0, num_layers)):
				ax_idx = layer_id + 1
				current_ax = axes[ax_idx - 1]
				
				# 1. Create Dendrogram
				d = dendrogram(
					linkage_results[ax_idx], 
					labels=[LANGCODE2LANGNAME[lang] for lang in languages], 
					orientation='right',
					ax=current_ax
				)
				
				current_ax.set_title(f"Layer {layer_id}", fontsize=12)
				current_ax.tick_params(axis='x', labelsize=6)
				
				# Set consistent x-axis limits
				current_ax.set_xlim([0, max_distance])  

				# ==========================================
				# 3. APPLY COLORS TO Y-AXIS LABELS
				# ==========================================
				
				# Get all text objects on the Y-axis
				y_labels = current_ax.get_ymajorticklabels()
				
				for label in y_labels:
					lang_name = label.get_text()
					
					# Look up the family, default to 'Unknown' if missing
					fam = globals()[f"lang_to_{category}"].get(lang_name, "Unknown")
					
					# Get the color, default to black if family not found in color map
					col = category_colors.get(fam, "black")
					
					# Set the color and force font size
					label.set_color(col)
					label.set_fontsize(6) # Ensuring size is readable


			# ==========================================
			# 4. CREATE A LEGEND (Optional but recommended)
			# ==========================================
			# Since you have 35 categories, we put the legend on the very first 
			# plot or a dedicated space, or outside the figure.
			# Here we add it to the top of the Figure.

			handles = [mpatches.Patch(color=category_colors[f], label=f) for f in unique_categories]
			fig.legend(handles=handles, loc='upper center', bbox_to_anchor=(0.5, 0.98), 
					ncol=7, fontsize=10, title=f"Language {category.capitalize()}s")

			# Save image
			image_path = os.path.join(os.path.dirname(sil_path), f'dendrogram_by_{category}.png')
			print(f"Saving dendrogram figure to {image_path}...")
			# Increased top margin (rect) to make room for the big legend
			plt.tight_layout(rect=[0, 0.0, 1, 0.96]) 
			plt.savefig(image_path, dpi=300)
			plt.close()

Processing model: gemma-3-12b-it with 48 layers and hidden size 3840
Generating dendrograms colored by family...


  ax.set_xlim([0, dvw])
100%|██████████| 48/48 [00:04<00:00, 11.17it/s]


Saving dendrogram figure to outputs_silhouette/exp4/next_token/dev/gemma-3-12b-it/raw/last_token/residual-preattn/dendrogram_by_family.png...
Generating dendrograms colored by region...


  ax.set_xlim([0, dvw])
100%|██████████| 48/48 [00:05<00:00,  9.53it/s]


Saving dendrogram figure to outputs_silhouette/exp4/next_token/dev/gemma-3-12b-it/raw/last_token/residual-preattn/dendrogram_by_region.png...
Generating dendrograms colored by script...


  ax.set_xlim([0, dvw])
100%|██████████| 48/48 [00:04<00:00, 10.01it/s]


Saving dendrogram figure to outputs_silhouette/exp4/next_token/dev/gemma-3-12b-it/raw/last_token/residual-preattn/dendrogram_by_script.png...
Generating dendrograms colored by subfamily...


  ax.set_xlim([0, dvw])
100%|██████████| 48/48 [00:04<00:00,  9.60it/s]


Saving dendrogram figure to outputs_silhouette/exp4/next_token/dev/gemma-3-12b-it/raw/last_token/residual-preattn/dendrogram_by_subfamily.png...
Generating dendrograms colored by subsubfamily...


  ax.set_xlim([0, dvw])
100%|██████████| 48/48 [00:05<00:00,  9.42it/s]


Saving dendrogram figure to outputs_silhouette/exp4/next_token/dev/gemma-3-12b-it/raw/last_token/residual-preattn/dendrogram_by_subsubfamily.png...
Generating dendrograms colored by syntax...


  ax.set_xlim([0, dvw])
100%|██████████| 48/48 [00:05<00:00,  9.08it/s]


Saving dendrogram figure to outputs_silhouette/exp4/next_token/dev/gemma-3-12b-it/raw/last_token/residual-preattn/dendrogram_by_syntax.png...
Processing model: gemma-3-12b-it with 48 layers and hidden size 3840
Generating dendrograms colored by family...


100%|██████████| 48/48 [00:05<00:00,  8.71it/s]


Saving dendrogram figure to outputs_silhouette/exp4/next_token/dev/gemma-3-12b-it/raw/last_token/residual-postattn/dendrogram_by_family.png...
Generating dendrograms colored by region...


100%|██████████| 48/48 [00:04<00:00, 11.14it/s]


Saving dendrogram figure to outputs_silhouette/exp4/next_token/dev/gemma-3-12b-it/raw/last_token/residual-postattn/dendrogram_by_region.png...
Generating dendrograms colored by script...


100%|██████████| 48/48 [00:05<00:00,  8.08it/s]


Saving dendrogram figure to outputs_silhouette/exp4/next_token/dev/gemma-3-12b-it/raw/last_token/residual-postattn/dendrogram_by_script.png...
Generating dendrograms colored by subfamily...


100%|██████████| 48/48 [00:04<00:00, 11.30it/s]


Saving dendrogram figure to outputs_silhouette/exp4/next_token/dev/gemma-3-12b-it/raw/last_token/residual-postattn/dendrogram_by_subfamily.png...
Generating dendrograms colored by subsubfamily...


100%|██████████| 48/48 [00:06<00:00,  7.77it/s]


Saving dendrogram figure to outputs_silhouette/exp4/next_token/dev/gemma-3-12b-it/raw/last_token/residual-postattn/dendrogram_by_subsubfamily.png...
Generating dendrograms colored by syntax...


100%|██████████| 48/48 [00:04<00:00, 11.39it/s]


Saving dendrogram figure to outputs_silhouette/exp4/next_token/dev/gemma-3-12b-it/raw/last_token/residual-postattn/dendrogram_by_syntax.png...


### Heatmap

In [None]:
# Heatmap per layer of the silhouette scores between languages, make a subplot for each layer with 4 columns
import seaborn as sns
import matplotlib.pyplot as plt

# Find global min and max across all layers for consistent scale
global_min = float('inf')
global_max = float('-inf')
for layer_id in range(-1, num_layers):
	layer_distances = silhouette_score_matrix[layer_id + 1]
	global_min = min(global_min, layer_distances.min().item())
	global_max = max(global_max, layer_distances.max().item())

fig, axes = plt.subplots(nrows=(num_layers + 1) // 4 + 1, ncols=4, figsize=(20, 5 * ((num_layers + 1) // 4 + 1)))
axes = axes.flatten()
for layer_id in tqdm(range(0, num_layers)):
	layer_distances = silhouette_score_matrix[layer_id + 1]
	sns.heatmap(layer_distances.numpy(), 
		xticklabels=[LANGCODE2LANGNAME[lang] for lang in languages], 
		yticklabels=[LANGCODE2LANGNAME[lang] for lang in languages], 
		ax=axes[layer_id], 
		cmap='viridis_r',
		vmin=global_min,
		vmax=global_max,
		# annot=True,
		# fmt='.2f',
	)
	axes[layer_id].set_title(f'Layer {layer_id} Silhouette Scores')
	plt.setp(axes[layer_id].get_xticklabels(), rotation=90, ha='right', rotation_mode='anchor')

	# Setup fontsize of the x and y label
	axes[layer_id].tick_params(axis='x', labelsize=5)
	axes[layer_id].tick_params(axis='y', labelsize=5)
	
	# Add bolded grid lines to separate language families
	# Calculate cumulative positions for family boundaries
	cumulative_langs = 0
	for family, num_langs in lengths_per_category.items():
		cumulative_langs += num_langs
		# Draw horizontal and vertical lines at family boundaries
		axes[layer_id].axhline(cumulative_langs, color='black', linewidth=1.5)
		axes[layer_id].axvline(cumulative_langs, color='black', linewidth=1.5)

	# if layer_id + 1 == 3:
	# 	break
	
plt.tight_layout()
plt.savefig(os.path.join(os.path.dirname(sil_path), f'silhouette_scores_heatmap_{model_name}_testing.png'), dpi=300, bbox_inches='tight')
plt.close()

In [None]:
import seaborn as sns

# Make 5 bins of the silhouette scores and color the heatmap accordingly
import matplotlib.pyplot as plt

# Find global min and max across all layers for consistent scale
global_min = float('inf')
global_max = float('-inf')
for layer_id in range(-1, num_layers):
	layer_distances = silhouette_score_matrix[layer_id + 1]
	# Exclude diagonal elements (which are 0)
	mask_no_diag = ~torch.eye(len(languages), dtype=torch.bool, device='cuda')
	non_diag_values = layer_distances[mask_no_diag]
	global_min = min(global_min, non_diag_values.min().item())
	global_max = max(global_max, non_diag_values.max().item())

# Create 3 bins
n_bins = 5
bin_edges = np.linspace(global_min, global_max, n_bins + 1)

# Create subplots
fig, axes = plt.subplots(nrows=(num_layers + 1) // 4 + 1, ncols=4, figsize=(20, 5 * ((num_layers + 1) // 4 + 1)))
axes = axes.flatten()

for layer_id in tqdm(range(0, num_layers)):
	layer_distances = silhouette_score_matrix[layer_id + 1].cpu().numpy()
	
	# Discretize the distances into bins
	binned_distances = np.digitize(layer_distances, bin_edges) - 1
	binned_distances = np.clip(binned_distances, 0, n_bins - 1)
	
	sns.heatmap(binned_distances, 
		xticklabels=[LANGCODE2LANGNAME[lang] for lang in languages], 
		yticklabels=[LANGCODE2LANGNAME[lang] for lang in languages], 
		ax=axes[layer_id], 
		cmap='viridis_r',
		vmin=0,
		vmax=n_bins - 1,
		cbar_kws={'label': 'Bin', 'ticks': range(n_bins)}
	)
	axes[layer_id].set_title(f'Layer {layer_id} Binned Silhouette Scores')
	plt.setp(axes[layer_id].get_xticklabels(), rotation=90, ha='right', rotation_mode='anchor')

	# Setup fontsize of the x and y label
	axes[layer_id].tick_params(axis='x', labelsize=5)
	axes[layer_id].tick_params(axis='y', labelsize=5)
	
	# Add bolded grid lines to separate language families
	cumulative_langs = 0
	for family, num_langs in lengths_per_category.items():
		cumulative_langs += num_langs
		axes[layer_id].axhline(cumulative_langs, color='black', linewidth=1.5)
		axes[layer_id].axvline(cumulative_langs, color='black', linewidth=1.5)
	
plt.tight_layout()
plt.savefig(f'silhouette_scores_binned{n_bins}_heatmap_{model_name}.png', dpi=300, bbox_inches='tight')
plt.close()

print(f"Bin edges: {bin_edges}")