In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json

# Import our custom functions
import sys
sys.path.append('..')
from binder_lab.common import (
    load_predictions_csv, 
    load_and_analyze_predictions,
    get_structure_info,
    get_confidence_summary,
    get_plddt_summary
)

# Set up plotting
plt.style.use('default')
sns.set_palette("husl")
%matplotlib inline


In [None]:
# Path to your results (adjust as needed)
results_dir = Path('../results')  # Adjust this path
csv_path = results_dir / 'all_predictions_data.csv'

# Check if the file exists
if csv_path.exists():
    print(f"Loading predictions from: {csv_path}")
    
    # Load with full analysis (structures, confidence, NPZ data)
    df = load_and_analyze_predictions(csv_path, base_dir=results_dir)
    
    print(f"\nLoaded {len(df)} predictions")
    print(f"Columns: {list(df.columns)}")
else:
    print(f"CSV file not found: {csv_path}")
    print("Please run the Snakemake workflow first to generate prediction data.")
    # Create dummy data for demonstration
    df = pd.DataFrame({
        'predictor': ['boltz1', 'boltz2'],
        'design_name': ['demo_design', 'demo_design'],
        'model_idx': [0, 0],
        'note': ['This is demo data - run workflow to get real results']
    })


In [None]:
# Basic info about the dataset
if len(df) > 0 and 'predictor' in df.columns:
    print("=== Dataset Overview ===")
    print(f"Total predictions: {len(df)}")
    print(f"Predictors: {sorted(df['predictor'].unique())}")
    print(f"Designs: {sorted(df['design_name'].unique())}")
    
    if 'model_idx' in df.columns:
        print(f"Models per design: {df['model_idx'].max() + 1}")
    
    # Show first few rows
    print("\n=== First Few Rows ===")
    display_cols = ['predictor', 'design_name', 'model_idx']
    if 'plddt_mean' in df.columns:
        display_cols.extend(['plddt_mean', 'struct_num_residues'])
    
    available_cols = [col for col in display_cols if col in df.columns]
    print(df[available_cols].head())
else:
    print("No prediction data available - please run the workflow first.")


In [None]:
if len(df) > 0 and 'design_dict' in df.columns:
    # Look at the original design for the first prediction
    first_design_dict = df.iloc[0]['design_dict']
    original_design = json.loads(first_design_dict)
    
    print("=== Original Design Data ===")
    print(f"Design name: {original_design.get('name', 'Unknown')}")
    print(f"Sequences:")
    
    for i, seq in enumerate(original_design.get('sequences', [])):
        print(f"  Sequence {i+1}:")
        if 'protein' in seq:
            protein = seq['protein']
            print(f"    Type: Protein")
            print(f"    ID: {protein.get('id')}")
            print(f"    Length: {len(protein.get('sequence', ''))} residues")
            print(f"    Sequence: {protein.get('sequence', '')[:50]}...")
            if 'designed' in protein:
                designed = protein['designed']
                num_designed = designed.count('D')
                print(f"    Designed positions: {num_designed}/{len(designed)}")
        elif 'ligand' in seq:
            ligand = seq['ligand']
            print(f"    Type: Ligand")
            print(f"    ID: {ligand.get('id')}")
            print(f"    CCD: {ligand.get('ccd')}")
        print()
else:
    print("No design data available.")


In [None]:
if len(df) > 0 and 'structure' in df.columns:
    # Check how many structures were successfully parsed
    structures_parsed = df['structure'].notna().sum()
    print(f"=== Structure Analysis ===")
    print(f"Structures successfully parsed: {structures_parsed}/{len(df)}")
    
    if structures_parsed > 0:
        # Look at structure info columns
        struct_cols = [col for col in df.columns if col.startswith('struct_')]
        if struct_cols:
            print(f"\nStructure info columns: {struct_cols}")
            print("\nStructure summary:")
            display(df[['predictor', 'design_name', 'model_idx'] + struct_cols].head())
        
        # Get a specific structure object
        first_structure = df[df['structure'].notna()].iloc[0]['structure']
        if first_structure is not None:
            print(f"\n=== Example Structure Details ===")
            print(f"Structure name: {first_structure.name}")
            print(f"Number of models: {len(first_structure)}")
            
            if len(first_structure) > 0:
                model = first_structure[0]
                print(f"Number of chains: {len(model)}")
                
                for chain in model:
                    print(f"  Chain {chain.name}: {len(chain)} residues")
                    if len(chain) > 0:
                        first_res = chain[0]
                        last_res = chain[-1]
                        print(f"    Residues: {first_res.name}{first_res.seqid.num} to {last_res.name}{last_res.seqid.num}")
    else:
        print("No structures were parsed. Check if gemmi is installed and CIF files exist.")
else:
    print("No structure data available.")


In [None]:
if len(df) > 0 and 'plddt_mean' in df.columns:
    print("=== pLDDT Analysis ===")
    
    # Summary statistics
    print(f"Mean pLDDT across all predictions: {df['plddt_mean'].mean():.2f}")
    print(f"pLDDT range: {df['plddt_mean'].min():.2f} - {df['plddt_mean'].max():.2f}")
    
    # Look at pLDDT summary columns
    plddt_cols = [col for col in df.columns if col.startswith('plddt_')]
    if plddt_cols:
        print(f"\npLDDT summary columns: {plddt_cols}")
        print("\npLDDT statistics:")
        display(df[['predictor', 'design_name', 'model_idx'] + plddt_cols].head())
    
    # Create visualization
    if len(df) > 1:
        print("\n=== Creating Visualizations ===")
        
        fig, axes = plt.subplots(1, 3, figsize=(18, 5))
        
        # 1. pLDDT distribution
        axes[0].hist(df['plddt_mean'], bins=20, alpha=0.7, edgecolor='black', color='skyblue')
        axes[0].set_xlabel('Mean pLDDT')
        axes[0].set_ylabel('Count')
        axes[0].set_title('Distribution of Mean pLDDT Scores')
        axes[0].axvline(df['plddt_mean'].mean(), color='red', linestyle='--', 
                       label=f"Mean: {df['plddt_mean'].mean():.1f}")
        axes[0].legend()
        
        # 2. pLDDT by predictor (if multiple predictors)
        if len(df['predictor'].unique()) > 1:
            df.boxplot(column='plddt_mean', by='predictor', ax=axes[1])
            axes[1].set_title('pLDDT by Predictor')
            axes[1].set_xlabel('Predictor')
            axes[1].set_ylabel('Mean pLDDT')
        else:
            axes[1].text(0.5, 0.5, 'Single predictor\n(no comparison)', 
                        ha='center', va='center', transform=axes[1].transAxes)
            axes[1].set_title('pLDDT by Predictor')
        
        # 3. High vs low confidence (if available)
        if 'plddt_high_conf' in df.columns and 'plddt_low_conf' in df.columns:
            scatter = axes[2].scatter(df['plddt_high_conf'], df['plddt_low_conf'], 
                                    alpha=0.6, c=df['plddt_mean'], cmap='viridis')
            axes[2].set_xlabel('Fraction High Confidence (>90)')
            axes[2].set_ylabel('Fraction Low Confidence (<50)')
            axes[2].set_title('High vs Low Confidence Residues')
            plt.colorbar(scatter, ax=axes[2], label='Mean pLDDT')
        else:
            axes[2].text(0.5, 0.5, 'Detailed confidence data\nnot available', 
                        ha='center', va='center', transform=axes[2].transAxes)
            axes[2].set_title('Confidence Analysis')
        
        plt.tight_layout()
        plt.show()
        
        # Comparative analysis if multiple predictors
        if len(df['predictor'].unique()) > 1:
            print("\n=== Comparative Analysis ===")
            print("\nMean pLDDT by predictor:")
            predictor_stats = df.groupby('predictor')['plddt_mean'].agg(['mean', 'std', 'count'])
            display(predictor_stats)
            
            print("\nMean pLDDT by design:")
            design_stats = df.groupby('design_name')['plddt_mean'].agg(['mean', 'std', 'count'])
            display(design_stats)
            
else:
    print("No pLDDT data available for analysis.")
