# Multi-Family Multi-Model Information Content Benchmark

This notebook performs comprehensive information-theoretic benchmarks on **multiple protein families** using **multiple FoldTree2 models** with different alphabet sizes, comparing them against standard amino acid representations.

## Purpose

Evaluate how well discrete structural alphabets (DSR - Discrete Structural Representations) capture biologically relevant information compared to amino acid sequences, specifically testing:

1. **Fold Discrimination**: Can k-mers distinguish protein folds better?
2. **Entropy Rate**: How compressible are the representations?
3. **Position-Specific Conservation**: Does entropy reflect functional constraints?
4. **Cross-Representation Information**: How much structural information predicts sequence?

## Family Data Sources

This notebook supports **two family data sources**:

### Option 1: Custom Families (USE_MAKESUBMAT_FAMILIES = False)
Manual curation of protein families in separate folders:
```
FAMILIES_DIR/
	├── rhodopsin/
	│   ├── structure1.pdb
	│   └── structure2.pdb
	├── kinase/
	│   ├── structure3.pdb
	│   └── structure4.pdb
	└── ...
```

### Option 2: AFDB Cluster Families (USE_MAKESUBMAT_FAMILIES = True) ⭐ RECOMMENDED
**Uses the same ~200+ families as `makesubmat.py`** from AlphaFold Database clusters:
```
struct_align/              # Created by makesubmat.py
	├── AF-A0A024RBG1-F1/   # AFDB cluster representative ID
	│   └── structs/
	│       ├── AF-A0A024RBG1-F1-model_v4.pdb
	│       ├── AF-P12345-F1-model_v4.pdb
	│       └── ...
	├── AF-A0A075B6H9-F1/
	│   └── structs/
	│       └── ...
	└── ... (200+ families)
```

**Prerequisites for Option 2**:
1. Run `makesubmat.py --download_structs` to fetch AFDB cluster structures
2. Point `MAKESUBMAT_BASE_DIR` to your datasets directory
3. Set `USE_MAKESUBMAT_FAMILIES = True` in Cell #3

**Advantages of Option 2**:
- **Consistency**: Benchmark on same data used for substitution matrices
- **Scale**: ~200+ diverse protein families automatically
- **Reproducibility**: Standard dataset from AlphaFold Database clusters
- **No manual curation**: Families pre-defined by structural clustering

## Benchmarks Performed (Per Family, Per Model)

### 1. **K-mer Fold Discrimination**
- Compute k-mer frequency distributions (k=1,2,3,4)
- Perform KMeans clustering
- Calculate silhouette scores as discrimination metric
- **Higher scores** = better fold separation

### 2. **Entropy Rate Estimation**  
- Build k-order Markov models (orders 0-3)
- Cross-validation for entropy estimation
- **Lower entropy** = more predictable/compressible representation

### 3. **Per-Position Entropy**
- Create Multiple Sequence Alignments (MSAs)
- Apply Henikoff sequence reweighting
- Calculate position-specific entropy
- **Entropy profile** reveals conserved vs variable positions

### 4. **Cross-Representation Mutual Information**
- Train Ridge regression: FoldTree2 features → AA features
- Cross-validation R² scores
- Spearman correlation between representations
- **Higher scores** = more shared information

## Outputs

1. **JSON**: `all_families_results.json` - Complete results
2. **CSV**: `all_families_results.csv` - Flattened table
3. **Plots**:
	 - `multi_family_benchmark_comparison.png` - Line plots across families
	 - `model_family_heatmaps.png` - Performance heatmaps

## How to Use

1. **Configure** (Cell #3):
	 - **Choose data source**: Set `USE_MAKESUBMAT_FAMILIES = True` (AFDB) or `False` (custom)
	 - **Set paths**: 
		 - If using AFDB: `MAKESUBMAT_BASE_DIR` (datasets directory)
		 - If custom: `FAMILIES_DIR` (your families directory)
	 - **Optional**: Set `MAX_FAMILIES` to limit number (useful for testing)
	 - Set `MODELS` to list of FoldTree2 model paths
	 - Adjust `BENCHMARK_PARAMS` if needed

2. **Run cells sequentially** (Cells 1-9)

3. **View results**:
	 - Terminal output shows progress and summary statistics
	 - Plots display comparative performance
	 - CSV/JSON files for detailed analysis

## Quick Start with AFDB Families

```python
# In Cell #3:
USE_MAKESUBMAT_FAMILIES = True
MAKESUBMAT_BASE_DIR = '/path/to/your/datasets'  # Contains struct_align/
MAX_FAMILIES = 50  # Start with 50 families for quick test (or None for all)

# Then run all cells
```

In [3]:
import os
import sys
import tqdm					

import glob
import torch
import numpy as np
import pandas as pd
from pathlib import Path
from typing import List, Dict, Tuple
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

# Import FoldTree2 modules
sys.path.insert(0, '/home/dmoi/projects/foldtree2')
from foldtree2.src import encoder as ft2
from foldtree2.src.pdbgraph import PDB2PyG, StructureDataset
from torch_geometric.data import DataLoader
from Bio import SeqIO
from Bio.PDB import PDBParser

# Set matplotlib style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

print("✓ Imports successful")

✓ Imports successful


In [4]:
# ==================== CONFIGURATION ====================

from pathlib import Path

# Model paths (list of models to benchmark)
MODELS = [
	'/home/dmoi/projects/foldtree2/models/model_10_embeddings',
	'/home/dmoi/projects/foldtree2/models/model_20_embeddings',
	'/home/dmoi/projects/foldtree2/models/model_30_embeddings',
	'/home/dmoi/projects/foldtree2/models/model_40_embeddings',
]

# ==================== FAMILY DATA SOURCE CONFIGURATION ====================
# Choose one of two modes:
#   1. Custom families in separate folders (USE_MAKESUBMAT_FAMILIES = False)
#   2. AFDB cluster families from makesubmat (USE_MAKESUBMAT_FAMILIES = True)

USE_MAKESUBMAT_FAMILIES = True  # Set to True to use makesubmat AFDB families

if USE_MAKESUBMAT_FAMILIES:
	# Path to the struct_align directory created by makesubmat.py
	# This contains one subfolder per AFDB cluster representative
	# Structure: MAKESUBMAT_BASE_DIR/struct_align/{repId}/structs/*.pdb
	MAKESUBMAT_BASE_DIR = '/mnt/data2/datasets'
	FAMILIES_DIR = os.path.join(MAKESUBMAT_BASE_DIR, 'struct_align')
	
	# Optional: Limit number of families (useful for quick tests)
	# Set to None to use all available families
	MAX_FAMILIES = None  # or e.g., 50 for first 50 families
	
	print(f"Using makesubmat AFDB cluster families from: {FAMILIES_DIR}")
else:
	# Custom families directory containing subfolders for each protein family
	# Each subfolder should contain PDB files for that family
	# Example structure:
	#   FAMILIES_DIR/
	#     rhodopsin/
	#       structure1.pdb
	#       structure2.pdb
	#     kinase/
	#       structure3.pdb
	#       structure4.pdb
	FAMILIES_DIR = '/home/dmoi/projects/foldtree2/families/examples'
	MAX_FAMILIES = None
	
	print(f"Using custom families from: {FAMILIES_DIR}")

# Output directory for results
OUTPUT_DIR = '/home/dmoi/projects/foldtree2/benchmark_output'

# Device configuration
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Benchmark parameters
BENCHMARK_PARAMS = {
	'k_mer_sizes': [1, 2, 3, 4],
	'markov_orders': [0, 1, 2, 3],
	'cv_folds': 5,
	'alpha_smoothing': 0.001,
	'occupancy_threshold': 0.7,
	'reweight_threshold': 0.8,
}

# Standard amino acid alphabet
AA_ALPHABET = list('ACDEFGHIKLMNPQRSTVWY')

print(f"\nConfiguration loaded:")
print(f"  Models: {len(MODELS)}")
print(f"  Families directory: {FAMILIES_DIR}")
print(f"  Output: {OUTPUT_DIR}")
print(f"  Device: {DEVICE}")

# ==================== FAMILY DISCOVERY ====================

def discover_families(families_dir, use_makesubmat=False, max_families=None):
	"""
	Discover protein families from directory structure.
	
	Supports two modes:
	1. Custom families: Each subfolder is a family
	2. Makesubmat AFDB: Each subfolder contains a 'structs' subdirectory
	
	Args:
		families_dir: Base directory containing families
		use_makesubmat: If True, expect makesubmat structure (repId/structs/)
		max_families: Maximum number of families to include (None = all)
	
	Returns:
		dict: {family_name: {'path': str, 'n_structures': int}}
	"""
	families = {}
	
	if not os.path.isdir(families_dir):
		print(f"ERROR: Families directory not found: {families_dir}")
		return families
	
	print(f"Scanning directories in {families_dir}...")
	
	# Use os.scandir for efficient directory traversal
	family_count = 0
	with tqdm.tqdm(desc="Discovering families") as pbar:
		for entry in os.scandir(families_dir):
			if not entry.is_dir():
				continue
			
			# Check if we've reached the max_families limit
			if max_families is not None and family_count >= max_families:
				break
			
			family_base_path = entry.path
			
			# Determine the actual structures directory
			if use_makesubmat:
				# Makesubmat structure: {repId}/structs/*.pdb
				structs_dir = os.path.join(family_base_path, 'structs')
				if not os.path.isdir(structs_dir):
					pbar.update(1)
					continue
				family_path = structs_dir
			else:
				# Custom structure: {family_name}/*.pdb
				family_path = family_base_path
			
			# Count PDB files using os.scandir (much faster than glob)
			pdb_count = 0
			try:
				for file_entry in os.scandir(family_path):
					if file_entry.is_file() and file_entry.name.endswith('.pdb'):
						pdb_count += 1
			except (PermissionError, OSError):
				pbar.update(1)
				continue
			
			# Only include if there are structures
			if pdb_count > 0:
				families[entry.name] = {
					'path': family_path,
					'n_structures': pdb_count
				}
				family_count += 1
			
			pbar.update(1)
	
	return families

# Discover families
FAMILIES = discover_families(
	FAMILIES_DIR, 
	use_makesubmat=USE_MAKESUBMAT_FAMILIES,
	max_families=MAX_FAMILIES
)

print(f"\nDiscovered {len(FAMILIES)} protein families:")
if len(FAMILIES) == 0:
	print(f"ERROR: No families found!")
	if USE_MAKESUBMAT_FAMILIES:
		print(f"Expected structure: {FAMILIES_DIR}/{{repId}}/structs/*.pdb")
		print(f"Make sure you've run makesubmat.py with --download_structs first")
	else:
		print(f"Expected structure: {FAMILIES_DIR}/{{family_name}}/*.pdb")
else:
	# Show first 10 families as sample
	sample_families = list(FAMILIES.items())[:10]
	for family_name, family_info in sample_families:
		print(f"  - {family_name}: {family_info['n_structures']} structures")
	
	if len(FAMILIES) > 10:
		print(f"  ... and {len(FAMILIES) - 10} more families")
	
	# Show statistics
	structure_counts = [info['n_structures'] for info in FAMILIES.values()]
	print(f"\nFamily statistics:")
	print(f"  Total families: {len(FAMILIES)}")
	print(f"  Total structures: {sum(structure_counts)}")
	print(f"  Structures per family (mean ± std): {np.mean(structure_counts):.1f} ± {np.std(structure_counts):.1f}")
	print(f"  Range: {min(structure_counts)} - {max(structure_counts)}")

Using makesubmat AFDB cluster families from: /mnt/data2/datasets/struct_align

Configuration loaded:
  Models: 4
  Families directory: /mnt/data2/datasets/struct_align
  Output: /home/dmoi/projects/foldtree2/benchmark_output
  Device: cuda
Scanning directories in /mnt/data2/datasets/struct_align...


Discovering families: 0it [00:00, ?it/s]

Discovering families: 2508it [00:15, 159.18it/s]


KeyboardInterrupt: 

In [None]:
from directory_tree import DisplayTree

# ==================== CLEANUP: REMOVE EMPTY DIRECTORIES ====================

def cleanup_empty_directories(base_dir, use_makesubmat=False, dry_run=True):
	"""
	Remove directories that don't contain any structure files.
	
	Args:
		base_dir: Base directory to scan
		use_makesubmat: If True, look for structs/ subdirectory
		dry_run: If True, only report what would be deleted (don't actually delete)
	
	Returns:
		tuple: (n_deleted, list_of_deleted_paths)
	"""
	if not os.path.isdir(base_dir):
		print(f"ERROR: Directory not found: {base_dir}")
		return 0, []
	
	deleted_dirs = []
	print(f"Scanning for empty directories in: {base_dir}")
	print(f"Mode: {'DRY RUN (no deletion)' if dry_run else 'DELETION ENABLED'}")
	
	rmcount = 0
	with tqdm.tqdm(desc="Scanning directories") as pbar:
		for entry in os.scandir(base_dir):
			if not entry.is_dir():
				continue
			
			# Determine the actual structures directory
			if use_makesubmat:
				# Check if structs/ subdirectory exists
				structs_dir = os.path.join(entry.path, 'structs')
				if not os.path.isdir(structs_dir):
					# No structs directory - mark for deletion
					check_dir = entry.path
					has_structures = False
				else:
					check_dir = structs_dir
					# Count PDB files in structs/
					has_structures = False
					try:
						for file_entry in os.scandir(check_dir):
							if file_entry.is_file() and file_entry.name.endswith('.pdb'):
								has_structures = True
								break
					except (PermissionError, OSError):
						pass
			else:
				# Check for PDB files directly in directory
				check_dir = entry.path
				has_structures = False
				try:
					for file_entry in os.scandir(check_dir):
						if file_entry.is_file() and file_entry.name.endswith('.pdb'):
							has_structures = True
							break
				except (PermissionError, OSError):
					pass
			
			# If no structures found, mark for deletion
			if not has_structures:
				deleted_dirs.append(entry.path)
				if not dry_run:
					import shutil
					try:
						shutil.rmtree(entry.path)
						#print(f"  ✓ Deleted: {entry.name}")
					except Exception as e:
						print(f"  ✗ Failed to delete {entry.name}: {e}")
				else:
					print(f"  [DRY RUN] Would delete: {entry.name}")
					#print dir tree strcuture
					DisplayTree(entry.path, maxDepth=3, showHidden=True)

				rmcount += 1
				if dry_run and rmcount > 30:
					print("  ... Dry run limit reached (200). Stopping further checks.")
					break
			pbar.update(1)
	
	print(f"\n{'='*70}")
	print(f"Summary:")
	print(f"  Empty directories found: {len(deleted_dirs)}")
	if not dry_run:
		print(f"  Directories deleted: {len(deleted_dirs)}")
	else:
		print(f"  (Set dry_run=False to actually delete)")
	print(f"{'='*70}")
	
	return len(deleted_dirs), deleted_dirs

# Example usage (with dry_run=True by default for safety):
# Uncomment to run cleanup

n_deleted, deleted_paths = cleanup_empty_directories(
	FAMILIES_DIR,
	use_makesubmat=USE_MAKESUBMAT_FAMILIES,
	dry_run=False  # Change to False to actually delete
)


print("✓ Cleanup utility loaded")
print("  Run cleanup_empty_directories() to scan for empty directories")
print("  Set dry_run=False to actually delete them")

Scanning for empty directories in: /mnt/data2/datasets/struct_align
Mode: DELETION ENABLED


Scanning directories: 1110161it [4:22:56, 59.40it/s] 

In [None]:
# ==================== DATA LOADING & ENCODING ====================

class ModelBenchmark:
	"""Container for model encoding results and metadata"""
	def __init__(self, model_path: str, device):
		self.model_path = model_path
		self.model_name = Path(model_path).stem
		self.device = device
		
		# Load encoder
		print(f"  Loading model: {self.model_name}")
		self.encoder = torch.load(model_path + '.pt', map_location=device, weights_only=False)
		self.encoder = self.encoder.to(device)
		self.encoder.device = device
		self.encoder.eval()
		
		# Extract model metadata
		self.num_embeddings = self.encoder.num_embeddings
		self.embedding_dim = self.encoder.out_channels
		
		# Storage for encoded sequences
		self.encoded_fasta = None
		self.encoded_df = None
		self.alphabet = None
		self.char_position_map = None
		
		print(f"    ✓ Loaded: {self.num_embeddings} embeddings, dim={self.embedding_dim}")
	
	def encode_structures(self, structures_loader, output_dir, family_name="encoded"):
		"""Encode structures using this model"""
		output_path = os.path.join(output_dir, f"{family_name}_{self.model_name}_encoded.fasta")
		
		print(f"  Encoding structures with {self.model_name}...")
		self.encoder.encode_structures_fasta(
			structures_loader, 
			output_path, 
			replace=True
		)
		
		self.encoded_fasta = output_path
		self.encoded_df = ft2.load_encoded_fasta(output_path, alphabet=None, replace=False)
		self._build_alphabet()
		
		print(f"    ✓ Encoded {len(self.encoded_df)} sequences")
		return output_path
	
	def _build_alphabet(self):
		"""Build alphabet from encoded sequences"""
		char_set = set()
		for seq in self.encoded_df.seq:
			char_set = char_set.union(set(seq))
		self.alphabet = sorted(list(char_set))
		self.char_position_map = {char: i for i, char in enumerate(self.alphabet)}
		print(f"    ✓ Alphabet size: {len(self.alphabet)} characters")

def load_structures(structures_dir: str, converter: PDB2PyG, verbose: bool = True):
	"""Load and convert PDB structures to PyG format"""
	pdb_files = glob.glob(os.path.join(structures_dir, "*.pdb"))
	if verbose:
		print(f"  Found {len(pdb_files)} PDB files")
	
	if len(pdb_files) == 0:
		raise ValueError(f"No PDB files found in {structures_dir}")
	
	def struct_generator():
		for pdb_file in pdb_files:
			try:
				data = converter.struct2pyg(pdb_file)
				if data:
					yield data
			except Exception as e:
				if verbose:
					print(f"    Warning: Failed to convert {Path(pdb_file).name}: {e}")
				continue
	
	return struct_generator()

def extract_aa_sequences(structures_dir: str, output_dir: str, family_name: str = "sequences", verbose: bool = True):
	"""Extract amino acid sequences from PDB files"""
	parser = PDBParser(QUIET=True)
	aa_dict = {
		'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D', 'CYS': 'C',
		'GLN': 'Q', 'GLU': 'E', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I',
		'LEU': 'L', 'LYS': 'K', 'MET': 'M', 'PHE': 'F', 'PRO': 'P',
		'SER': 'S', 'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V'
	}
	
	pdb_files = glob.glob(os.path.join(structures_dir, "*.pdb"))
	sequences = {}
	
	if verbose:
		print(f"  Extracting amino acid sequences from {len(pdb_files)} structures...")
	for pdb_file in pdb_files:
		structure_id = Path(pdb_file).stem
		try:
			structure = parser.get_structure(structure_id, pdb_file)
			seq = ""
			for model in structure:
				for chain in model:
					for residue in chain:
						if residue.get_resname() in aa_dict:
							seq += aa_dict[residue.get_resname()]
					break  # Only first chain
				break  # Only first model
			
			if seq:
				sequences[structure_id] = seq
		except Exception as e:
			if verbose:
				print(f"    Warning: Failed to extract sequence from {Path(pdb_file).name}: {e}")
	
	# Write to FASTA with family-specific naming
	output_path = os.path.join(output_dir, f"{family_name}_aa_sequences.fasta")
	with open(output_path, 'w') as f:
		for struct_id, seq in sequences.items():
			f.write(f">{struct_id}\n{seq}\n")
	
	if verbose:
		print(f"    ✓ Extracted {len(sequences)} AA sequences")
	return output_path, sequences

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Initialize converter
converter = PDB2PyG(aapropcsv='/home/dmoi/projects/foldtree2/foldtree2/config/aaindex1.csv')

print("✓ Setup complete")

In [None]:
# ==================== RUN ALL BENCHMARKS PER FAMILY ====================

# Store results per family
all_results = {}

for family_name, family_info in FAMILIES.items():
	print(f"\n{'='*70}")
	print(f"FAMILY: {family_name}")
	print(f"{'='*70}")
	print(f"  {family_info['n_structures']} structures in {family_info['path']}")
	
	family_path = family_info['path']
	family_results = {
		'n_structures': family_info['n_structures'],
		'models': [],
		'alphabet_sizes': [],
		'k_mer_discrimination': {},
		'entropy_rates': {},
		'per_position_entropy': {},
		'cross_representation_mi': {}
	}
	
	# Extract AA sequences for this family
	print(f"\n1. Extracting amino acid sequences for {family_name}...")
	aa_fasta_path, aa_sequences = extract_aa_sequences(
		family_path, OUTPUT_DIR, family_name, verbose=True
	)
	
	# Load structures once for all models
	print(f"\n2. Loading structures for {family_name}...")
	
	# Encode structures with each model
	model_benchmarks = []
	for model_path in MODELS:
		print(f"\n3. Processing model: {Path(model_path).stem}")
		mb = ModelBenchmark(model_path, DEVICE)
		
		# Load and encode structures
		structures_loader = load_structures(family_path, converter, verbose=True)
		mb.encode_structures(structures_loader, OUTPUT_DIR, family_name)
		model_benchmarks.append(mb)
		
		family_results['models'].append(mb.model_name)
		family_results['alphabet_sizes'].append(mb.num_embeddings)
	
	print(f"\n{'─'*70}")
	print(f"BENCHMARKING {family_name}")
	print(f"{'─'*70}")
	
	# ==================== BENCHMARK 1: K-MER FOLD DISCRIMINATION ====================
	print("\n4. K-mer Fold Discrimination Test")
	print("   Testing if k-mers can discriminate folds better than random...")
	
	for mb in model_benchmarks:
		model_results = {}
		for k in [1, 2, 3, 4]:
			# Count k-mers for each sequence
			kmer_counts = []
			for seq in mb.encoded_df.seq:
				kmers = [seq[i:i+k] for i in range(len(seq)-k+1)]
				kmer_counter = Counter(kmers)
				# Create feature vector from k-mer counts
				kmer_counts.append(list(kmer_counter.values()))
			
			# Pad to same length
			max_len = max(len(x) for x in kmer_counts)
			X = np.array([x + [0]*(max_len-len(x)) for x in kmer_counts])
			
			# KMeans clustering
			kmeans = KMeans(n_clusters=min(3, len(X)), random_state=42)
			labels = kmeans.fit_predict(X)
			score = silhouette_score(X, labels)
			
			model_results[f'k={k}'] = score
		
		family_results['k_mer_discrimination'][mb.model_name] = model_results
		print(f"   {mb.model_name}: k=1: {model_results['k=1']:.3f}, k=2: {model_results['k=2']:.3f}, k=3: {model_results['k=3']:.3f}, k=4: {model_results['k=4']:.3f}")
	
	# Also run for AA sequences
	aa_results = {}
	for k in [1, 2, 3, 4]:
		kmer_counts = []
		for seq in aa_sequences.values():
			kmers = [seq[i:i+k] for i in range(len(seq)-k+1)]
			kmer_counter = Counter(kmers)
			kmer_counts.append(list(kmer_counter.values()))
		
		max_len = max(len(x) for x in kmer_counts)
		X = np.array([x + [0]*(max_len-len(x)) for x in kmer_counts])
		
		kmeans = KMeans(n_clusters=min(3, len(X)), random_state=42)
		labels = kmeans.fit_predict(X)
		score = silhouette_score(X, labels)
		aa_results[f'k={k}'] = score
	
	family_results['k_mer_discrimination']['AA'] = aa_results
	print(f"   AA: k=1: {aa_results['k=1']:.3f}, k=2: {aa_results['k=2']:.3f}, k=3: {aa_results['k=3']:.3f}, k=4: {aa_results['k=4']:.3f}")
	
	# ==================== BENCHMARK 2: ENTROPY RATE ESTIMATION ====================
	print("\n5. Entropy Rate Estimation")
	print("   Estimating entropy rate using k-order Markov models...")
	
	for mb in model_benchmarks:
		model_results = {}
		sequences = list(mb.encoded_df.seq)
		
		for order in [0, 1, 2, 3]:
			# k-fold cross-validation
			n_folds = min(5, len(sequences))
			fold_size = len(sequences) // n_folds
			entropies = []
			
			for fold in range(n_folds):
				# Split data
				test_start = fold * fold_size
				test_end = test_start + fold_size
				test_seqs = sequences[test_start:test_end]
				train_seqs = sequences[:test_start] + sequences[test_end:]
				
				# Build k-order Markov model from training data
				if order == 0:
					# 0-order: character frequencies
					char_counts = Counter()
					for seq in train_seqs:
						char_counts.update(seq)
					total = sum(char_counts.values())
					probs = {c: (count + 0.001) / (total + 0.001 * len(char_counts)) 
							for c, count in char_counts.items()}
					
					# Calculate entropy on test set
					test_entropy = 0
					for seq in test_seqs:
						for c in seq:
							if c in probs:
								test_entropy += -np.log2(probs[c])
							else:
								test_entropy += -np.log2(0.001 / (total + 0.001 * len(char_counts)))
					test_entropy /= sum(len(s) for s in test_seqs)
				else:
					# k-order: k-gram frequencies
					kgram_counts = Counter()
					context_counts = Counter()
					for seq in train_seqs:
						for i in range(len(seq) - order):
							context = seq[i:i+order]
							next_char = seq[i+order]
							kgram_counts[(context, next_char)] += 1
							context_counts[context] += 1
					
					# Calculate entropy on test set
					test_entropy = 0
					for seq in test_seqs:
						for i in range(len(seq) - order):
							context = seq[i:i+order]
							next_char = seq[i+order]
							# Laplace smoothing
							count = kgram_counts.get((context, next_char), 0) + 0.001
							total = context_counts.get(context, 0) + 0.001 * len(mb.alphabet)
							prob = count / total
							test_entropy += -np.log2(prob)
					test_entropy /= sum(len(s) - order for s in test_seqs)
				
				entropies.append(test_entropy)
			
			model_results[f'order={order}'] = np.mean(entropies)
		
		family_results['entropy_rates'][mb.model_name] = model_results
		print(f"   {mb.model_name}: 0-order: {model_results['order=0']:.3f}, 1-order: {model_results['order=1']:.3f}, 2-order: {model_results['order=2']:.3f}, 3-order: {model_results['order=3']:.3f}")
	
	# Also run for AA sequences
	aa_alphabet = set('ACDEFGHIKLMNPQRSTVWY')
	aa_results = {}
	sequences = list(aa_sequences.values())
	
	for order in [0, 1, 2, 3]:
		n_folds = min(5, len(sequences))
		fold_size = len(sequences) // n_folds
		entropies = []
		
		for fold in range(n_folds):
			test_start = fold * fold_size
			test_end = test_start + fold_size
			test_seqs = sequences[test_start:test_end]
			train_seqs = sequences[:test_start] + sequences[test_end:]
			
			if order == 0:
				char_counts = Counter()
				for seq in train_seqs:
					char_counts.update(seq)
				total = sum(char_counts.values())
				probs = {c: (count + 0.001) / (total + 0.001 * len(char_counts)) 
						for c, count in char_counts.items()}
				
				test_entropy = 0
				for seq in test_seqs:
					for c in seq:
						if c in probs:
							test_entropy += -np.log2(probs[c])
						else:
							test_entropy += -np.log2(0.001 / (total + 0.001 * len(char_counts)))
				test_entropy /= sum(len(s) for s in test_seqs)
			else:
				kgram_counts = Counter()
				context_counts = Counter()
				for seq in train_seqs:
					for i in range(len(seq) - order):
						context = seq[i:i+order]
						next_char = seq[i+order]
						kgram_counts[(context, next_char)] += 1
						context_counts[context] += 1
				
				test_entropy = 0
				for seq in test_seqs:
					for i in range(len(seq) - order):
						context = seq[i:i+order]
						next_char = seq[i+order]
						count = kgram_counts.get((context, next_char), 0) + 0.001
						total = context_counts.get(context, 0) + 0.001 * len(aa_alphabet)
						prob = count / total
						test_entropy += -np.log2(prob)
				test_entropy /= sum(len(s) - order for s in test_seqs)
			
			entropies.append(test_entropy)
		
		aa_results[f'order={order}'] = np.mean(entropies)
	
	family_results['entropy_rates']['AA'] = aa_results
	print(f"   AA: 0-order: {aa_results['order=0']:.3f}, 1-order: {aa_results['order=1']:.3f}, 2-order: {aa_results['order=2']:.3f}, 3-order: {aa_results['order=3']:.3f}")
	
	# ==================== BENCHMARK 3: PER-POSITION ENTROPY ====================
	print("\n6. Per-Position Entropy")
	print("   Calculating position-specific entropy from MSAs...")
	
	for mb in model_benchmarks:
		# Create MSA
		sequences = list(mb.encoded_df.seq)
		max_len = max(len(s) for s in sequences)
		
		# Pad sequences with gaps
		padded = [s + '-' * (max_len - len(s)) for s in sequences]
		
		# Calculate Henikoff weights
		weights = []
		for i, seq in enumerate(padded):
			weight = 0
			for pos in range(max_len):
				char = seq[pos]
				n_different = len(set(s[pos] for s in padded))
				n_with_char = sum(1 for s in padded if s[pos] == char)
				weight += 1.0 / (n_different * n_with_char)
			weights.append(weight)
		
		# Normalize weights
		total_weight = sum(weights)
		weights = [w / total_weight for w in weights]
		
		# Calculate per-position entropy
		position_entropies = []
		for pos in range(max_len):
			char_weights = {}
			for i, seq in enumerate(padded):
				char = seq[pos]
				char_weights[char] = char_weights.get(char, 0) + weights[i]
			
			# Filter positions with low occupancy
			gap_weight = char_weights.get('-', 0)
			if gap_weight > 0.3:  # Skip if >30% gaps
				continue
			
			# Calculate entropy
			entropy = 0
			for char, weight in char_weights.items():
				if char != '-' and weight > 0:
					entropy += -weight * np.log2(weight)
			
			position_entropies.append(entropy)
		
		avg_entropy = np.mean(position_entropies) if position_entropies else 0
		family_results['per_position_entropy'][mb.model_name] = {
			'mean': avg_entropy,
			'std': np.std(position_entropies) if position_entropies else 0
		}
		print(f"   {mb.model_name}: mean={avg_entropy:.3f}, std={np.std(position_entropies) if position_entropies else 0:.3f}")
	
	# Also run for AA sequences
	sequences = list(aa_sequences.values())
	max_len = max(len(s) for s in sequences)
	padded = [s + '-' * (max_len - len(s)) for s in sequences]
	
	weights = []
	for i, seq in enumerate(padded):
		weight = 0
		for pos in range(max_len):
			char = seq[pos]
			n_different = len(set(s[pos] for s in padded))
			n_with_char = sum(1 for s in padded if s[pos] == char)
			weight += 1.0 / (n_different * n_with_char)
		weights.append(weight)
	
	total_weight = sum(weights)
	weights = [w / total_weight for w in weights]
	
	position_entropies = []
	for pos in range(max_len):
		char_weights = {}
		for i, seq in enumerate(padded):
			char = seq[pos]
			char_weights[char] = char_weights.get(char, 0) + weights[i]
		
		gap_weight = char_weights.get('-', 0)
		if gap_weight > 0.3:
			continue
		
		entropy = 0
		for char, weight in char_weights.items():
			if char != '-' and weight > 0:
				entropy += -weight * np.log2(weight)
		
		position_entropies.append(entropy)
	
	avg_entropy = np.mean(position_entropies) if position_entropies else 0
	family_results['per_position_entropy']['AA'] = {
		'mean': avg_entropy,
		'std': np.std(position_entropies) if position_entropies else 0
	}
	print(f"   AA: mean={avg_entropy:.3f}, std={np.std(position_entropies) if position_entropies else 0:.3f}")
	
	# ==================== BENCHMARK 4: CROSS-REPRESENTATION MUTUAL INFORMATION ====================
	print("\n7. Cross-Representation Mutual Information")
	print("   Training neural predictors to estimate MI...")
	
	for mb in model_benchmarks:
		# Prepare data: FoldTree2 sequences -> AA sequences
		ft2_seqs = mb.encoded_df.seq.values
		aa_seqs = [aa_sequences[struct_id] for struct_id in mb.encoded_df.protid.values 
				   if struct_id in aa_sequences]
		
		# Create k-mer feature vectors
		k = 2
		ft2_features = []
		aa_features = []
		
		for ft2_seq, aa_seq in zip(ft2_seqs[:len(aa_seqs)], aa_seqs):
			# FoldTree2 k-mers
			ft2_kmers = [ft2_seq[i:i+k] for i in range(len(ft2_seq)-k+1)]
			ft2_counter = Counter(ft2_kmers)
			ft2_features.append(list(ft2_counter.values()))
			
			# AA k-mers
			aa_kmers = [aa_seq[i:i+k] for i in range(len(aa_seq)-k+1)]
			aa_counter = Counter(aa_kmers)
			aa_features.append(list(aa_counter.values()))
		
		# Pad to same length
		max_ft2 = max(len(x) for x in ft2_features)
		max_aa = max(len(x) for x in aa_features)
		
		X_ft2 = np.array([x + [0]*(max_ft2-len(x)) for x in ft2_features])
		X_aa = np.array([x + [0]*(max_aa-len(x)) for x in aa_features])
		
		# Train simple linear predictor: FoldTree2 -> AA
		from sklearn.linear_model import Ridge
		model = Ridge(alpha=1.0)
		
		# Cross-validation
		from sklearn.model_selection import cross_val_score
		scores = cross_val_score(model, X_ft2, X_aa, cv=min(5, len(X_ft2)), 
								 scoring='r2')
		
		# Calculate Spearman correlation as MI proxy
		model.fit(X_ft2, X_aa)
		predictions = model.predict(X_ft2)
		
		# Spearman correlation between predicted and actual
		from scipy.stats import spearmanr
		correlations = []
		for i in range(X_aa.shape[1]):
			corr, _ = spearmanr(predictions[:, i], X_aa[:, i])
			if not np.isnan(corr):
				correlations.append(corr)
		
		avg_correlation = np.mean(correlations) if correlations else 0
		family_results['cross_representation_mi'][mb.model_name] = {
			'r2_score': scores.mean(),
			'spearman_corr': avg_correlation
		}
		print(f"   {mb.model_name}: R²={scores.mean():.3f}, Spearman={avg_correlation:.3f}")
	
	# Store family results
	all_results[family_name] = family_results
	print(f"\n✓ Completed benchmarks for {family_name}")

print(f"\n{'='*70}")
print(f"✓ ALL FAMILIES COMPLETE")
print(f"{'='*70}")
print(f"Processed {len(all_results)} families:")

In [None]:
# ==================== SAVE RESULTS ====================

# Save full results to JSON
results_json_path = os.path.join(OUTPUT_DIR, 'benchmark_results.json')
with open(results_json_path, 'w') as f:
	json.dump(all_results, f, indent=2)
print(f"\n✓ Results saved to: {results_json_path}")

# Create summary DataFrame
summary_data = []
for model_name in all_results['models']:
	row = {
		'model': model_name,
		'alphabet_size': all_results['alphabet_sizes'][all_results['models'].index(model_name)]
	}
	
	# K-mer discrimination
	if model_name in all_results['k_mer_discrimination']:
		for k_key, k_data in all_results['k_mer_discrimination'][model_name].items():
			row[f'kmers_{k_key}_aa_sil'] = k_data['aa_silhouette']
			row[f'kmers_{k_key}_enc_sil'] = k_data['encoded_silhouette']
	
	# Entropy rates
	if model_name in all_results['entropy_rates']:
		for order_key, order_data in all_results['entropy_rates'][model_name].items():
			row[f'entropy_{order_key}_aa'] = order_data['aa_entropy_rate']
			row[f'entropy_{order_key}_enc'] = order_data['encoded_entropy_rate']
	
	# Per-position entropy
	if model_name in all_results['per_position_entropy'] and all_results['per_position_entropy'][model_name]:
		pos_data = all_results['per_position_entropy'][model_name]
		row['pos_entropy_aa_mean'] = pos_data['aa_mean_entropy']
		row['pos_entropy_enc_mean'] = pos_data['encoded_mean_entropy']
	
	# Cross-representation MI
	if model_name in all_results['cross_representation_mi'] and all_results['cross_representation_mi'][model_name]:
		mi_data = all_results['cross_representation_mi'][model_name]
		row['cross_rep_correlation'] = mi_data.get('proxy_correlation', np.nan)
	
	summary_data.append(row)

summary_df = pd.DataFrame(summary_data)

# Save summary CSV
summary_csv_path = os.path.join(OUTPUT_DIR, 'benchmark_summary.csv')
summary_df.to_csv(summary_csv_path, index=False)
print(f"✓ Summary saved to: {summary_csv_path}")

# Display summary
print("\n" + "="*70)
print("BENCHMARK SUMMARY")
print("="*70)
print(summary_df.to_string(index=False))
print("="*70)

In [None]:
# ==================== VISUALIZATION ====================

fig, axes = plt.subplots(2, 2, figsize=(16, 12))
fig.suptitle('Multi-Model Benchmark Comparison', fontsize=16, fontweight='bold')

# Extract alphabet sizes for x-axis
alphabet_sizes = summary_df['alphabet_size'].values
models = summary_df['model'].values

# Color palette
colors = sns.color_palette("husl", len(models))

# ========== 1. K-mer Discrimination (Silhouette Scores) ==========
ax = axes[0, 0]
k_cols = [col for col in summary_df.columns if 'kmers_k' in col and '_enc_sil' in col]
if k_cols:
	for i, col in enumerate(k_cols):
		k_value = col.split('_')[1]  # Extract k value
		ax.plot(alphabet_sizes, summary_df[col].values, marker='o', 
				label=f'{k_value} (encoded)', linewidth=2, markersize=8)
	
	ax.set_xlabel('Alphabet Size', fontsize=12)
	ax.set_ylabel('Silhouette Score', fontsize=12)
	ax.set_title('K-mer Fold Discrimination', fontsize=14, fontweight='bold')
	ax.legend(fontsize=10)
	ax.grid(True, alpha=0.3)
	ax.set_ylim([-0.1, 1.0])

# ========== 2. Entropy Rate ==========
ax = axes[0, 1]
entropy_cols = [col for col in summary_df.columns if 'entropy_order' in col and '_enc' in col]
if entropy_cols:
	for col in entropy_cols:
		order_value = col.split('_')[1]  # Extract order
		ax.plot(alphabet_sizes, summary_df[col].values, marker='s',
				label=f'{order_value} (encoded)', linewidth=2, markersize=8)
	
	ax.set_xlabel('Alphabet Size', fontsize=12)
	ax.set_ylabel('Entropy Rate (bits)', fontsize=12)
	ax.set_title('Markov Entropy Rate Estimation', fontsize=14, fontweight='bold')
	ax.legend(fontsize=10)
	ax.grid(True, alpha=0.3)

# ========== 3. Per-Position Entropy ==========
ax = axes[1, 0]
if 'pos_entropy_enc_mean' in summary_df.columns:
	ax.scatter(alphabet_sizes, summary_df['pos_entropy_aa_mean'].values,
			   s=150, alpha=0.7, label='AA sequences', marker='o', color='blue')
	ax.scatter(alphabet_sizes, summary_df['pos_entropy_enc_mean'].values,
			   s=150, alpha=0.7, label='Encoded sequences', marker='s', color='orange')
	
	# Add error bars if std available
	if 'pos_entropy_enc_std' in summary_df.columns:
		ax.errorbar(alphabet_sizes, summary_df['pos_entropy_enc_mean'].values,
				   yerr=summary_df['pos_entropy_enc_std'].values, fmt='none', 
				   color='orange', alpha=0.5, capsize=5)
	
	ax.set_xlabel('Alphabet Size', fontsize=12)
	ax.set_ylabel('Mean Positional Entropy (bits)', fontsize=12)
	ax.set_title('Per-Position Entropy (MSA)', fontsize=14, fontweight='bold')
	ax.legend(fontsize=10)
	ax.grid(True, alpha=0.3)

# ========== 4. Model Comparison Table ==========
ax = axes[1, 1]
ax.axis('tight')
ax.axis('off')

# Create comparison table
table_data = []
for idx, row in summary_df.iterrows():
	table_row = [
		row['model'][:20],  # Truncate model name
		f"{row['alphabet_size']}",
	]
	
	# Add best k-mer score
	k_scores = [row[col] for col in summary_df.columns if 'enc_sil' in col and not pd.isna(row[col])]
	table_row.append(f"{max(k_scores):.3f}" if k_scores else "N/A")
	
	# Add best entropy rate
	ent_scores = [row[col] for col in summary_df.columns if 'entropy_order' in col and '_enc' in col and not pd.isna(row[col])]
	table_row.append(f"{np.mean(ent_scores):.3f}" if ent_scores else "N/A")
	
	table_data.append(table_row)

table = ax.table(cellText=table_data,
				colLabels=['Model', 'Alphabet', 'Best Sil.', 'Avg Entropy'],
				cellLoc='left',
				loc='center',
				colWidths=[0.4, 0.2, 0.2, 0.2])
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1, 2)

# Style header row
for i in range(4):
	table[(0, i)].set_facecolor('#4CAF50')
	table[(0, i)].set_text_props(weight='bold', color='white')

# Alternate row colors
for i in range(1, len(table_data) + 1):
	for j in range(4):
		if i % 2 == 0:
			table[(i, j)].set_facecolor('#f0f0f0')

ax.set_title('Benchmark Summary', fontsize=14, fontweight='bold', pad=20)

plt.tight_layout()
plot_path = os.path.join(OUTPUT_DIR, 'benchmark_comparison.png')
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
print(f"\n✓ Visualization saved to: {plot_path}")
plt.show()

In [None]:
# ==================== CONFIGURATION ====================

# Models to benchmark (list of model paths without .pt extension)
MODELS = [
	"/home/dmoi/projects/foldtree2/models/model_10_embeddings_best_encoder",
	"/home/dmoi/projects/foldtree2/models/model_20_embeddings_best_encoder",
	"/home/dmoi/projects/foldtree2/models/model_30_embeddings_best_encoder",
	"/home/dmoi/projects/foldtree2/models/model_40_embeddings_best_encoder",
]

# Structure directory (PDB files)
STRUCTURES_DIR = "/home/dmoi/projects/foldtree2/alphafold_benchmark/rhodopsin/structs"

# Output directory for encoded sequences
OUTPUT_DIR = "/home/dmoi/projects/foldtree2/benchmark_results"

# Device configuration
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Benchmark parameters
BENCHMARK_PARAMS = {
	'k_mer_sizes': [2, 3, 4, 5],  # k-mer lengths to test
	'markov_orders': [0, 1, 2, 3, 4],  # Markov model orders
	'cv_folds': 5,  # Cross-validation folds
	'alpha_smoothing': 0.1,  # Smoothing parameter
	'occupancy_threshold': 0.7,  # MSA column occupancy threshold
	'reweight_threshold': 0.8,  # Sequence reweighting identity threshold
}

print(f"✓ Configuration loaded")
print(f"  - Models to benchmark: {len(MODELS)}")
print(f"  - Structure directory: {STRUCTURES_DIR}")
print(f"  - Output directory: {OUTPUT_DIR}")
print(f"  - Device: {DEVICE}")

In [None]:
# ==================== DATA LOADING & ENCODING ====================

class ModelBenchmark:
	"""Container for model encoding results and metadata"""
	def __init__(self, model_path: str, device):
		self.model_path = model_path
		self.model_name = Path(model_path).stem
		self.device = device
		
		# Load encoder
		print(f"Loading model: {self.model_name}")
		self.encoder = torch.load(model_path + '.pt', map_location=device, weights_only=False)
		self.encoder = self.encoder.to(device)
		self.encoder.device = device
		self.encoder.eval()
		
		# Extract model metadata
		self.num_embeddings = self.encoder.num_embeddings
		self.embedding_dim = self.encoder.out_channels
		
		# Storage for encoded sequences
		self.encoded_fasta = None
		self.encoded_df = None
		self.alphabet = None
		self.char_position_map = None
		
		print(f"  ✓ Loaded: {self.num_embeddings} embeddings, dim={self.embedding_dim}")
	
	def encode_structures(self, structures_loader, output_dir):
		"""Encode structures using this model"""
		output_path = os.path.join(output_dir, f"{self.model_name}_encoded.fasta")
		
		print(f"  Encoding structures with {self.model_name}...")
		self.encoder.encode_structures_fasta(
			structures_loader, 
			output_path, 
			replace=True
		)
		
		self.encoded_fasta = output_path
		self.encoded_df = ft2.load_encoded_fasta(output_path, alphabet=None, replace=False)
		self._build_alphabet()
		
		print(f"  ✓ Encoded {len(self.encoded_df)} sequences")
		return output_path
	
	def _build_alphabet(self):
		"""Build alphabet from encoded sequences"""
		char_set = set()
		for seq in self.encoded_df.seq:
			char_set = char_set.union(set(seq))
		self.alphabet = sorted(list(char_set))
		self.char_position_map = {char: i for i, char in enumerate(self.alphabet)}
		print(f"  ✓ Alphabet size: {len(self.alphabet)} characters")

def load_structures(structures_dir: str, converter: PDB2PyG):
	"""Load and convert PDB structures to PyG format"""
	pdb_files = glob.glob(os.path.join(structures_dir, "*.pdb"))
	print(f"\nFound {len(pdb_files)} PDB files in {structures_dir}")
	
	if len(pdb_files) == 0:
		raise ValueError(f"No PDB files found in {structures_dir}")
	
	def struct_generator():
		for pdb_file in pdb_files:
			try:
				data = converter.struct2pyg(pdb_file)
				if data:
					yield data
			except Exception as e:
				print(f"  Warning: Failed to convert {pdb_file}: {e}")
				continue
	
	return struct_generator()

def extract_aa_sequences(structures_dir: str, output_dir: str):
	"""Extract amino acid sequences from PDB files"""
	parser = PDBParser(QUIET=True)
	aa_dict = {
		'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D', 'CYS': 'C',
		'GLN': 'Q', 'GLU': 'E', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I',
		'LEU': 'L', 'LYS': 'K', 'MET': 'M', 'PHE': 'F', 'PRO': 'P',
		'SER': 'S', 'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V'
	}
	
	pdb_files = glob.glob(os.path.join(structures_dir, "*.pdb"))
	sequences = {}
	
	print(f"\nExtracting amino acid sequences from {len(pdb_files)} structures...")
	for pdb_file in pdb_files:
		structure_id = Path(pdb_file).stem
		try:
			structure = parser.get_structure(structure_id, pdb_file)
			seq = ""
			for model in structure:
				for chain in model:
					for residue in chain:
						if residue.get_resname() in aa_dict:
							seq += aa_dict[residue.get_resname()]
					break  # Only first chain
				break  # Only first model
			
			if seq:
				sequences[structure_id] = seq
		except Exception as e:
			print(f"  Warning: Failed to extract sequence from {pdb_file}: {e}")
	
	# Write to FASTA
	output_path = os.path.join(output_dir, "aa_sequences.fasta")
	with open(output_path, 'w') as f:
		for struct_id, seq in sequences.items():
			f.write(f">{struct_id}\n{seq}\n")
	
	print(f"  ✓ Extracted {len(sequences)} AA sequences")
	return output_path, sequences

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Initialize converter
converter = PDB2PyG(aapropcsv='/home/dmoi/projects/foldtree2/foldtree2/config/aaindex1.csv')

print("✓ Setup complete")

In [None]:
#note. try both alignment using foldmason, ft2 and regular mafft...
# the aligner can also be an argument for a particular character model

#

## K-mer Frequency Analysis for Fold Discrimination

This experiment analyzes the discriminative power of FoldTree2 (DSR) versus amino acid representations by examining k-mer frequency distributions. The goal is to determine which alphabet better distinguishes between different protein folds.

### Methodology

- **K-mer extraction**: Compute frequency distributions of subsequences of length k for both AA and DSR sequences
- **Within-family distances**: Calculate Jensen-Shannon divergences between k-mer distributions of sequences within the same fold family
- **Between-family distances**: Measure k-mer distribution differences across distinct fold families
- **Discrimination analysis**: Compare the separation between within-family (similar folds) and between-family (different folds) distance distributions

### Key Questions

1. **Fold specificity**: Does the FoldTree2 alphabet capture fold-specific sequence patterns more effectively than amino acid sequences?
2. **Optimal k-mer length**: What subsequence length provides the best discrimination for each representation?
3. **Distribution separation**: Which alphabet shows clearer separation between intra-fold similarity and inter-fold dissimilarity?

The analysis will reveal whether structural alphabets provide enhanced discriminative power for protein fold classification compared to traditional sequence-based representations.

In [None]:
from itertools import product
def kmer_freqs(seq, k, alpha_map, A):
	idx = [alpha_map[c] for c in seq if c in alpha_map]
	if len(idx) < k: return np.ones(A**k)/(A**k)
	counts = np.zeros(A**k)
	base = A**np.arange(k)[::-1]
	for t in range(len(idx)-k+1):
		code = 0
		for j in range(k):
			code = code*A + idx[t+j]
		counts[code]+=1
	p = counts + 1e-9
	p /= p.sum()
	return p

def jsd(p, q):
	m = 0.5*(p+q)
	def KL(a,b): 
		mask = (a>0)
		return (a[mask]*np.log2(a[mask]/b[mask])).sum()
	return 0.5*(KL(p,m)+KL(q,m))

In [None]:
from typing import List, Tuple, Dict
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score

# ---------------------------- k-mer utils ------------------------------

def build_alpha_index(alphabet: str) -> Dict[str, int]:
	return {c:i for i,c in enumerate(alphabet)}

def kmers_counts(seq: str, k: int, alpha_idx: Dict[str,int], A: int) -> np.ndarray:
	"""Return counts vector of length A^k for overlapping k-mers in seq (skip k-mer if any unseen char)."""
	L = len(seq)
	if L < k or k == 0:
		return np.zeros(A**max(k,1), dtype=np.float64)
	v = np.zeros(A**k, dtype=np.float64)
	code = -1
	for i, ch in enumerate(seq):
		if ch not in alpha_idx:
			code = -1
		else:
			x = alpha_idx[ch]
			if code == -1:
				if i >= k-1:
					ok = True
					code_tmp = 0
					for j in range(i-k+1, i+1):
						c2 = seq[j]
						if c2 not in alpha_idx:
							ok = False; break
						code_tmp = code_tmp * A + alpha_idx[c2]
					if ok:
						code = code_tmp
						v[code] += 1.0
			else:
				code = (code % (A**(k-1))) * A + x
				v[code] += 1.0
	return v

def kmer_prob(seq: str, k: int, alphabet: str, pseudocount: float = 1e-9) -> np.ndarray:
	A = len(alphabet)
	idx = build_alpha_index(alphabet)
	c = kmers_counts(seq, k, idx, A)
	total = c.sum()
	if total == 0:
		# no valid k-mers: return uniform tiny distribution
		p = np.ones(A**k, dtype=np.float64)
		p /= p.sum()
		return p
	p = (c + pseudocount) / (total + pseudocount * c.shape[0])
	return p

def build_feature_matrix(fasta: List[Tuple[str,str]], alphabet: str, k: int, pseudocount: float):
	ids = [name for name,_ in fasta]
	seqs = [seq.upper() for _,seq in fasta]
	P = np.vstack([kmer_prob(s, k, alphabet, pseudocount=pseudocount) for s in seqs])
	return ids, P


In [None]:

# --------------------------- KMeans + eval -----------------------------

def kmeans_cluster(P: np.ndarray, K: int, n_init: int = 20, max_iter: int = 300, random_state: int = 0):
	# KMeans on probability vectors (Euclidean). For probability geometry, you can also sqrt-transform (Hellinger).
	model = KMeans(n_clusters=K, n_init=n_init, max_iter=max_iter, random_state=random_state)
	labels = model.fit_predict(P)
	return labels, model

def run_rep(
	fasta_path: str, labels_map: Dict[str,str],
	alphabet: str, k: int, pseudocount: float,
	target_clusters: int, random_state: int, n_init: int, max_iter: int
):
	fasta = read_fasta(fasta_path)
	ids, P = build_feature_matrix(fasta, alphabet, k, pseudocount)

	# align labels and filter
	y = []
	keep = []
	for i, sid in enumerate(ids):
		if sid in labels_map:
			y.append(labels_map[sid])
			keep.append(i)
	if not keep:
		raise ValueError("No IDs from FASTA matched labels.tsv")
	ids = [ids[i] for i in keep]
	P = P[keep]
	y = np.array(y)

	uniq = {lab:i for i,lab in enumerate(sorted(set(y)))}
	y_int = np.array([uniq[lab] for lab in y], dtype=int)
	K = target_clusters or len(uniq)

	# KMeans
	labels_pred, model = kmeans_cluster(
		P, K=K, n_init=n_init, max_iter=max_iter, random_state=random_state
	)

	ari = adjusted_rand_score(y_int, labels_pred)
	nmi = normalized_mutual_info_score(y_int, labels_pred)

	return {
		"ids": ids,
		"P": P,
		"y_int": y_int,
		"y_str": y,
		"labels_pred": labels_pred,
		"K": K,
		"ari": ari,
		"nmi": nmi,
	}

## Entropy Rate Estimation with k-order Markov Models

This experiment estimates the global entropy rate using k-order Markov models across multiple protein families. The analysis compares two different sequence representations:

- **AA sequences**: Traditional amino acid sequences using the 20-letter alphabet
- **DSR sequences**: Discrete structure representation using a K-token alphabet

### Methodology

The experiment uses a **backoff smoothing approach** with cross-validation to estimate entropy rates:

1. **k-order Markov modeling**: Models conditional probabilities P(x|context) where context length = k
2. **Backoff smoothing**: Handles sparse data by interpolating between different order models (k-gram → (k-1)-gram → ... → unigram)
3. **5-fold cross-validation**: Splits sequences by family to avoid overfitting
4. **Additive smoothing**: Regularizes maximum likelihood estimates with parameter α

### Aggregation Strategy

Results are aggregated at two levels:
- **Macro-averaging**: Equal weight per family (family-centric view)
- **Micro-averaging**: Weight by total tokens (sequence-centric view)

This allows comparison of structural vs. sequence-based entropy rates across different context lengths (k=0,1,2,3,4).

In [None]:
# 1) Prepare your data as dict: family -> list of sequences (strings)
families_AA  = {"PF00001": ["MKT...", "MSS..."], "PF00002": [...], ...}
families_DSR = {"PF00001": ["QAB...", "QAA..."], "PF00002": [...], ...}

# 2) Define alphabets
alphabet_AA  = list("ACDEFGHIKLMNPQRSTVWY")       # or include 'X' if you keep it
alphabet_DSR = [chr(i) for i in range(65, 65+K)]   # e.g., 'A'.. for K tokens, or your actual token set

# 3) Run
ks = (0,1,2,3,4)
aa_res, aa_agg   = run_entropy_over_families(families_AA,  alphabet_AA,  k_values=ks, alpha=0.1, delta=None, folds=5)
dsr_res, dsr_agg = run_entropy_over_families(families_DSR, alphabet_DSR, k_values=ks, alpha=0.1, delta=None, folds=5)

# 4) Compare and plot:
#   - aa_agg[k]['macro'] vs dsr_agg[k]['macro']
#   - per-family deltas: {fam: dsr_res[k][fam]-aa_res[k][fam]}

In [None]:
from collections import defaultdict, Counter
import random, math
from typing import List, Dict, Tuple

# ---------- utils
def tokenize(seq: str, alpha_map: Dict[str,int]) -> List[int]:
	return [alpha_map[c] for c in seq if c in alpha_map]

def k_context(stream: List[int], k: int):
	# yields (context_tuple, symbol) skipping first k
	if k == 0:
		for x in stream: yield (), x
	else:
		ctx = []
		for x in stream:
			ctx.append(x)
			if len(ctx) > k:
				yield tuple(ctx[-k-1:-1]), x

def add_counts(counts, stream: List[int], k: int):
	for ctx, x in k_context(stream, k):
		counts[ctx][x] += 1

def build_counts(seqs: List[List[int]], k: int, A: int):
	counts = defaultdict(lambda: Counter())
	total_tokens = 0
	for s in seqs:
		total_tokens += max(0, len(s)-k)
		add_counts(counts, s, k)
	return counts, total_tokens

# ---------- smoothed conditional with simple backoff
class BackoffKModel:
	def __init__(self, counts_k_list, A: int, alpha=0.1, delta=None):
		"""
		counts_k_list: list where idx j holds counts for order j (0..k)
		A: alphabet size
		alpha: additive smoothing for MLE
		delta: backoff strength; if None, set to A (alphabet size)
		"""
		self.counts = counts_k_list
		self.A = A
		self.alpha = alpha
		self.delta = delta if delta is not None else A

	def p_cond(self, ctx: Tuple[int,...], x: int) -> float:
		k = len(ctx)
		return self._p_k(k, ctx, x)

	def _p_k(self, k: int, ctx: Tuple[int,...], x: int) -> float:
		# base: unigram (order 0)
		if k == 0:
			cnts0 = self.counts[0][()]
			num = cnts0.get(x, 0) + self.alpha
			den = sum(cnts0.values()) + self.alpha * self.A
			return num / den

		# order-k MLE with smoothing
		cnts_k = self.counts[k][ctx]
		num = cnts_k.get(x, 0) + self.alpha
		den = sum(cnts_k.values()) + self.alpha * self.A
		p_mle = num / den

		# backoff weight
		gamma = self.delta / (self.delta + sum(cnts_k.values()))
		# suffix context
		suffix = ctx[1:]
		return (1 - gamma) * p_mle + gamma * self._p_k(k-1, suffix, x)

# ---------- cross-entropy (held-out)
def cross_entropy_bits(model: BackoffKModel, seqs: List[List[int]], k: int) -> float:
	tot_logloss = 0.0
	tot_tokens = 0
	for s in seqs:
		# iterate tokens with contexts; boundaries reset by per-seq processing
		for ctx, x in k_context(s, k):
			p = model.p_cond(ctx, x)
			tot_logloss += -math.log2(max(p, 1e-300))
			tot_tokens += 1
	return tot_logloss / max(1, tot_tokens)

# ---------- 5-fold CV by sequence
def entropy_rate_cv(seqs: List[List[int]], A: int, k: int, alpha=0.1, delta=None, folds=5, seed=0):
	random.Random(seed).shuffle(seqs)
	if len(seqs) < folds: folds = max(2, len(seqs))
	fold_size = math.ceil(len(seqs)/folds)
	losses = []
	for f in range(folds):
		test = seqs[f*fold_size:(f+1)*fold_size]
		train = seqs[:f*fold_size] + seqs[(f+1)*fold_size:]
		# build counts for orders 0..k
		counts_k_list = []
		for j in range(k+1):
			counts_j, _ = build_counts(train, j, A)
			counts_k_list.append(counts_j)
		model = BackoffKModel(counts_k_list, A, alpha=alpha, delta=delta)
		H = cross_entropy_bits(model, test, k)
		losses.append((H, sum(max(0, len(s)-k) for s in test)))
	# micro-average over folds
	num = sum(H*n for H, n in losses)
	den = sum(n for _, n in losses) or 1
	return num/den

# ---------- run over families and ks
def run_entropy_over_families(
	families: Dict[str, List[str]],
	alphabet: List[str],
	k_values=(0,1,2,3,4),
	alpha=0.1, delta=None, folds=5, seed=0
):
	alpha_map = {c:i for i,c in enumerate(alphabet)}
	A = len(alphabet)

	# tokenize
	fam_tok = {
		fam: [tokenize(s, alpha_map) for s in seqs if len(tokenize(s, alpha_map))>0]
		for fam, seqs in families.items()
	}

	results = {k:{} for k in k_values}
	sizes   = {fam: sum(max(0,len(s)-max(k_values)) for s in seqs) for fam, seqs in fam_tok.items()}

	for k in k_values:
		fam_H = {}
		for fam, seqs in fam_tok.items():
			if len(seqs)==0: continue
			Hk = entropy_rate_cv(seqs, A, k, alpha=alpha, delta=delta, folds=folds, seed=seed)
			fam_H[fam] = Hk
		results[k] = fam_H

	# aggregates
	aggregates = {}
	for k in k_values:
		fam_H = results[k]
		fam_list = list(fam_H.items())
		if not fam_list:
			aggregates[k] = dict(macro=None, micro=None, n_families=0)
			continue
		macro = sum(h for _,h in fam_list)/len(fam_list)
		# micro weight by total tokens (approximate using lengths at this k)
		weights = {fam: sum(max(0, len(s)-k) for s in fam_tok[fam]) for fam,_ in fam_list}
		num = sum(fam_H[fam]*weights[fam] for fam,_ in fam_list)
		den = sum(weights.values()) or 1
		micro = num/den
		aggregates[k] = dict(macro=macro, micro=micro, n_families=len(fam_list))
	return results, aggregates


## Cross-Representation Information Analysis

This experiment investigates the **mutual information** between amino acid (AA) and discrete structure representation (DSR) sequences by training probabilistic mappings in both directions using local sequence windows.

### Approach

- **Windowed prediction**: Train neural networks to predict target tokens from source context windows (e.g., predict AA from 7-token DSR window)
- **Bidirectional mapping**: Learn both DSR→AA and AA→DSR predictors to estimate cross-entropies
- **Information bounds**: Derive mutual information lower bounds using H(target) - H(target|source_window)

### Key Questions

1. **Complementarity**: How much structural information is captured in DSR that's not present in AA sequences?
2. **Redundancy**: What fraction of sequence information is already encoded in structural representations?
3. **Context dependence**: How does prediction accuracy vary with window size and MSA position?

The analysis will reveal whether the two representations contain overlapping or complementary information about protein structure and function.

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Cross-representation information (Strategy 3)
- Learn small probabilistic mappers DSR→AA and AA→DSR using windowed tokens
- Estimate per-position and global H(AA), H(DSR), and conditional cross-entropies
- Derive MI lower bounds: I_hat = H - H_hat(cond)
"""

import argparse, math, random
from collections import Counter, defaultdict
from pathlib import Path
from typing import List, Tuple, Dict

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# ----------------------------- FASTA utils -----------------------------

def read_fasta(path: str) -> List[Tuple[str,str]]:
	seqs = []
	name = None
	buf = []
	with open(path, 'r') as f:
		for line in f:
			line = line.strip()
			if not line: 
				continue
			if line.startswith('>'):
				if name is not None:
					seqs.append((name, ''.join(buf)))
				name = line[1:].strip()
				buf = []
			else:
				buf.append(line)
	if name is not None:
		seqs.append((name, ''.join(buf)))
	return seqs

def ensure_same_shape(msa1: List[str], msa2: List[str]):
	assert len(msa1) == len(msa2), "MSA row count mismatch between AA and DSR."
	L1 = len(msa1[0])
	for s in msa1:
		assert len(s) == L1, "AA MSA rows must have equal length."
	for s in msa2:
		assert len(s) == L1, "DSR MSA must have same number of columns as AA MSA."

# ------------------------- Sequence reweighting ------------------------

def seq_identity(a: str, b: str) -> float:
	matches = 0
	comps = 0
	for x, y in zip(a, b):
		if x == '-' or y == '-':
			continue
		comps += 1
		if x == y:
			matches += 1
	if comps == 0: return 0.0
	return matches / comps

def reweight_sequences(msa: List[str], thresh: float = 0.8) -> np.ndarray:
	n = len(msa)
	w = np.zeros(n, dtype=float)
	for i in range(n):
		c = 0
		for j in range(n):
			if seq_identity(msa[i], msa[j]) >= thresh:
				c += 1
		w[i] = 1.0 / max(1, c)
	if w.sum() > 0:
		w *= (n / w.sum())
	return w

# ----------------------- Entropy (empirical) ---------------------------

def shannon_entropy_from_counts(counts: np.ndarray, pseudocount: float = 0.0) -> float:
	A = counts.shape[0]
	total = counts.sum()
	p = (counts + pseudocount) / (total + pseudocount * A)
	p = p[p > 0]
	return float(-(p * np.log2(p)).sum())

def per_position_entropy(msa: List[str], alphabet: List[str], weights: np.ndarray,
						 occupancy_threshold: float = 0.7,
						 pseudocount: float = None) -> Tuple[np.ndarray, np.ndarray]:
	A = len(alphabet)
	if pseudocount is None:
		pseudocount = 1.0 / A
	alpha_index = {c:i for i,c in enumerate(alphabet)}
	L = len(msa[0])
	ent = np.full(L, np.nan, dtype=float)
	mask = np.zeros(L, dtype=bool)

	occ = np.zeros(L, dtype=float)
	for i, seq in enumerate(msa):
		g = np.fromiter((c != '-' for c in seq), dtype=bool, count=L)
		occ += weights[i] * g.astype(float)
	occ /= max(1e-9, weights.sum())

	for t in range(L):
		if occ[t] < occupancy_threshold: continue
		counts = np.zeros(A, dtype=float)
		for i, seq in enumerate(msa):
			ch = seq[t]
			if ch == '-' or ch not in alpha_index: continue
			counts[alpha_index[ch]] += weights[i]
		if counts.sum() <= 0: continue
		ent[t] = shannon_entropy_from_counts(counts, pseudocount=pseudocount)
		mask[t] = True
	return ent, mask

# ------------------------ Windowed samples -----------------------------

def build_window_samples(
	src_msa: List[str], tgt_msa: List[str], weights: np.ndarray,
	src_alpha: List[str], tgt_alpha: List[str],
	pos_mask: np.ndarray, win: int
):
	"""
	Build samples to predict target token at t from a src window around t.
	- Requires no gaps in the src window or target at t.
	- pos_mask picks columns to consider (e.g., occupancy intersection).
	Returns X (one-hot per-window), y (int labels), w (sample weights), pos_idx (column indices).
	"""
	assert win % 2 == 1, "Window must be odd size."
	half = win // 2
	A_src, A_tgt = len(src_alpha), len(tgt_alpha)
	src_idx = {c:i for i,c in enumerate(src_alpha)}
	tgt_idx = {c:i for i,c in enumerate(tgt_alpha)}

	L = len(src_msa[0])
	n = len(src_msa)

	feats = []
	labels = []
	sw = []
	pos_idx = []

	# precompute valid (non-gap and in alphabet) masks
	src_valid = np.array([[ (c!='-' and c in src_idx) for c in row] for row in src_msa], dtype=bool)
	tgt_valid = np.array([[ (c!='-' and c in tgt_idx) for c in row] for row in tgt_msa], dtype=bool)

	for i in range(n):
		w_i = weights[i]
		src_row = src_msa[i]
		tgt_row = tgt_msa[i]
		for t in range(half, L-half):
			if not pos_mask[t]: continue
			if not tgt_valid[i, t]: continue
			# window must be fully valid in src
			if not src_valid[i, t-half:t+half+1].all(): continue

			# one-hot encode window: concat per-position one-hots
			v = np.zeros((win, A_src), dtype=np.float32)
			for j, col in enumerate(range(t-half, t+half+1)):
				s = src_row[col]
				v[j, src_idx[s]] = 1.0
			feats.append(v.reshape(-1))  # (win*A_src,)
			labels.append(tgt_idx[tgt_row[t]])
			sw.append(w_i)
			pos_idx.append(t)

	if len(feats) == 0:
		return (np.zeros((0, win*A_src), dtype=np.float32),
				np.zeros((0,), dtype=np.int64),
				np.zeros((0,), dtype=np.float32),
				np.zeros((0,), dtype=np.int32))
	X = np.stack(feats, axis=0)
	y = np.array(labels, dtype=np.int64)
	w = np.array(sw, dtype=np.float32)
	pos_idx = np.array(pos_idx, dtype=np.int32)
	return X, y, w, pos_idx

# --------------------------- Torch model --------------------------------

class SoftmaxLinear(nn.Module):
	def __init__(self, D_in: int, C_out: int, bias=True):
		super().__init__()
		self.lin = nn.Linear(D_in, C_out, bias=bias)
	def forward(self, x):
		return self.lin(x)  # logits

class NumpyDataset(Dataset):
	def __init__(self, X, y, w):
		self.X = torch.from_numpy(X)
		self.y = torch.from_numpy(y)
		self.w = torch.from_numpy(w)
	def __len__(self):
		return self.X.shape[0]
	def __getitem__(self, idx):
		return self.X[idx], self.y[idx], self.w[idx]

def split_by_sequence_indices(n_seq: int, seed=0, ratios=(0.7, 0.15, 0.15)):
	idx = list(range(n_seq))
	random.Random(seed).shuffle(idx)
	n_train = int(ratios[0]*n_seq)
	n_val   = int(ratios[1]*n_seq)
	train_ids = set(idx[:n_train])
	val_ids   = set(idx[n_train:n_train+n_val])
	test_ids  = set(idx[n_train+n_val:])
	return train_ids, val_ids, test_ids

def mask_samples_by_seqpos(seq_pos_of_sample: List[int], seq_ids_set: set):
	return np.array([sp in seq_ids_set for sp in seq_pos_of_sample], dtype=bool)

def train_softmax(
	X_train, y_train, w_train,
	X_val,   y_val,   w_val,
	D_in, C_out, lr=1e-2, epochs=20, bs=4096, weight_decay=1e-4, device="cpu"
):
	model = SoftmaxLinear(D_in, C_out).to(device)
	opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

	ds_tr = NumpyDataset(X_train, y_train, w_train)
	dl_tr = DataLoader(ds_tr, batch_size=bs, shuffle=True, drop_last=False)

	def eval_ce(X, y, w):
		if X.shape[0] == 0: return float('nan')
		with torch.no_grad():
			X_t = torch.from_numpy(X).to(device)
			y_t = torch.from_numpy(y).to(device)
			w_t = torch.from_numpy(w).to(device)
			logits = model(X_t)
			ce = nn.functional.cross_entropy(logits, y_t, reduction='none')
			return float((ce * w_t).sum().item() / max(1e-9, w_t.sum().item()))

	best_val = float('inf')
	best_state = None
	patience, bad = 4, 0

	for ep in range(epochs):
		model.train()
		for xb, yb, wb in dl_tr:
			xb, yb, wb = xb.to(device), yb.to(device), wb.to(device)
			logits = model(xb)
			ce = nn.functional.cross_entropy(logits, yb, reduction='none')
			loss = (ce * wb).sum() / (wb.sum() + 1e-9)
			opt.zero_grad()
			loss.backward()
			opt.step()

		val_ce = eval_ce(X_val, y_val, w_val)
		if val_ce < best_val - 1e-5:
			best_val = val_ce
			best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
			bad = 0
		else:
			bad += 1
			if bad >= patience:
				break

	if best_state is not None:
		model.load_state_dict(best_state)
	return model

def predict_log_probs(model: nn.Module, X: np.ndarray, bs=8192, device="cpu") -> np.ndarray:
	if X.shape[0] == 0: return np.zeros((0, model.lin.out_features), dtype=np.float32)
	model.eval()
	out = []
	with torch.no_grad():
		for i in range(0, X.shape[0], bs):
			xb = torch.from_numpy(X[i:i+bs]).to(device)
			logits = model(xb)
			logp = nn.functional.log_softmax(logits, dim=-1)
			out.append(logp.detach().cpu().numpy())
	return np.vstack(out)

In [None]:
# ==================== CROSS-FAMILY RESULTS SUMMARY ====================

print("\n" + "="*70)
print("CROSS-FAMILY BENCHMARK RESULTS SUMMARY")
print("="*70)

# Display results for each family
for family_name in all_results.keys():
	print(f"\n  - {family_name}: {all_results[family_name]['n_structures']} structures")

# Save comprehensive results to JSON
import json
results_path = os.path.join(OUTPUT_DIR, 'all_families_results.json')
with open(results_path, 'w') as f:
	# Convert numpy types to native Python types for JSON serialization
	def convert_to_serializable(obj):
		if isinstance(obj, np.integer):
			return int(obj)
		elif isinstance(obj, np.floating):
			return float(obj)
		elif isinstance(obj, np.ndarray):
			return obj.tolist()
		elif isinstance(obj, dict):
			return {k: convert_to_serializable(v) for k, v in obj.items()}
		elif isinstance(obj, list):
			return [convert_to_serializable(item) for item in obj]
		else:
			return obj
	
	serializable_results = convert_to_serializable(all_results)
	json.dump(serializable_results, f, indent=2)

print(f"\n✓ Results saved to: {results_path}")

# Create DataFrame for easy viewing
import pandas as pd

# Collect all results into structured format
summary_data = []

for family_name, family_results in all_results.items():
	for model_name in family_results['models'] + ['AA']:
		row = {
			'Family': family_name,
			'Model': model_name,
			'N_Structures': family_results['n_structures']
		}
		
		# K-mer discrimination scores
		if model_name in family_results['k_mer_discrimination']:
			kmer_results = family_results['k_mer_discrimination'][model_name]
			for k_val in ['k=1', 'k=2', 'k=3', 'k=4']:
				if k_val in kmer_results:
					row[f'KMer_{k_val}'] = kmer_results[k_val]
		
		# Entropy rates
		if model_name in family_results['entropy_rates']:
			entropy_results = family_results['entropy_rates'][model_name]
			for order in ['order=0', 'order=1', 'order=2', 'order=3']:
				if order in entropy_results:
					row[f'Entropy_{order}'] = entropy_results[order]
		
		# Per-position entropy
		if model_name in family_results['per_position_entropy']:
			pp_entropy = family_results['per_position_entropy'][model_name]
			row['PerPos_Entropy_Mean'] = pp_entropy['mean']
			row['PerPos_Entropy_Std'] = pp_entropy['std']
		
		# Cross-representation MI
		if model_name in family_results['cross_representation_mi']:
			mi_results = family_results['cross_representation_mi'][model_name]
			row['CrossMI_R2'] = mi_results['r2_score']
			row['CrossMI_Spearman'] = mi_results['spearman_corr']
		
		summary_data.append(row)

results_df = pd.DataFrame(summary_data)

# Save to CSV
csv_path = os.path.join(OUTPUT_DIR, 'all_families_results.csv')
results_df.to_csv(csv_path, index=False)
print(f"✓ CSV results saved to: {csv_path}")

# Display summary statistics
print("\n" + "─"*70)
print("SUMMARY STATISTICS")
print("─"*70)

# Group by model and show average performance across families
print("\n1. K-MER DISCRIMINATION (k=2) - Average Silhouette Score:")
kmer_cols = [c for c in results_df.columns if 'KMer_k=2' in c]
if kmer_cols:
	kmer_summary = results_df.groupby('Model')['KMer_k=2'].mean().sort_values(ascending=False)
	for model, score in kmer_summary.items():
		print(f"   {model:20s}: {score:.4f}")

print("\n2. ENTROPY RATE (1st-order) - Average Entropy:")
entropy_cols = [c for c in results_df.columns if 'Entropy_order=1' in c]
if entropy_cols:
	entropy_summary = results_df.groupby('Model')['Entropy_order=1'].mean().sort_values()
	for model, score in entropy_summary.items():
		print(f"   {model:20s}: {score:.4f}")

print("\n3. PER-POSITION ENTROPY - Average Mean Entropy:")
if 'PerPos_Entropy_Mean' in results_df.columns:
	pp_summary = results_df.groupby('Model')['PerPos_Entropy_Mean'].mean().sort_values()
	for model, score in pp_summary.items():
		print(f"   {model:20s}: {score:.4f}")

print("\n4. CROSS-REPRESENTATION MI - Average R² Score:")
if 'CrossMI_R2' in results_df.columns:
	mi_summary = results_df.groupby('Model')['CrossMI_R2'].mean().sort_values(ascending=False)
	for model, score in mi_summary.items():
		print(f"   {model:20s}: {score:.4f}")

print("\n" + "="*70)
print("✓ BENCHMARK COMPLETE")
print("="*70)

In [None]:
# ==================== MULTI-FAMILY VISUALIZATION ====================

import matplotlib.pyplot as plt
import seaborn as sns

# Set style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (18, 12)

fig, axes = plt.subplots(2, 2, figsize=(18, 12))
fig.suptitle('Multi-Family Multi-Model Benchmark Comparison', fontsize=16, fontweight='bold')

# Define colors for models
unique_models = results_df['Model'].unique()
model_colors = {model: color for model, color in zip(unique_models, sns.color_palette("husl", len(unique_models)))}

# ========== 1. K-MER DISCRIMINATION ACROSS FAMILIES ==========
ax = axes[0, 0]

for model in unique_models:
	model_data = results_df[results_df['Model'] == model]
	if 'KMer_k=2' in model_data.columns:
		families = model_data['Family'].values
		scores = model_data['KMer_k=2'].values
		ax.plot(families, scores, marker='o', label=model, 
				color=model_colors[model], linewidth=2, markersize=8, alpha=0.7)

ax.set_xlabel('Family', fontsize=12, fontweight='bold')
ax.set_ylabel('Silhouette Score (k=2)', fontsize=12, fontweight='bold')
ax.set_title('K-mer Fold Discrimination', fontsize=14, fontweight='bold')
ax.legend(loc='best', fontsize=10)
ax.grid(True, alpha=0.3)
ax.tick_params(axis='x', rotation=45)

# ========== 2. ENTROPY RATES ACROSS FAMILIES ==========
ax = axes[0, 1]

for model in unique_models:
	model_data = results_df[results_df['Model'] == model]
	if 'Entropy_order=1' in model_data.columns:
		families = model_data['Family'].values
		entropies = model_data['Entropy_order=1'].values
		ax.plot(families, entropies, marker='s', label=model,
				color=model_colors[model], linewidth=2, markersize=8, alpha=0.7)

ax.set_xlabel('Family', fontsize=12, fontweight='bold')
ax.set_ylabel('Entropy Rate (1st-order)', fontsize=12, fontweight='bold')
ax.set_title('Markov Entropy Estimation', fontsize=14, fontweight='bold')
ax.legend(loc='best', fontsize=10)
ax.grid(True, alpha=0.3)
ax.tick_params(axis='x', rotation=45)

# ========== 3. PER-POSITION ENTROPY ACROSS FAMILIES ==========
ax = axes[1, 0]

if 'PerPos_Entropy_Mean' in results_df.columns:
	for model in unique_models:
		model_data = results_df[results_df['Model'] == model]
		families = model_data['Family'].values
		entropies = model_data['PerPos_Entropy_Mean'].values
		ax.plot(families, entropies, marker='^', label=model,
				color=model_colors[model], linewidth=2, markersize=8, alpha=0.7)

	ax.set_xlabel('Family', fontsize=12, fontweight='bold')
	ax.set_ylabel('Mean Entropy', fontsize=12, fontweight='bold')
	ax.set_title('Per-Position Entropy', fontsize=14, fontweight='bold')
	ax.legend(loc='best', fontsize=10)
	ax.grid(True, alpha=0.3)
	ax.tick_params(axis='x', rotation=45)

# ========== 4. CROSS-REPRESENTATION MI ACROSS FAMILIES ==========
ax = axes[1, 1]

if 'CrossMI_R2' in results_df.columns:
	for model in unique_models:
		if model != 'AA':  # Skip AA baseline for cross-representation
			model_data = results_df[results_df['Model'] == model]
			families = model_data['Family'].values
			r2_scores = model_data['CrossMI_R2'].values
			ax.plot(families, r2_scores, marker='D', label=model,
					color=model_colors[model], linewidth=2, markersize=8, alpha=0.7)

	ax.set_xlabel('Family', fontsize=12, fontweight='bold')
	ax.set_ylabel('R² Score', fontsize=12, fontweight='bold')
	ax.set_title('Cross-Representation MI (FoldTree2→AA)', fontsize=14, fontweight='bold')
	ax.legend(loc='best', fontsize=10)
	ax.grid(True, alpha=0.3)
	ax.tick_params(axis='x', rotation=45)

plt.tight_layout()
fig_path = os.path.join(OUTPUT_DIR, 'multi_family_benchmark_comparison.png')
plt.savefig(fig_path, dpi=300, bbox_inches='tight')
print(f"\n✓ Visualization saved to: {fig_path}")
plt.show()

# ==================== HEATMAP: MODEL PERFORMANCE BY FAMILY ====================

print("\n" + "="*70)
print("HEATMAP: MODEL PERFORMANCE BY FAMILY")
print("="*70)

# Create pivot table for heatmap
if 'KMer_k=2' in results_df.columns:
	fig, axes = plt.subplots(2, 2, figsize=(16, 12))
	fig.suptitle('Model Performance Heatmaps by Family', fontsize=16, fontweight='bold')
	
	# K-mer discrimination heatmap
	ax = axes[0, 0]
	pivot_kmer = results_df.pivot(index='Model', columns='Family', values='KMer_k=2')
	sns.heatmap(pivot_kmer, annot=True, fmt='.3f', cmap='YlGnBu', ax=ax, cbar_kws={'label': 'Score'})
	ax.set_title('K-mer Discrimination (k=2)', fontsize=14, fontweight='bold')
	ax.set_xlabel('Family', fontsize=12)
	ax.set_ylabel('Model', fontsize=12)
	
	# Entropy rate heatmap
	if 'Entropy_order=1' in results_df.columns:
		ax = axes[0, 1]
		pivot_entropy = results_df.pivot(index='Model', columns='Family', values='Entropy_order=1')
		sns.heatmap(pivot_entropy, annot=True, fmt='.3f', cmap='YlOrRd', ax=ax, cbar_kws={'label': 'Entropy'})
		ax.set_title('Entropy Rate (1st-order)', fontsize=14, fontweight='bold')
		ax.set_xlabel('Family', fontsize=12)
		ax.set_ylabel('Model', fontsize=12)
	
	# Per-position entropy heatmap
	if 'PerPos_Entropy_Mean' in results_df.columns:
		ax = axes[1, 0]
		pivot_pp = results_df.pivot(index='Model', columns='Family', values='PerPos_Entropy_Mean')
		sns.heatmap(pivot_pp, annot=True, fmt='.3f', cmap='Purples', ax=ax, cbar_kws={'label': 'Entropy'})
		ax.set_title('Per-Position Entropy', fontsize=14, fontweight='bold')
		ax.set_xlabel('Family', fontsize=12)
		ax.set_ylabel('Model', fontsize=12)
	
	# Cross-MI heatmap
	if 'CrossMI_R2' in results_df.columns:
		ax = axes[1, 1]
		# Filter out AA for cross-representation
		mi_data = results_df[results_df['Model'] != 'AA']
		pivot_mi = mi_data.pivot(index='Model', columns='Family', values='CrossMI_R2')
		sns.heatmap(pivot_mi, annot=True, fmt='.3f', cmap='Greens', ax=ax, cbar_kws={'label': 'R²'})
		ax.set_title('Cross-Representation MI', fontsize=14, fontweight='bold')
		ax.set_xlabel('Family', fontsize=12)
		ax.set_ylabel('Model', fontsize=12)
	
	plt.tight_layout()
	heatmap_path = os.path.join(OUTPUT_DIR, 'model_family_heatmaps.png')
	plt.savefig(heatmap_path, dpi=300, bbox_inches='tight')
	print(f"✓ Heatmaps saved to: {heatmap_path}")
	plt.show()

print("\n" + "="*70)
print("✓ ALL VISUALIZATIONS COMPLETE")
print("="*70)