In [None]:
import os
import sys

# Change to parent directory
os.chdir('..')
# Add current directory to path
sys.path.append('.')

In [None]:
import pandas as pd
import numpy as np
from src.utils.const import LANGCODE2LANGNAME, LANGNAME2LANGCODE, MODEL2HIDDEN_SIZE, MODEL2NUM_LAYERS, EXP2_CONFIG, EXP3_CONFIG, EXP4_CONFIG, MODEL2HF_NAME
import glob
import torch
from tqdm import tqdm

In [None]:
model_names = glob.glob('outputs_silhouette/exp4/next_token/dev/*')
model_names = [mn.split('/')[-1] for mn in model_names]
model_names.sort()
print(model_names)
model_name = model_names[1]

In [None]:
languages = []
for family, langs in EXP4_CONFIG['languages'].items():
	languages.extend(langs)

languages = [LANGNAME2LANGCODE[lang] for lang in languages]
languages

In [None]:
def get_google_sheet(sheet_id: str, sheet_gid: str) -> pd.DataFrame:
	"""
	Downloads a specific sheet from a Google Sheet into a pandas DataFrame.

	Args:
		sheet_id: The ID of the Google Sheet.
		sheet_gid: The GID of the specific sheet to download.

	Returns:
		A pandas DataFrame containing the data from the specified sheet.
	"""
	url = f'https://docs.google.com/spreadsheets/d/{sheet_id}/export?format=csv&gid={sheet_gid}'
	df = pd.read_csv(url)
	return df
google_sheet_id = '1CmhOZeYTbfePLI2-rMubJpnKHuS6RLEfHYZD6-rVQ0M'  # Replace with your actual Google Sheet ID
gid = '0'  # Replace with the actual GID for the Indonesian sheet
corrected_data = []
try:
	df_lang = get_google_sheet(google_sheet_id, gid)
except Exception as e:
	print(f"An error occurred: {e}")
	print("Please ensure the Google Sheet is shared correctly and the IDs are correct.")

In [None]:
df_lang

In [None]:
df_lang['Syntax'] = df_lang['Syntax'].fillna('Unknown')

In [None]:
df_lang['script'] = df_lang['Language code'].apply(lambda x: x.split('_')[1])

In [None]:
df_lang['Language sub-sub-family'] = df_lang['Language sub-sub-family'].fillna(df_lang['Language sub-family'])
df_lang['Language sub-sub-family'] = df_lang['Language sub-sub-family'].fillna(df_lang['Language family'])

In [None]:
lengths_per_category = {}
for family, langs in EXP4_CONFIG['languages'].items():
	lengths_per_category[family] = len(langs)

In [None]:
lang_to_subsubfamily = {}
for lang in languages:
	lang_name = LANGCODE2LANGNAME[lang]
	family_row = df_lang[df_lang['Language name'] == lang_name]
	if not family_row.empty:
		family = family_row.iloc[0]['Language sub-sub-family']
		lang_to_subsubfamily[lang_name] = family
	else:
		print(f"Warning: Language {lang_name} not found in the DataFrame.")
		lang_to_subsubfamily[lang_name] = 'Unknown'

In [None]:
lang_to_region = {}
for lang in languages:
	lang_name = LANGCODE2LANGNAME[lang]
	region_row = df_lang[df_lang['Language name'] == lang_name]
	if not region_row.empty:
		region = region_row.iloc[0]['Region']
		lang_to_region[lang_name] = region
	else:
		print(f"Warning: Language {lang_name} not found in the DataFrame.")
		lang_to_region[lang_name] = 'Unknown'

In [None]:
lang_to_script = {}
for lang in languages:
	lang_name = LANGCODE2LANGNAME[lang]
	script_row = df_lang[df_lang['Language name'] == lang_name]
	if not script_row.empty:
		script = script_row.iloc[0]['script']
		lang_to_script[lang_name] = script
	else:
		print(f"Warning: Language {lang_name} not found in the DataFrame.")
		lang_to_script[lang_name] = 'Unknown'

In [None]:
lang_to_subfamily = {}
for lang in languages:
	lang_name = LANGCODE2LANGNAME[lang]
	subfamily_row = df_lang[df_lang['Language name'] == lang_name]
	if not subfamily_row.empty:
		subfamily = subfamily_row.iloc[0]['Language sub-family']
		lang_to_subfamily[lang_name] = subfamily
	else:
		print(f"Warning: Language {lang_name} not found in the DataFrame.")
		lang_to_subfamily[lang_name] = 'Unknown'

In [None]:
lang_to_family = {}
for lang in languages:
	lang_name = LANGCODE2LANGNAME[lang]
	family_row = df_lang[df_lang['Language name'] == lang_name]
	if not family_row.empty:
		family = family_row.iloc[0]['Language family']
		lang_to_family[lang_name] = family
	else:
		lang_to_family[lang_name] = 'Unknown'

In [None]:
lang_to_syntax = {}
for lang in languages:
	lang_name = LANGCODE2LANGNAME[lang]
	syntax_row = df_lang[df_lang['Language name'] == lang_name]
	if not syntax_row.empty:
		syntax = syntax_row.iloc[0]['Syntax']
		lang_to_syntax[lang_name] = syntax
	else:
		print(f"Warning: Language {lang_name} not found in the DataFrame.")
		lang_to_syntax[lang_name] = 'Unknown'

In [None]:
set(lang_to_family.values()).__len__(), set(lang_to_region.values()).__len__(), set(lang_to_script.values()).__len__(), set(lang_to_subfamily.values()).__len__(), set(lang_to_subsubfamily.values()).__len__(), set(lang_to_syntax.values()).__len__()

In [None]:
categories = ['family', 'region', 'script', 'subfamily', 'subsubfamily', 'syntax']

### Dendogram

In [None]:
extraction_mode = 'raw'
token_position = 'last_token'
task_id = 'next_token'
exp_id = 'exp4'
activation_loc = 'residual-postmlp'

In [None]:
sil_path  = f'outputs_silhouette/{exp_id}/{task_id}/dev/{model_name}/{extraction_mode}/{token_position}/{activation_loc}/silhouette_score_matrix.pt'
silhouette_score_matrix = torch.load(sil_path, map_location='cpu')
num_layers = MODEL2NUM_LAYERS[model_name]
hidden_size = MODEL2HIDDEN_SIZE[model_name]

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from scipy.cluster.hierarchy import dendrogram, linkage
from tqdm import tqdm
import os

In [78]:
model_names[5:]

['gemma-3-12b-it', 'gemma-3-4b-it', 'pythia-6.9b-deduped']

In [79]:
from math import ceil
ceil(num_layers / 5)

10

In [80]:
num_layers

48

In [81]:
for model_name in model_names[5:]:
	if '-14B' in model_name:
		continue
	extraction_mode = 'raw'
	token_position = 'last_token'
	task_id = 'next_token'
	exp_id = 'exp4'
	activation_loc = 'residual-postmlp'

	sil_path  = f'outputs_silhouette/{exp_id}/{task_id}/dev/{model_name}/{extraction_mode}/{token_position}/{activation_loc}/silhouette_score_matrix.pt'
	silhouette_score_matrix = torch.load(sil_path, map_location='cpu')
	num_layers = MODEL2NUM_LAYERS[model_name]
	hidden_size = MODEL2HIDDEN_SIZE[model_name]
	print(f"Processing model: {model_name} with {num_layers} layers and hidden size {hidden_size}")
	for category in categories:
		print(f"Generating dendrograms colored by {category}...")
		
		# ==========================================
		# 1. COLOR MAP SETUP
		# ==========================================
		# We use 'nipy_spectral' which is great for high cardinality (many categories)
		unique_categories = list(set(globals()[f"lang_to_{category}"].values()))
		n_categories = len(unique_categories)

		# Create a color map object
		cmap = plt.get_cmap('nipy_spectral', n_categories)

		# Map every family name to a specific RGBA color
		category_colors = {cat: cmap(i) for i, cat in enumerate(unique_categories)}

		# ==========================================
		# 2. PLOTTING LOOP
		# ==========================================

		# Create subplots
		ncols = 5
		nrows = ceil(num_layers / ncols)
		# Make the figure size dynamic based on number of rows
		fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(35, 10 * nrows))
		fig.suptitle('Hierarchical Clustering Dendrograms Across All Layers', fontsize=24, y=0.99)

		axes = axes.flatten()

		# First pass: find max distance (kept from your code)
		max_distance = 0
		linkage_results = []
		# Assuming num_layers is defined elsewhere in your variable scope
		for layer_id in range(-1, num_layers):
			linked = linkage(silhouette_score_matrix[layer_id + 1].cpu().numpy(), 'complete')
			linkage_results.append(linked)
			max_distance = max(max_distance, linked[:, 2].max())
			pass 


		# Second pass: Plot and Color
		for layer_id in tqdm(range(0, num_layers)):
			ax_idx = layer_id + 1
			current_ax = axes[ax_idx - 1]
			
			# 1. Create Dendrogram
			d = dendrogram(
				linkage_results[ax_idx], 
				labels=[LANGCODE2LANGNAME[lang] for lang in languages], 
				orientation='right',
				ax=current_ax
			)
			
			current_ax.set_title(f"Layer {layer_id}", fontsize=12)
			current_ax.tick_params(axis='x', labelsize=6)
			
			# Set consistent x-axis limits
			current_ax.set_xlim([0, max_distance])  

			# ==========================================
			# 3. APPLY COLORS TO Y-AXIS LABELS
			# ==========================================
			
			# Get all text objects on the Y-axis
			y_labels = current_ax.get_ymajorticklabels()
			
			for label in y_labels:
				lang_name = label.get_text()
				
				# Look up the family, default to 'Unknown' if missing
				fam = globals()[f"lang_to_{category}"].get(lang_name, "Unknown")
				
				# Get the color, default to black if family not found in color map
				col = category_colors.get(fam, "black")
				
				# Set the color and force font size
				label.set_color(col)
				label.set_fontsize(6) # Ensuring size is readable


		# ==========================================
		# 4. CREATE A LEGEND (Optional but recommended)
		# ==========================================
		# Since you have 35 categories, we put the legend on the very first 
		# plot or a dedicated space, or outside the figure.
		# Here we add it to the top of the Figure.

		handles = [mpatches.Patch(color=category_colors[f], label=f) for f in unique_categories]
		fig.legend(handles=handles, loc='upper center', bbox_to_anchor=(0.5, 0.98), 
				ncol=7, fontsize=10, title=f"Language {category.capitalize()}s")

		# Save image
		image_path = os.path.join(os.path.dirname(sil_path), f'dendrogram_by_{category}.png')
		# Increased top margin (rect) to make room for the big legend
		plt.tight_layout(rect=[0, 0.0, 1, 0.96]) 
		plt.savefig(image_path, dpi=300)
		plt.close()

Processing model: gemma-3-12b-it with 48 layers and hidden size 3840
Generating dendrograms colored by family...


  linked = linkage(silhouette_score_matrix[layer_id + 1].cpu().numpy(), 'complete')
100%|██████████| 48/48 [00:04<00:00, 11.73it/s]


Generating dendrograms colored by region...


  linked = linkage(silhouette_score_matrix[layer_id + 1].cpu().numpy(), 'complete')
100%|██████████| 48/48 [00:04<00:00, 11.86it/s]


Generating dendrograms colored by script...


  linked = linkage(silhouette_score_matrix[layer_id + 1].cpu().numpy(), 'complete')
100%|██████████| 48/48 [00:13<00:00,  3.60it/s]


Generating dendrograms colored by subfamily...


  linked = linkage(silhouette_score_matrix[layer_id + 1].cpu().numpy(), 'complete')
100%|██████████| 48/48 [00:04<00:00,  9.85it/s]


Generating dendrograms colored by subsubfamily...


  linked = linkage(silhouette_score_matrix[layer_id + 1].cpu().numpy(), 'complete')
100%|██████████| 48/48 [00:05<00:00,  9.42it/s]


Generating dendrograms colored by syntax...


  linked = linkage(silhouette_score_matrix[layer_id + 1].cpu().numpy(), 'complete')
100%|██████████| 48/48 [00:03<00:00, 12.02it/s]


Processing model: gemma-3-4b-it with 34 layers and hidden size 2560
Generating dendrograms colored by family...


  linked = linkage(silhouette_score_matrix[layer_id + 1].cpu().numpy(), 'complete')
100%|██████████| 34/34 [00:04<00:00,  7.95it/s]


Generating dendrograms colored by region...


  linked = linkage(silhouette_score_matrix[layer_id + 1].cpu().numpy(), 'complete')
100%|██████████| 34/34 [00:02<00:00, 11.51it/s]


Generating dendrograms colored by script...


  linked = linkage(silhouette_score_matrix[layer_id + 1].cpu().numpy(), 'complete')
100%|██████████| 34/34 [00:02<00:00, 11.89it/s]


Generating dendrograms colored by subfamily...


  linked = linkage(silhouette_score_matrix[layer_id + 1].cpu().numpy(), 'complete')
100%|██████████| 34/34 [00:04<00:00,  7.35it/s]


Generating dendrograms colored by subsubfamily...


  linked = linkage(silhouette_score_matrix[layer_id + 1].cpu().numpy(), 'complete')
100%|██████████| 34/34 [00:02<00:00, 12.08it/s]


Generating dendrograms colored by syntax...


  linked = linkage(silhouette_score_matrix[layer_id + 1].cpu().numpy(), 'complete')
100%|██████████| 34/34 [00:02<00:00, 12.08it/s]


Processing model: pythia-6.9b-deduped with 32 layers and hidden size 4096
Generating dendrograms colored by family...


  linked = linkage(silhouette_score_matrix[layer_id + 1].cpu().numpy(), 'complete')
100%|██████████| 32/32 [00:04<00:00,  6.58it/s]


Generating dendrograms colored by region...


  linked = linkage(silhouette_score_matrix[layer_id + 1].cpu().numpy(), 'complete')
100%|██████████| 32/32 [00:02<00:00, 12.17it/s]


Generating dendrograms colored by script...


  linked = linkage(silhouette_score_matrix[layer_id + 1].cpu().numpy(), 'complete')
100%|██████████| 32/32 [00:02<00:00, 12.12it/s]


Generating dendrograms colored by subfamily...


  linked = linkage(silhouette_score_matrix[layer_id + 1].cpu().numpy(), 'complete')
100%|██████████| 32/32 [00:02<00:00, 12.15it/s]


Generating dendrograms colored by subsubfamily...


  linked = linkage(silhouette_score_matrix[layer_id + 1].cpu().numpy(), 'complete')
100%|██████████| 32/32 [00:05<00:00,  5.97it/s]


Generating dendrograms colored by syntax...


  linked = linkage(silhouette_score_matrix[layer_id + 1].cpu().numpy(), 'complete')
100%|██████████| 32/32 [00:02<00:00, 12.10it/s]


In [83]:
for model_name in model_names:
	extraction_mode = 'raw'
	token_position = 'last_token'
	task_id = 'next_token'
	exp_id = 'exp4'
	activation_loc = 'residual-postattn'
	if 'aya-101' in model_name:
		activation_loc = 'residual-postselfattn'
	else:
		continue

	sil_path  = f'outputs_silhouette/{exp_id}/{task_id}/dev/{model_name}/{extraction_mode}/{token_position}/{activation_loc}/silhouette_score_matrix.pt'
	silhouette_score_matrix = torch.load(sil_path, map_location='cpu')
	num_layers = MODEL2NUM_LAYERS[model_name]
	hidden_size = MODEL2HIDDEN_SIZE[model_name]
	print(f"Processing model: {model_name} with {num_layers} layers and hidden size {hidden_size}")
	for category in categories:
		print(f"Generating dendrograms colored by {category}...")
		
		# ==========================================
		# 1. COLOR MAP SETUP
		# ==========================================
		# We use 'nipy_spectral' which is great for high cardinality (many categories)
		unique_categories = list(set(globals()[f"lang_to_{category}"].values()))
		n_categories = len(unique_categories)

		# Create a color map object
		cmap = plt.get_cmap('nipy_spectral', n_categories)

		# Map every family name to a specific RGBA color
		category_colors = {cat: cmap(i) for i, cat in enumerate(unique_categories)}

		# ==========================================
		# 2. PLOTTING LOOP
		# ==========================================

		# Create subplots
		ncols = 5
		nrows = ceil(num_layers / ncols)
		# Make the figure size dynamic based on number of rows
		fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(35, 10 * nrows))
		fig.suptitle('Hierarchical Clustering Dendrograms Across All Layers', fontsize=24, y=0.99)

		axes = axes.flatten()

		# First pass: find max distance (kept from your code)
		max_distance = 0
		linkage_results = []
		# Assuming num_layers is defined elsewhere in your variable scope
		for layer_id in range(-1, num_layers):
			linked = linkage(silhouette_score_matrix[layer_id + 1].cpu().numpy(), 'complete')
			linkage_results.append(linked)
			max_distance = max(max_distance, linked[:, 2].max())
			pass 


		# Second pass: Plot and Color
		for layer_id in tqdm(range(0, num_layers)):
			ax_idx = layer_id + 1
			current_ax = axes[ax_idx - 1]
			
			# 1. Create Dendrogram
			d = dendrogram(
				linkage_results[ax_idx], 
				labels=[LANGCODE2LANGNAME[lang] for lang in languages], 
				orientation='right',
				ax=current_ax
			)
			
			current_ax.set_title(f"Layer {layer_id}", fontsize=12)
			current_ax.tick_params(axis='x', labelsize=6)
			
			# Set consistent x-axis limits
			current_ax.set_xlim([0, max_distance])  

			# ==========================================
			# 3. APPLY COLORS TO Y-AXIS LABELS
			# ==========================================
			
			# Get all text objects on the Y-axis
			y_labels = current_ax.get_ymajorticklabels()
			
			for label in y_labels:
				lang_name = label.get_text()
				
				# Look up the family, default to 'Unknown' if missing
				fam = globals()[f"lang_to_{category}"].get(lang_name, "Unknown")
				
				# Get the color, default to black if family not found in color map
				col = category_colors.get(fam, "black")
				
				# Set the color and force font size
				label.set_color(col)
				label.set_fontsize(6) # Ensuring size is readable


		# ==========================================
		# 4. CREATE A LEGEND (Optional but recommended)
		# ==========================================
		# Since you have 35 categories, we put the legend on the very first 
		# plot or a dedicated space, or outside the figure.
		# Here we add it to the top of the Figure.

		handles = [mpatches.Patch(color=category_colors[f], label=f) for f in unique_categories]
		fig.legend(handles=handles, loc='upper center', bbox_to_anchor=(0.5, 0.98), 
				ncol=7, fontsize=10, title=f"Language {category.capitalize()}s")

		# Save image
		image_path = os.path.join(os.path.dirname(sil_path), f'dendrogram_by_{category}.png')
		# Increased top margin (rect) to make room for the big legend
		plt.tight_layout(rect=[0, 0.0, 1, 0.96]) 
		plt.savefig(image_path, dpi=300)
		plt.close()

  linked = linkage(silhouette_score_matrix[layer_id + 1].cpu().numpy(), 'complete')


Processing model: aya-101 with 24 layers and hidden size 4096
Generating dendrograms colored by family...


  0%|          | 0/24 [00:00<?, ?it/s]

  ax.set_xlim([0, dvw])
100%|██████████| 24/24 [00:18<00:00,  1.28it/s]
  linked = linkage(silhouette_score_matrix[layer_id + 1].cpu().numpy(), 'complete')


Generating dendrograms colored by region...


  ax.set_xlim([0, dvw])
100%|██████████| 24/24 [00:02<00:00,  9.97it/s]
  linked = linkage(silhouette_score_matrix[layer_id + 1].cpu().numpy(), 'complete')


Generating dendrograms colored by script...


  ax.set_xlim([0, dvw])
100%|██████████| 24/24 [00:02<00:00,  9.64it/s]
  linked = linkage(silhouette_score_matrix[layer_id + 1].cpu().numpy(), 'complete')


Generating dendrograms colored by subfamily...


  ax.set_xlim([0, dvw])
100%|██████████| 24/24 [00:02<00:00,  9.18it/s]
  linked = linkage(silhouette_score_matrix[layer_id + 1].cpu().numpy(), 'complete')


Generating dendrograms colored by subsubfamily...


  ax.set_xlim([0, dvw])
100%|██████████| 24/24 [00:01<00:00, 12.05it/s]
  linked = linkage(silhouette_score_matrix[layer_id + 1].cpu().numpy(), 'complete')


Generating dendrograms colored by syntax...


  ax.set_xlim([0, dvw])
100%|██████████| 24/24 [00:02<00:00,  8.54it/s]


### Heatmap

In [None]:
# Heatmap per layer of the silhouette scores between languages, make a subplot for each layer with 4 columns
import seaborn as sns
import matplotlib.pyplot as plt

# Find global min and max across all layers for consistent scale
global_min = float('inf')
global_max = float('-inf')
for layer_id in range(-1, num_layers):
	layer_distances = silhouette_score_matrix[layer_id + 1]
	global_min = min(global_min, layer_distances.min().item())
	global_max = max(global_max, layer_distances.max().item())

fig, axes = plt.subplots(nrows=(num_layers + 1) // 4 + 1, ncols=4, figsize=(20, 5 * ((num_layers + 1) // 4 + 1)))
axes = axes.flatten()
for layer_id in tqdm(range(0, num_layers)):
	layer_distances = silhouette_score_matrix[layer_id + 1]
	sns.heatmap(layer_distances.numpy(), 
		xticklabels=[LANGCODE2LANGNAME[lang] for lang in languages], 
		yticklabels=[LANGCODE2LANGNAME[lang] for lang in languages], 
		ax=axes[layer_id], 
		cmap='viridis_r',
		vmin=global_min,
		vmax=global_max,
		# annot=True,
		# fmt='.2f',
	)
	axes[layer_id].set_title(f'Layer {layer_id} Silhouette Scores')
	plt.setp(axes[layer_id].get_xticklabels(), rotation=90, ha='right', rotation_mode='anchor')

	# Setup fontsize of the x and y label
	axes[layer_id].tick_params(axis='x', labelsize=5)
	axes[layer_id].tick_params(axis='y', labelsize=5)
	
	# Add bolded grid lines to separate language families
	# Calculate cumulative positions for family boundaries
	cumulative_langs = 0
	for family, num_langs in lengths_per_category.items():
		cumulative_langs += num_langs
		# Draw horizontal and vertical lines at family boundaries
		axes[layer_id].axhline(cumulative_langs, color='black', linewidth=1.5)
		axes[layer_id].axvline(cumulative_langs, color='black', linewidth=1.5)

	# if layer_id + 1 == 3:
	# 	break
	
plt.tight_layout()
plt.savefig(os.path.join(os.path.dirname(sil_path), f'silhouette_scores_heatmap_{model_name}_testing.png'), dpi=300, bbox_inches='tight')
plt.close()

In [None]:
import seaborn as sns

# Make 5 bins of the silhouette scores and color the heatmap accordingly
import matplotlib.pyplot as plt

# Find global min and max across all layers for consistent scale
global_min = float('inf')
global_max = float('-inf')
for layer_id in range(-1, num_layers):
	layer_distances = silhouette_score_matrix[layer_id + 1]
	# Exclude diagonal elements (which are 0)
	mask_no_diag = ~torch.eye(len(languages), dtype=torch.bool, device='cuda')
	non_diag_values = layer_distances[mask_no_diag]
	global_min = min(global_min, non_diag_values.min().item())
	global_max = max(global_max, non_diag_values.max().item())

# Create 3 bins
n_bins = 5
bin_edges = np.linspace(global_min, global_max, n_bins + 1)

# Create subplots
fig, axes = plt.subplots(nrows=(num_layers + 1) // 4 + 1, ncols=4, figsize=(20, 5 * ((num_layers + 1) // 4 + 1)))
axes = axes.flatten()

for layer_id in tqdm(range(0, num_layers)):
	layer_distances = silhouette_score_matrix[layer_id + 1].cpu().numpy()
	
	# Discretize the distances into bins
	binned_distances = np.digitize(layer_distances, bin_edges) - 1
	binned_distances = np.clip(binned_distances, 0, n_bins - 1)
	
	sns.heatmap(binned_distances, 
		xticklabels=[LANGCODE2LANGNAME[lang] for lang in languages], 
		yticklabels=[LANGCODE2LANGNAME[lang] for lang in languages], 
		ax=axes[layer_id], 
		cmap='viridis_r',
		vmin=0,
		vmax=n_bins - 1,
		cbar_kws={'label': 'Bin', 'ticks': range(n_bins)}
	)
	axes[layer_id].set_title(f'Layer {layer_id} Binned Silhouette Scores')
	plt.setp(axes[layer_id].get_xticklabels(), rotation=90, ha='right', rotation_mode='anchor')

	# Setup fontsize of the x and y label
	axes[layer_id].tick_params(axis='x', labelsize=5)
	axes[layer_id].tick_params(axis='y', labelsize=5)
	
	# Add bolded grid lines to separate language families
	cumulative_langs = 0
	for family, num_langs in lengths_per_category.items():
		cumulative_langs += num_langs
		axes[layer_id].axhline(cumulative_langs, color='black', linewidth=1.5)
		axes[layer_id].axvline(cumulative_langs, color='black', linewidth=1.5)
	
plt.tight_layout()
plt.savefig(f'silhouette_scores_binned{n_bins}_heatmap_{model_name}.png', dpi=300, bbox_inches='tight')
plt.close()

print(f"Bin edges: {bin_edges}")