In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import re
from pathlib import Path
import glob

def parse_colabfold_results(output_text):
   # parse the colabfold output
   results = []
   
   # find lines with results - go through each line
   lines = output_text.split('\n')
   current_job = None
   current_metrics = {}
   
   for line in lines:
       line = line.strip()
       
       # get job name
       if 'Processing' in line:
           current_job = line.split('Processing ')[1].split('...')[0]
           current_metrics = {}
           
       # get the metrics
       elif 'mean_plddt' in line and 'max_ptm' in line and 'max_iptm' in line:
           # extract numbers manually
           plddt_part = line.split("'mean_plddt': ")[1].split(',')[0]
           ptm_part = line.split("'max_ptm': ")[1].split(',')[0] 
           iptm_part = line.split("'max_iptm': ")[1].split('}')[0]
           
           current_metrics['mean_plddt'] = float(plddt_part)
           current_metrics['max_ptm'] = float(ptm_part)
           current_metrics['max_iptm'] = float(iptm_part)
           
       # get binding score
       elif 'Binding score:' in line:
           score_part = line.split('Binding score: ')[1]
           current_metrics['binding_score'] = float(score_part)
           
           # now we have everything for this job
           if current_job and len(current_metrics) == 4:
               job_name = current_job
               mean_plddt = current_metrics['mean_plddt']
               max_ptm = current_metrics['max_ptm']
               max_iptm = current_metrics['max_iptm'] 
               binding_score = current_metrics['binding_score']
               
               # get nanobody and antigen names from job name
               parts = job_name.split('_')
               nanobody = parts[0] + '_' + parts[1]  # like nbGFP_6xzf
               antigen = parts[2]  # like GFP
               
               result = {
                   'job_name': job_name,
                   'nanobody': nanobody,
                   'antigen': antigen,
                   'mean_plddt': mean_plddt,
                   'max_ptm': max_ptm,
                   'max_iptm': max_iptm,
                   'binding_score': binding_score
               }
               results.append(result)
   
   return pd.DataFrame(results)

def load_multiple_seed_files(base_filename, seeds=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]):
    # Load and combine results from multiple seed files
    all_results = []
    
    for seed in seeds:
        # Generate filename for this seed
        seed_filename = base_filename.replace('.txt', f'_seed_{seed}.txt')
        
        try:
            # Try to read the file
            try:
                with open(seed_filename, 'r', encoding='utf-8') as file:
                    file_content = file.read()
            except UnicodeDecodeError:
                with open(seed_filename, 'r', encoding='utf-8', errors='ignore') as file:
                    file_content = file.read()
            
            # Parse this seed's results
            seed_df = parse_colabfold_results(file_content)
            
            if len(seed_df) > 0:
                # Add seed information
                seed_df['seed'] = seed
                all_results.append(seed_df)
                print(f"Loaded {len(seed_df)} results from seed {seed}")
            else:
                print(f"No results found in {seed_filename}")
                
        except FileNotFoundError:
            print(f"File not found: {seed_filename}")
            continue
    
    if all_results:
        # Combine all seeds
        combined_df = pd.concat(all_results, ignore_index=True)
        print(f"Total combined results: {len(combined_df)} experiments across {len(all_results)} seeds")
        return combined_df
    else:
        print("No seed files found!")
        return pd.DataFrame()

def calculate_cross_seed_statistics(df):
    # Calculate mean and std across seeds for each nanobody-antigen pair
    
    # Group by nanobody-antigen combination and calculate statistics
    grouped = df.groupby(['nanobody', 'antigen']).agg({
        'mean_plddt': ['mean', 'std', 'count'],
        'max_ptm': ['mean', 'std', 'count'], 
        'max_iptm': ['mean', 'std', 'count'],
        'binding_score': ['mean', 'std', 'count']
    }).round(3)
    
    # Flatten column names
    grouped.columns = ['_'.join(col).strip() for col in grouped.columns]
    grouped = grouped.reset_index()
    
    # Add quality flags based on mean values
    grouped['high_confidence'] = (grouped['max_iptm_mean'] > 0.7) & (grouped['max_ptm_mean'] > 0.7)
    grouped['good_interface'] = grouped['max_iptm_mean'] > 0.5
    grouped['good_structure'] = grouped['mean_plddt_mean'] > 80
    
    # Fill NaN std values with 0 (happens when only 1 seed available)
    std_columns = [col for col in grouped.columns if '_std' in col]
    grouped[std_columns] = grouped[std_columns].fillna(0)
    
    return grouped

def create_binding_matrix_visualization_multiseed(df_stats, score_column='binding_score_mean', 
                                                 std_column='binding_score_std', 
                                                 title_suffix='AlphaFold2 Multimer Score Multi-Seed'):
    # Create visualization with mean +/- std from multiple seeds
    
    if len(df_stats) == 0:
        print("No data to plot")
        return None
        
    # Sort the names alphabetically (case insensitive)
    nanobodies = sorted(df_stats['nanobody'].unique(), key=str.lower)
    antigens = sorted(df_stats['antigen'].unique(), key=str.lower)
    
    # Make pivot tables for mean and std
    df_clean = df_stats.drop_duplicates(subset=['nanobody', 'antigen'], keep='first')
    
    mean_matrix = df_clean.pivot(index='nanobody', columns='antigen', values=score_column)
    mean_matrix = mean_matrix.reindex(index=nanobodies, columns=antigens)
    
    std_matrix = df_clean.pivot(index='nanobody', columns='antigen', values=std_column)
    std_matrix = std_matrix.reindex(index=nanobodies, columns=antigens)
    
    # Create count matrix to show how many seeds contributed
    count_matrix = df_clean.pivot(index='nanobody', columns='antigen', values=score_column.replace('_mean', '_count'))
    count_matrix = count_matrix.reindex(index=nanobodies, columns=antigens)
    
    # Make the plot bigger
    plt.figure(figsize=(16, 14))
    
    # Use viridis colormap 
    cmap = 'viridis'
    
    # Create annotation labels with mean +/- std
    annot_labels = mean_matrix.round(3).astype(str) + '\n+/-' + std_matrix.round(3).astype(str)
    
    # Add seed count information where helpful
    for i in range(len(nanobodies)):
        for j in range(len(antigens)):
            if not pd.isna(count_matrix.iloc[i, j]) and count_matrix.iloc[i, j] < 3:
                # Show count if less than 3 seeds
                current_label = annot_labels.iloc[i, j]
                seed_count = int(count_matrix.iloc[i, j])
                annot_labels.iloc[i, j] = f"{current_label}\n(n={seed_count})"

    ax = sns.heatmap(
       mean_matrix,
       annot=annot_labels,
       fmt='',
       cmap=cmap,
       square=True,
       linewidths=0.5,
       cbar_kws={'label': f'{title_suffix}'},
       annot_kws={'size': 14}  # smaller font to fit multiple lines
    )
    
    # Make colorbar label bigger
    cbar = ax.collections[0].colorbar
    cbar.set_label(f'{title_suffix}', size=20)
    
    # Highlight the diagonal - these should be the real binding pairs
    for i in range(min(len(nanobodies), len(antigens))):
        ax.add_patch(plt.Rectangle((i, i), 1, 1, fill=False, edgecolor='red', lw=5))
    
    plt.title(f'Nanobody-Antigen Binding Matrix ({title_suffix})', fontsize=22, pad=20)
    plt.xlabel('Antigens', fontsize=22)
    plt.ylabel('Nanobodies', fontsize=22)
    plt.xticks(rotation=45, ha='right', fontsize=18)
    plt.yticks(rotation=0, fontsize=18)
    plt.tight_layout()
    
    # Save with descriptive filename
    plt.savefig("binding_matrix_alphafold2_multimer_multiseed.png", dpi=300, bbox_inches='tight')
    plt.show()
    
    return mean_matrix, std_matrix

def analyze_multiseed_results(df_raw, df_stats):
    # Analyze multi-seed ColabFold results
    print("=== AlphaFold2 Multimer Multi-Seed Results Analysis ===")
    print()
    
    if len(df_raw) == 0:
        print("ERROR: No data found - check file format")
        return
    
    # Basic statistics
    n_seeds = df_raw['seed'].nunique()
    n_combinations = len(df_stats)
    
    print(f"Seeds analyzed: {sorted(df_raw['seed'].unique())}")
    print(f"Total experiments: {len(df_raw)}")
    print(f"Unique nanobody-antigen combinations: {n_combinations}")
    print(f"Number of nanobodies: {df_stats['nanobody'].nunique()}")
    print(f"Number of antigens: {df_stats['antigen'].nunique()}")
    print()
    
    # Show completeness - how many combinations have all seeds
    complete_combinations = (df_stats['binding_score_count'] == n_seeds).sum()
    print(f"Combinations with all {n_seeds} seeds: {complete_combinations}/{n_combinations}")
    
    # Show seed distribution
    seed_counts = df_stats['binding_score_count'].value_counts().sort_index()
    print("Seed count distribution:")
    for count, freq in seed_counts.items():
        print(f"  {int(count)} seeds: {freq} combinations")
    print()
    
    # Quality check on mean values
    high_conf_count = df_stats['high_confidence'].sum()
    good_interface_count = df_stats['good_interface'].sum() 
    good_structure_count = df_stats['good_structure'].sum()
    
    print("Quality check (based on mean across seeds):")
    print(f"  High confidence (ipTM>0.7, pTM>0.7): {high_conf_count}/{n_combinations}")
    print(f"  Good interface (ipTM>0.5): {good_interface_count}/{n_combinations}")
    print(f"  Good structure (pLDDT>80): {good_structure_count}/{n_combinations}")
    print()
    
    # Top binding pairs by mean score
    print("Top 5 binding pairs (highest mean binding score):")
    top_pairs = df_stats.nlargest(5, 'binding_score_mean')
    for _, row in top_pairs.iterrows():
        confidence_flag = "high_conf" if row['high_confidence'] else "low_conf "
        print(f"  {confidence_flag} {row['nanobody']} + {row['antigen']}: "
              f"{row['binding_score_mean']:.3f}+/-{row['binding_score_std']:.3f} "
              f"(ipTM: {row['max_iptm_mean']:.3f}+/-{row['max_iptm_std']:.3f}, "
              f"pTM: {row['max_ptm_mean']:.3f}+/-{row['max_ptm_std']:.3f}, n={int(row['binding_score_count'])})")
    print()
    
    # Check diagonal pairs
    diagonal_pairs = []
    for _, row in df_stats.iterrows():
        nb_clean = row['nanobody'].replace('nb', '').split('_')[0].upper()
        ag_clean = row['antigen'].upper()
        if nb_clean == ag_clean:
            diagonal_pairs.append(row)
    
    if diagonal_pairs:
        print("Expected binding pairs (diagonal matches):")
        for _, pair in pd.DataFrame(diagonal_pairs).iterrows():
            confidence_flag = "high_conf" if pair['high_confidence'] else "low_conf "
            print(f"  {confidence_flag} {pair['nanobody']} + {pair['antigen']}: "
                  f"{pair['binding_score_mean']:.3f}+/-{pair['binding_score_std']:.3f} "
                  f"(ipTM: {pair['max_iptm_mean']:.3f}+/-{pair['max_iptm_std']:.3f}, "
                  f"pTM: {pair['max_ptm_mean']:.3f}+/-{pair['max_ptm_std']:.3f}, n={int(pair['binding_score_count'])})")
        print()
    
    # Summary statistics across all experiments
    print("Summary statistics (across all seeds):")
    print(f"  Binding Score - Mean: {df_raw['binding_score'].mean():.3f} "
          f"Std: {df_raw['binding_score'].std():.3f} "
          f"Range: {df_raw['binding_score'].min():.3f} to {df_raw['binding_score'].max():.3f}")
    print(f"  ipTM - Mean: {df_raw['max_iptm'].mean():.3f} "
          f"Std: {df_raw['max_iptm'].std():.3f} "
          f"Range: {df_raw['max_iptm'].min():.3f} to {df_raw['max_iptm'].max():.3f}")
    print(f"  pTM - Mean: {df_raw['max_ptm'].mean():.3f} "
          f"Std: {df_raw['max_ptm'].std():.3f} "
          f"Range: {df_raw['max_ptm'].min():.3f} to {df_raw['max_ptm'].max():.3f}")
    print(f"  pLDDT - Mean: {df_raw['mean_plddt'].mean():.1f} "
          f"Std: {df_raw['mean_plddt'].std():.1f} "
          f"Range: {df_raw['mean_plddt'].min():.1f} to {df_raw['mean_plddt'].max():.1f}")

# Run the multi-seed analysis
if __name__ == "__main__":
    # Base filename - will look for _seed_0.txt, _seed_1.txt, _seed_2.txt versions
    base_filename = "alphafold2_multimer_v3_combined_information.txt"
    seeds_to_analyze = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
    
    print("=== Loading Multi-Seed AlphaFold2 Multimer Results ===")
    print(f"Looking for files: {base_filename.replace('.txt', '_seed_X.txt')}")
    print()
    
    # Load all seed files
    df_raw = load_multiple_seed_files(base_filename, seeds_to_analyze)
    
    if len(df_raw) > 0:
        # Calculate statistics across seeds
        df_stats = calculate_cross_seed_statistics(df_raw)
        
        print(f"Calculated statistics for {len(df_stats)} unique combinations")
        print()
        
        # Create visualizations
        print("Creating visualization...")
        mean_matrix, std_matrix = create_binding_matrix_visualization_multiseed(
            df_stats, 'binding_score_mean', 'binding_score_std'
        )
        
        # Analyze results
        analyze_multiseed_results(df_raw, df_stats)
        
        # Save results
        df_raw.to_csv('alphafold2_results_all_seeds.csv', index=False)
        df_stats.to_csv('alphafold2_results_statistics.csv', index=False)
        if mean_matrix is not None:
            mean_matrix.to_csv('alphafold2_binding_matrix_mean.csv')
            std_matrix.to_csv('alphafold2_binding_matrix_std.csv')
        
        print()
        print("=== Files Saved ===")
        print("alphafold2_results_all_seeds.csv - Raw data from all seeds")
        print("alphafold2_results_statistics.csv - Mean+/-Std statistics")
        print("alphafold2_binding_matrix_mean.csv - Mean binding matrix")
        print("alphafold2_binding_matrix_std.csv - Std deviation matrix")
        print("binding_matrix_alphafold2_multimer_multiseed.png - Visualization")
        
    else:
        print("No data found - check that seed files exist!")
        print(f"Expected files like: {base_filename.replace('.txt', '_seed_0.txt')}")
    
    print("Done!")