# Multi-Variable Phenotype Clustering Analysis
## Chen 2004 Yeast Cell Cycle Model - Whole Network Phenotypes

This notebook implements the second definition of phenotype that considers the whole gene regulatory network rather than a single molecule. We use BIRCH clustering to group mn-dimensional points representing the full system output.

In [1]:
# === IMPORTS ===
import tellurium as te
import numpy as np
import matplotlib.pyplot as plt
import random
import os
import roadrunner
import multiprocessing as mp
from multiprocessing import Pool
from sklearn.cluster import Birch
from sklearn.metrics import silhouette_score, pairwise_distances
import pandas as pd
import time
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

roadrunner.Logger.setLevel(roadrunner.Logger.LOG_CRITICAL)
print("✓ All imports loaded")

ModuleNotFoundError: No module named 'sklearn'

In [None]:
# === CONFIGURATION ===
def get_model_path():
    linux_path = "/home/gijs/Documents/OxfordEvolution/Yeast/Chen/chen2004_biomd56.xml"
    mac_path = "/Users/gijsbartholomeus/Documents/STUDIE/OxfordEvolution/code/Yeast/Chen/chen2004_biomd56.xml"
    
    if os.path.exists(linux_path):
        return linux_path
    elif os.path.exists(mac_path):
        return mac_path
    else:
        raise FileNotFoundError("Could not find chen2004_biomd56.xml")

# Configuration
model_path = get_model_path()
multipliers = [0.25, 0.50, 0.75, 1.00, 1.25, 1.50, 1.75, 2.00]
SIMULATION_TIME = 1000
N_TIME_POINTS = 50  # Discretize to 50 time points as per paper
SAMPLE_SIZE = 100000  # 100K samples for initial testing
DIVERGENCE_THRESHOLD = 25

# BIRCH parameters
BIRCH_THRESHOLD = 0.5  # Distance threshold for clustering
BIRCH_BRANCHING_FACTOR = 50
BIRCH_N_CLUSTERS = None  # Let BIRCH determine automatically

print(f"✓ Configuration loaded")
print(f"   Model: {model_path}")
print(f"   Sample size: {SAMPLE_SIZE:,}")
print(f"   Time points: {N_TIME_POINTS}")

In [None]:
# === CORE FUNCTIONS ===
def get_kinetic_parameters(rr):
    """Get kinetic parameters excluding regulatory switches"""
    kinetic_params = []
    
    for pid in rr.getGlobalParameterIds():
        value = rr.getValue(pid)
        param_lower = pid.lower()
        
        if (param_lower.endswith('t') and value in [0.0, 1.0]) or \
           (param_lower.startswith('d') and param_lower.endswith('n')) or \
           ('flag' in param_lower) or \
           ('switch' in param_lower) or \
           (value == 0.0) or \
           (pid in ['cell']) or \
           ('total' in param_lower and value in [0.0, 1.0]):
            continue
        else:
            kinetic_params.append(pid)
    
    return kinetic_params

def sample_parameters(rr, wildtype=False):
    """Sample parameters with random multipliers"""
    rr.resetAll()
    kinetic_params = get_kinetic_parameters(rr)
    
    sampled_values = []
    
    if wildtype:
        return [1.0] * len(kinetic_params)
    
    for pid in kinetic_params:
        try:
            current = rr.getValue(pid)
            factor = random.choice(multipliers)
            rr.setValue(pid, current * factor)
            sampled_values.append(factor)
        except RuntimeError:
            sampled_values.append(1.0)
    
    return sampled_values

def simulate_full_network(rr):
    """Simulate all variables in the network"""
    # Get all species (variables)
    species_ids = rr.getFloatingSpeciesIds()
    
    # Set selections to all species plus time
    rr.selections = ["time"] + list(species_ids)
    
    try:
        result = rr.simulate(0, SIMULATION_TIME, N_TIME_POINTS + 1)
    except RuntimeError:
        return None, None
    
    time_data = result[:, 0]
    concentrations = result[:, 1:]  # All species concentrations
    
    # Check for divergence in any species
    if np.any(np.abs(concentrations) > DIVERGENCE_THRESHOLD):
        return "divergent", None
    
    # Check for negative concentrations (unphysical)
    if np.any(concentrations < 0):
        return "negative", None
    
    return species_ids, concentrations

print("✓ Core functions defined")

In [None]:
# === DATA COLLECTION ===
def collect_network_data(n_samples=1000):
    """Collect full network time series data"""
    print(f"Collecting {n_samples:,} network trajectories...")
    
    all_trajectories = []
    all_genotypes = []
    species_names = None
    
    success_count = 0
    divergent_count = 0
    negative_count = 0
    
    start_time = time.time()
    
    for i in range(n_samples):
        if (i + 1) % 1000 == 0:
            elapsed = time.time() - start_time
            rate = (i + 1) / elapsed
            print(f"  Progress: {i+1:,}/{n_samples:,} ({rate:.1f}/s) | Success: {success_count:,}")
        
        # Load fresh model and sample parameters
        rr = te.loadSBMLModel(model_path)
        genotype = sample_parameters(rr, wildtype=(i == 0))  # First sample is wildtype
        
        # Simulate full network
        species_ids, concentrations = simulate_full_network(rr)
        
        if isinstance(species_ids, str):
            if species_ids == "divergent":
                divergent_count += 1
            elif species_ids == "negative":
                negative_count += 1
            continue
        
        if species_ids is None:
            divergent_count += 1
            continue
        
        # Store species names from first successful run
        if species_names is None:
            species_names = species_ids
            print(f"  Network has {len(species_names)} species: {species_names[:5]}...")
        
        # Flatten concentrations matrix to create mn-dimensional vector
        # Shape: (n_timepoints, m_species) -> (n_timepoints * m_species,)
        trajectory_vector = concentrations.flatten()
        
        all_trajectories.append(trajectory_vector)
        all_genotypes.append(genotype)
        success_count += 1
    
    elapsed_total = time.time() - start_time
    success_rate = success_count / n_samples * 100
    
    print(f"\n✅ Data collection completed:")
    print(f"   Success: {success_count:,}/{n_samples:,} ({success_rate:.1f}%)")
    print(f"   Divergent: {divergent_count:,}")
    print(f"   Negative: {negative_count:,}")
    print(f"   Total time: {elapsed_total:.1f}s")
    
    if success_count > 0:
        trajectory_matrix = np.array(all_trajectories)
        print(f"   Trajectory matrix shape: {trajectory_matrix.shape}")
        print(f"   (samples × (timepoints × species): {success_count} × ({N_TIME_POINTS+1} × {len(species_names)}))")
        
        return {
            'trajectories': trajectory_matrix,
            'genotypes': all_genotypes,
            'species_names': species_names,
            'success_count': success_count,
            'divergent_count': divergent_count,
            'negative_count': negative_count
        }
    else:
        return None

# Collect the data
network_data = collect_network_data(SAMPLE_SIZE)

In [None]:
# === BIRCH CLUSTERING ===
if network_data is not None:
    print("\n🔬 Applying BIRCH clustering...")
    
    trajectories = network_data['trajectories']
    
    # Normalize data (important for Euclidean distance)
    from sklearn.preprocessing import StandardScaler
    scaler = StandardScaler()
    trajectories_normalized = scaler.fit_transform(trajectories)
    
    print(f"   Data shape: {trajectories_normalized.shape}")
    print(f"   Data normalized: mean={np.mean(trajectories_normalized):.3f}, std={np.std(trajectories_normalized):.3f}")
    
    # Apply BIRCH clustering
    start_time = time.time()
    
    birch = Birch(
        n_clusters=BIRCH_N_CLUSTERS,
        threshold=BIRCH_THRESHOLD,
        branching_factor=BIRCH_BRANCHING_FACTOR
    )
    
    cluster_labels = birch.fit_predict(trajectories_normalized)
    
    clustering_time = time.time() - start_time
    
    # Analyze clustering results
    n_clusters = len(np.unique(cluster_labels))
    cluster_counts = Counter(cluster_labels)
    
    print(f"\n✅ BIRCH clustering completed:")
    print(f"   Clustering time: {clustering_time:.1f}s")
    print(f"   Number of clusters (phenotypes): {n_clusters:,}")
    print(f"   Largest cluster: {max(cluster_counts.values()):,} genotypes")
    print(f"   Smallest cluster: {min(cluster_counts.values()):,} genotypes")
    print(f"   Average cluster size: {np.mean(list(cluster_counts.values())):.1f}")
    
    # Calculate silhouette score (if computationally feasible)
    if len(trajectories_normalized) <= 10000:  # Only for smaller datasets
        silhouette = silhouette_score(trajectories_normalized, cluster_labels)
        print(f"   Silhouette score: {silhouette:.3f}")
    
    network_data['cluster_labels'] = cluster_labels
    network_data['n_clusters'] = n_clusters
    network_data['cluster_counts'] = cluster_counts

else:
    print("❌ No data collected - skipping clustering")

In [None]:
# === CLUSTER VALIDATION ===
if network_data is not None and 'cluster_labels' in network_data:
    print("\n🔍 Validating cluster separation...")
    
    trajectories_normalized = scaler.transform(network_data['trajectories'])
    cluster_labels = network_data['cluster_labels']
    
    # Sample subset for distance calculations (computational efficiency)
    max_samples_for_validation = 5000
    if len(trajectories_normalized) > max_samples_for_validation:
        indices = np.random.choice(len(trajectories_normalized), max_samples_for_validation, replace=False)
        sample_trajectories = trajectories_normalized[indices]
        sample_labels = cluster_labels[indices]
        print(f"   Validation on {max_samples_for_validation:,} random samples")
    else:
        sample_trajectories = trajectories_normalized
        sample_labels = cluster_labels
    
    # Calculate pairwise distances
    distances = pairwise_distances(sample_trajectories, metric='euclidean')
    
    intra_cluster_distances = []
    inter_cluster_distances = []
    
    for i in range(len(sample_trajectories)):
        for j in range(i + 1, len(sample_trajectories)):
            if sample_labels[i] == sample_labels[j]:
                # Same cluster (intra-cluster)
                intra_cluster_distances.append(distances[i, j])
            else:
                # Different clusters (inter-cluster)
                inter_cluster_distances.append(distances[i, j])
    
    # Statistics
    if intra_cluster_distances and inter_cluster_distances:
        mean_intra = np.mean(intra_cluster_distances)
        mean_inter = np.mean(inter_cluster_distances)
        std_intra = np.std(intra_cluster_distances)
        std_inter = np.std(inter_cluster_distances)
        
        separation_ratio = mean_inter / mean_intra
        
        print(f"\n📊 Cluster validation results:")
        print(f"   Intra-cluster distances: {mean_intra:.3f} ± {std_intra:.3f}")
        print(f"   Inter-cluster distances: {mean_inter:.3f} ± {std_inter:.3f}")
        print(f"   Separation ratio: {separation_ratio:.2f}")
        
        if separation_ratio > 1.5:
            print(f"   ✅ Good cluster separation (ratio > 1.5)")
        elif separation_ratio > 1.0:
            print(f"   ⚠️  Moderate cluster separation (ratio > 1.0)")
        else:
            print(f"   ❌ Poor cluster separation (ratio ≤ 1.0)")
        
        network_data['validation'] = {
            'mean_intra_distance': mean_intra,
            'mean_inter_distance': mean_inter,
            'separation_ratio': separation_ratio,
            'n_intra_comparisons': len(intra_cluster_distances),
            'n_inter_comparisons': len(inter_cluster_distances)
        }

else:
    print("❌ No clustering results - skipping validation")

In [None]:
# === PHENOTYPE ANALYSIS ===
if network_data is not None and 'cluster_labels' in network_data:
    print("\n📈 Analyzing phenotype distribution...")
    
    cluster_counts = network_data['cluster_counts']
    cluster_labels = network_data['cluster_labels']
    n_clusters = network_data['n_clusters']
    
    # Create frequency distribution
    frequencies = list(cluster_counts.values())
    frequencies.sort(reverse=True)
    
    # Plot cluster size distribution
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # Plot 1: Cluster size histogram
    axes[0].hist(frequencies, bins=50, alpha=0.7, color='skyblue', edgecolor='black')
    axes[0].set_xlabel('Cluster Size (Number of Genotypes)')
    axes[0].set_ylabel('Number of Clusters')
    axes[0].set_title(f'Cluster Size Distribution\n{n_clusters:,} clusters total')
    axes[0].set_yscale('log')
    axes[0].grid(True, alpha=0.3)
    
    # Plot 2: Rank-frequency plot
    ranks = np.arange(1, len(frequencies) + 1)
    axes[1].loglog(ranks, frequencies, 'bo-', markersize=3, alpha=0.7)
    axes[1].set_xlabel('Phenotype Rank')
    axes[1].set_ylabel('Phenotype Frequency')
    axes[1].set_title('Rank-Frequency Distribution\n(Log-log scale)')
    axes[1].grid(True, alpha=0.3)
    
    # Plot 3: Cumulative distribution
    cumulative_genotypes = np.cumsum(frequencies)
    cumulative_fraction = cumulative_genotypes / network_data['success_count']
    axes[2].semilogx(ranks, cumulative_fraction, 'ro-', markersize=3, alpha=0.7)
    axes[2].set_xlabel('Number of Most Frequent Clusters')
    axes[2].set_ylabel('Fraction of All Genotypes')
    axes[2].set_title('Cumulative Genotype Distribution')
    axes[2].grid(True, alpha=0.3)
    axes[2].axhline(y=0.5, color='gray', linestyle='--', alpha=0.7, label='50% of genotypes')
    axes[2].legend()
    
    plt.tight_layout()
    plt.savefig('birch_clustering_analysis.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Summary statistics
    print(f"\n📊 Phenotype distribution summary:")
    print(f"   Total genotypes analyzed: {network_data['success_count']:,}")
    print(f"   Total phenotypes (clusters): {n_clusters:,}")
    print(f"   Largest phenotype: {frequencies[0]:,} genotypes ({frequencies[0]/network_data['success_count']*100:.1f}%)")
    print(f"   Smallest phenotype: {frequencies[-1]:,} genotypes")
    print(f"   Singleton phenotypes: {sum(1 for f in frequencies if f == 1):,}")
    print(f"   Average phenotype size: {np.mean(frequencies):.1f} genotypes")
    print(f"   Median phenotype size: {np.median(frequencies):.1f} genotypes")
    
    # Find number of clusters containing 50% of genotypes
    half_genotypes = network_data['success_count'] // 2
    for i, cum_count in enumerate(cumulative_genotypes):
        if cum_count >= half_genotypes:
            clusters_for_half = i + 1
            break
    
    print(f"   Top {clusters_for_half:,} clusters contain 50% of all genotypes")
    print(f"   Cluster diversity: {clusters_for_half/n_clusters*100:.1f}% of clusters needed for 50% coverage")

else:
    print("❌ No clustering results - skipping phenotype analysis")

In [None]:
# === EXAMPLE PHENOTYPES ===
if network_data is not None and 'cluster_labels' in network_data:
    print("\n🎨 Visualizing example phenotypes...")
    
    cluster_counts = network_data['cluster_counts']
    cluster_labels = network_data['cluster_labels']
    trajectories = network_data['trajectories']
    species_names = network_data['species_names']
    
    # Get largest and smallest clusters
    sorted_clusters = sorted(cluster_counts.items(), key=lambda x: x[1], reverse=True)
    
    largest_cluster_id = sorted_clusters[0][0]
    smallest_cluster_id = sorted_clusters[-1][0]
    
    # Pick a medium-sized cluster
    medium_cluster_id = sorted_clusters[len(sorted_clusters)//2][0]
    
    clusters_to_plot = [
        (largest_cluster_id, "Largest", cluster_counts[largest_cluster_id]),
        (medium_cluster_id, "Medium", cluster_counts[medium_cluster_id]),
        (smallest_cluster_id, "Smallest", cluster_counts[smallest_cluster_id])
    ]
    
    # Plot example trajectories from each cluster
    fig, axes = plt.subplots(3, 3, figsize=(18, 12))
    fig.suptitle('Example Phenotypes: Network Trajectories by Cluster Size', fontsize=16)
    
    # Select key species to plot (first 3 species)
    key_species = species_names[:3]
    n_timepoints = N_TIME_POINTS + 1
    n_species = len(species_names)
    
    time_points = np.linspace(0, SIMULATION_TIME, n_timepoints)
    
    for row, (cluster_id, cluster_type, cluster_size) in enumerate(clusters_to_plot):
        # Get indices of genotypes in this cluster
        cluster_indices = np.where(cluster_labels == cluster_id)[0]
        
        # Sample up to 5 trajectories from this cluster
        n_examples = min(5, len(cluster_indices))
        example_indices = np.random.choice(cluster_indices, n_examples, replace=False)
        
        for col, species in enumerate(key_species):
            species_idx = species_names.index(species)
            
            for example_idx in example_indices:
                # Extract trajectory for this species
                trajectory = trajectories[example_idx]
                # Reshape to (timepoints, species) and extract specific species
                trajectory_matrix = trajectory.reshape(n_timepoints, n_species)
                species_trajectory = trajectory_matrix[:, species_idx]
                
                axes[row, col].plot(time_points, species_trajectory, alpha=0.7, linewidth=1)
            
            axes[row, col].set_title(f'{cluster_type} Cluster: {species}\n({cluster_size:,} genotypes)')
            axes[row, col].set_xlabel('Time (min)')
            axes[row, col].set_ylabel('Concentration')
            axes[row, col].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('example_phenotype_trajectories.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"✅ Example phenotype visualizations saved")
    print(f"   Largest cluster ({largest_cluster_id}): {cluster_counts[largest_cluster_id]:,} genotypes")
    print(f"   Medium cluster ({medium_cluster_id}): {cluster_counts[medium_cluster_id]:,} genotypes")
    print(f"   Smallest cluster ({smallest_cluster_id}): {cluster_counts[smallest_cluster_id]:,} genotypes")

else:
    print("❌ No clustering results - skipping example visualization")

In [None]:
# === FINAL SUMMARY ===
if network_data is not None:
    print("\n" + "="*80)
    print("MULTI-VARIABLE PHENOTYPE CLUSTERING SUMMARY")
    print("="*80)
    
    print(f"\n🔬 DATASET OVERVIEW:")
    print(f"   Samples attempted: {SAMPLE_SIZE:,}")
    print(f"   Successful simulations: {network_data['success_count']:,} ({network_data['success_count']/SAMPLE_SIZE*100:.1f}%)")
    print(f"   Network species: {len(network_data['species_names'])}")
    print(f"   Time points: {N_TIME_POINTS + 1}")
    print(f"   Feature dimensions: {len(network_data['species_names']) * (N_TIME_POINTS + 1):,}")
    
    if 'cluster_labels' in network_data:
        print(f"\n🎯 CLUSTERING RESULTS:")
        print(f"   Phenotypes discovered: {network_data['n_clusters']:,}")
        print(f"   BIRCH threshold: {BIRCH_THRESHOLD}")
        print(f"   Branching factor: {BIRCH_BRANCHING_FACTOR}")
        
        frequencies = list(network_data['cluster_counts'].values())
        print(f"   Largest phenotype: {max(frequencies):,} genotypes")
        print(f"   Smallest phenotype: {min(frequencies):,} genotypes")
        print(f"   Average phenotype size: {np.mean(frequencies):.1f} genotypes")
        
        singleton_count = sum(1 for f in frequencies if f == 1)
        print(f"   Singleton phenotypes: {singleton_count:,} ({singleton_count/network_data['n_clusters']*100:.1f}%)")
        
        if 'validation' in network_data:
            val = network_data['validation']
            print(f"\n✅ CLUSTER VALIDATION:")
            print(f"   Separation ratio: {val['separation_ratio']:.2f}")
            print(f"   Intra-cluster distance: {val['mean_intra_distance']:.3f}")
            print(f"   Inter-cluster distance: {val['mean_inter_distance']:.3f}")
    
    print(f"\n📊 GENOTYPE-PHENOTYPE MAPPING:")
    if 'cluster_labels' in network_data:
        total_genotypes = network_data['success_count']
        total_phenotypes = network_data['n_clusters']
        redundancy = total_genotypes / total_phenotypes
        
        print(f"   Many-to-one ratio: {redundancy:.1f} genotypes per phenotype (average)")
        print(f"   Phenotype diversity: {total_phenotypes/total_genotypes*100:.2f}% unique phenotypes")
        
        # Calculate entropy
        frequencies = np.array(list(network_data['cluster_counts'].values()))
        probabilities = frequencies / total_genotypes
        entropy = -np.sum(probabilities * np.log2(probabilities))
        max_entropy = np.log2(total_phenotypes)
        normalized_entropy = entropy / max_entropy
        
        print(f"   Phenotype entropy: {entropy:.2f} bits ({normalized_entropy:.2f} normalized)")
    
    print(f"\n📁 OUTPUT FILES:")
    print(f"   • birch_clustering_analysis.png - Cluster distribution plots")
    print(f"   • example_phenotype_trajectories.png - Example phenotype trajectories")
    
    print("\n" + "="*80)
    print("ANALYSIS COMPLETED SUCCESSFULLY! 🎉")
    print("="*80)

else:
    print("\n❌ Analysis failed - no successful simulations")