# BiTE Rigid Linker Analysis Pipeline
## AI Co-Scientist Challenge Korea 2026

### Pipeline Overview
| Cell | Function | Output |
|:----:|----------|--------|
| 1 | Load RMSD functions | Functions ready |
| 2 | Analyze structures & extract sequences | `sequence_df` |
| 3 | Filter & create **combined** FASTA | `all_sequences.fasta` |
| 4 | **AlphaFold (Fast Batch)** | PDB + JSON files |
| 5 | **RMSD + Antigen Distance (130Å target)** | `comparison_df` |
| 6 | Normal Mode Analysis (NMA) | `nma_results_df` |
| 7 | pLDDT & PAE Analysis | `json_results_df` |
| 8 | **Final Summary & Download** | Excel + ZIP |

### ⚡ Key Features
- **Fast MSA:** All sequences in one FASTA file (single MSA search)
- **Antigen Distance:** Measures CD3e-HER2 distance after scFv alignment
- **Target:** 130Å immunological synapse distance

### Required Files
Upload to `/content/`:
- `reference.pdb` - Reference BiTE structure
- `predictions/` - Folder with RFdiffusion PDB files
- `Include_antigen_with_perfect_distance.pdb` - Antigen-antibody complex

### Quick Start
1. `Runtime → Change runtime type → GPU`
2. Upload required files
3. `Runtime → Run all`
4. **First time:** Restart runtime when prompted, then `Run all` again

---

In [None]:
# ================================================================================
# CELL 1: RMSD ANALYSIS FUNCTIONS
# ================================================================================

#@title **Cell 1: Load RMSD Analysis Functions** { display-mode: "form" }
#@markdown This cell loads all core functions for RMSD calculation.
#@markdown
#@markdown **Run this cell first** - Functions will be reused in Cells 2 and 5.

# --- Package Installation ---
!pip install biopython --quiet

# --- Imports ---
import numpy as np
import pandas as pd
from pathlib import Path
from typing import List, Tuple, Dict, Optional
import warnings
warnings.filterwarnings('ignore')

from Bio.PDB import PDBParser, Superimposer, PPBuilder
from Bio.PDB.Atom import Atom
from Bio.PDB.Residue import Residue
from Bio.PDB.Polypeptide import is_aa

# --- Constants ---
SCFV1_LENGTH = 240  # First scFv domain length
SCFV2_LENGTH = 242  # Second scFv domain length

print("=" * 70)
print("CELL 1: Loading RMSD Analysis Functions")
print("=" * 70)

# =============================================================================
# CORE FUNCTIONS
# =============================================================================

def get_ca_atoms(residues: List[Residue]) -> List[Atom]:
    """Extract CA (alpha carbon) atoms from a list of residues."""
    ca_atoms = []
    for res in residues:
        if 'CA' in res:
            ca_atoms.append(res['CA'])
    return ca_atoms


def extract_residues(structure) -> List[Residue]:
    """Extract all amino acid residues from a structure."""
    residues = []
    for model in structure:
        for chain in model:
            for residue in chain:
                if is_aa(residue, standard=True):
                    residues.append(residue)
    return residues


def align_and_calculate_rmsd(
    ref_atoms: List[Atom],
    target_atoms: List[Atom],
    full_target_atoms: List[Atom] = None
) -> Tuple[float, List[Atom]]:
    """
    Align target atoms to reference atoms and calculate RMSD.
    Uses Superimposer to perform optimal superposition.
    """
    if len(ref_atoms) != len(target_atoms):
        raise ValueError(f"Atom count mismatch: ref={len(ref_atoms)}, target={len(target_atoms)}")

    if len(ref_atoms) == 0:
        raise ValueError("No atoms to align")

    sup = Superimposer()
    sup.set_atoms(ref_atoms, target_atoms)

    if full_target_atoms is not None:
        sup.apply(full_target_atoms)
    else:
        sup.apply(target_atoms)

    rmsd = sup.rms

    return rmsd, target_atoms


def calculate_rmsd_only(atoms1: List[Atom], atoms2: List[Atom]) -> float:
    """Calculate RMSD between two sets of atoms without alignment."""
    if len(atoms1) != len(atoms2):
        raise ValueError(f"Atom count mismatch: {len(atoms1)} vs {len(atoms2)}")

    coords1 = np.array([atom.get_coord() for atom in atoms1])
    coords2 = np.array([atom.get_coord() for atom in atoms2])

    diff = coords1 - coords2
    rmsd = np.sqrt(np.mean(np.sum(diff**2, axis=1)))

    return rmsd


def load_reference_structure(ref_path: str) -> Tuple[List[Atom], List[Atom]]:
    """Load reference structure and extract scFv1 and scFv2 CA atoms."""
    parser = PDBParser(QUIET=True)
    ref_structure = parser.get_structure('reference', ref_path)
    ref_residues = extract_residues(ref_structure)

    total_residues = len(ref_residues)
    print(f"Reference structure: {total_residues} residues")

    if total_residues < SCFV1_LENGTH + SCFV2_LENGTH:
        raise ValueError(f"Reference structure too short: {total_residues} residues")

    ref_scfv1 = ref_residues[:SCFV1_LENGTH]
    ref_scfv2 = ref_residues[-SCFV2_LENGTH:]

    ref_scfv1_ca = get_ca_atoms(ref_scfv1)
    ref_scfv2_ca = get_ca_atoms(ref_scfv2)

    print(f"  • scFv1: {len(ref_scfv1_ca)} CA atoms")
    print(f"  • scFv2: {len(ref_scfv2_ca)} CA atoms")

    return ref_scfv1_ca, ref_scfv2_ca


def analyze_single_target(
    ref_scfv1_ca: List[Atom],
    ref_scfv2_ca: List[Atom],
    target_path: str
) -> Dict:
    """Analyze a single target structure against reference domains."""
    result = {
        'Filename': Path(target_path).name,
        'Status': 'Failed',
        'Total_Residues': 0,
        'Linker_Length': 0,
        'scFv1_RMSD': None,
        'scFv2_RMSD': None,
        'Avg_RMSD': None,
        'Final_RMSD': None,
        'Best_Alignment': None,
        'Error': None
    }

    try:
        parser = PDBParser(QUIET=True)
        target_structure = parser.get_structure('target', target_path)
        target_residues = extract_residues(target_structure)

        total_residues = len(target_residues)
        result['Total_Residues'] = total_residues

        if total_residues < SCFV1_LENGTH + SCFV2_LENGTH:
            result['Error'] = f"Too few residues: {total_residues}"
            return result

        linker_length = total_residues - SCFV1_LENGTH - SCFV2_LENGTH
        result['Linker_Length'] = linker_length

        target_scfv1 = target_residues[:SCFV1_LENGTH]
        target_scfv2 = target_residues[-SCFV2_LENGTH:]

        target_scfv1_ca = get_ca_atoms(target_scfv1)
        target_scfv2_ca = get_ca_atoms(target_scfv2)
        target_all_ca = get_ca_atoms(target_residues)

        if len(target_scfv1_ca) != len(ref_scfv1_ca):
            result['Error'] = f"scFv1 atom mismatch: {len(target_scfv1_ca)} vs {len(ref_scfv1_ca)}"
            return result
        if len(target_scfv2_ca) != len(ref_scfv2_ca):
            result['Error'] = f"scFv2 atom mismatch: {len(target_scfv2_ca)} vs {len(ref_scfv2_ca)}"
            return result

        # Align on scFv1
        rmsd_scfv1, _ = align_and_calculate_rmsd(
            ref_scfv1_ca, target_scfv1_ca, target_all_ca
        )

        aligned_scfv2_ca = get_ca_atoms(target_residues[-SCFV2_LENGTH:])
        rmsd_scfv2_after_scfv1_align = calculate_rmsd_only(ref_scfv2_ca, aligned_scfv2_ca)

        parser2 = PDBParser(QUIET=True)
        target_structure2 = parser2.get_structure('target2', target_path)
        target_residues2 = extract_residues(target_structure2)
        target_scfv2_ca_2 = get_ca_atoms(target_residues2[-SCFV2_LENGTH:])
        target_all_ca_2 = get_ca_atoms(target_residues2)

        # Align on scFv2
        rmsd_scfv2, _ = align_and_calculate_rmsd(
            ref_scfv2_ca, target_scfv2_ca_2, target_all_ca_2
        )

        aligned_scfv1_ca = get_ca_atoms(target_residues2[:SCFV1_LENGTH])
        rmsd_scfv1_after_scfv2_align = calculate_rmsd_only(ref_scfv1_ca, aligned_scfv1_ca)

        avg_rmsd_align_scfv1 = (rmsd_scfv1 + rmsd_scfv2_after_scfv1_align) / 2
        avg_rmsd_align_scfv2 = (rmsd_scfv2 + rmsd_scfv1_after_scfv2_align) / 2

        if avg_rmsd_align_scfv1 <= avg_rmsd_align_scfv2:
            result['Final_RMSD'] = rmsd_scfv2_after_scfv1_align
            result['Best_Alignment'] = 'scFv1'
            result['scFv1_RMSD'] = rmsd_scfv1
            result['scFv2_RMSD'] = rmsd_scfv2_after_scfv1_align
        else:
            result['Final_RMSD'] = rmsd_scfv1_after_scfv2_align
            result['Best_Alignment'] = 'scFv2'
            result['scFv1_RMSD'] = rmsd_scfv1_after_scfv2_align
            result['scFv2_RMSD'] = rmsd_scfv2

        result['Avg_RMSD'] = min(avg_rmsd_align_scfv1, avg_rmsd_align_scfv2)
        result['Status'] = 'Success'

    except Exception as e:
        result['Error'] = str(e)

    return result


def analyze_batch(
    ref_scfv1_ca: List[Atom],
    ref_scfv2_ca: List[Atom],
    target_folder: str,
    pattern: str = "*.pdb"
) -> pd.DataFrame:
    """Analyze all PDB files in a folder."""
    target_path = Path(target_folder)
    pdb_files = sorted(target_path.glob(pattern))

    print(f"\nFound {len(pdb_files)} PDB files to analyze")
    print("-" * 50)

    results = []
    for i, pdb_file in enumerate(pdb_files, 1):
        print(f"  [{i}/{len(pdb_files)}] {pdb_file.name}", end="")
        result = analyze_single_target(ref_scfv1_ca, ref_scfv2_ca, str(pdb_file))

        if result['Status'] == 'Success':
            print(f" → RMSD: {result['Final_RMSD']:.4f} Å ({result['Best_Alignment']})")
        else:
            print(f" → FAILED: {result['Error']}")

        results.append(result)

    return pd.DataFrame(results)

# =============================================================================
# SUMMARY
# =============================================================================

print("\n" + "-" * 70)
print("✓ Functions loaded successfully!")
print("-" * 70)
print("\nAvailable functions:")
print("  • load_reference_structure(ref_path)")
print("  • analyze_single_target(ref_scfv1_ca, ref_scfv2_ca, target_path)")
print("  • analyze_batch(ref_scfv1_ca, ref_scfv2_ca, target_folder, pattern)")
print("\nDomain structure:")
print(f"  • scFv1: {SCFV1_LENGTH} residues (N-terminal)")
print(f"  • scFv2: {SCFV2_LENGTH} residues (C-terminal)")
print("=" * 70)


In [None]:
# ================================================================================
# CELL 2: INITIAL RMSD ANALYSIS & SEQUENCE EXTRACTION
# ================================================================================

#@title **Cell 2: Analyze Structures & Extract Sequences** { display-mode: "form" }
#@markdown ### Configuration
#@markdown Modify paths if needed:
REFERENCE_PATH = "/content/reference.pdb"  #@param {type:"string"}
TARGET_FOLDER = "/content/predictions/"  #@param {type:"string"}
FILE_PATTERN = "*.pdb"  #@param {type:"string"}

#@markdown ---
#@markdown **Requires:** Cell 1 must be run first.

# --- Display Settings ---
pd.set_option('display.max_colwidth', None)
pd.set_option('display.max_rows', None)
pd.set_option('display.width', None)

print("=" * 70)
print("CELL 2: RMSD Analysis & Sequence Extraction")
print("=" * 70)

# =============================================================================
# SEQUENCE EXTRACTION FUNCTION
# =============================================================================

def extract_sequence_from_pdb(pdb_path: str) -> str:
    """Extract complete amino acid sequence from a PDB file."""
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure('protein', pdb_path)
    ppb = PPBuilder()

    sequences = []
    for model in structure:
        for chain in model:
            for pp in ppb.build_peptides(chain):
                sequences.append(str(pp.get_sequence()))

    return ''.join(sequences)


# =============================================================================
# VALIDATION
# =============================================================================

print(f"\nConfiguration:")
print(f"  • Reference: {REFERENCE_PATH}")
print(f"  • Target folder: {TARGET_FOLDER}")
print(f"  • Pattern: {FILE_PATTERN}")

print("\n--- Validation ---")

# Check paths
if not Path(REFERENCE_PATH).exists():
    raise FileNotFoundError(f"Reference file not found: {REFERENCE_PATH}")
print(f"✓ Reference file exists")

if not Path(TARGET_FOLDER).exists():
    raise FileNotFoundError(f"Target folder not found: {TARGET_FOLDER}")

pdb_count = len(list(Path(TARGET_FOLDER).glob(FILE_PATTERN)))
if pdb_count == 0:
    raise FileNotFoundError(f"No PDB files found in {TARGET_FOLDER}")
print(f"✓ Found {pdb_count} PDB files")

# =============================================================================
# LOAD REFERENCE & ANALYZE
# =============================================================================

print("\n" + "=" * 70)
print("LOADING REFERENCE STRUCTURE")
print("=" * 70)

ref_scfv1, ref_scfv2 = load_reference_structure(REFERENCE_PATH)

print("\n" + "=" * 70)
print("ANALYZING TARGET STRUCTURES")
print("=" * 70)

results_df = analyze_batch(ref_scfv1, ref_scfv2, TARGET_FOLDER, FILE_PATTERN)

# =============================================================================
# EXTRACT SEQUENCES
# =============================================================================

print("\n" + "=" * 70)
print("EXTRACTING SEQUENCES")
print("=" * 70)

sequences = []
for _, row in results_df.iterrows():
    if row['Status'] == 'Success':
        pdb_path = Path(TARGET_FOLDER) / row['Filename']
        seq = extract_sequence_from_pdb(str(pdb_path))
        sequences.append(seq)
        print(f"  ✓ {row['Filename']}: {len(seq)} aa")
    else:
        sequences.append(None)
        print(f"  ✗ {row['Filename']}: Failed")

results_df['Sequence'] = sequences
results_df['Sequence_Length'] = results_df['Sequence'].apply(lambda x: len(x) if x else 0)

# =============================================================================
# DETERMINE LINKER GROUP
# =============================================================================

def get_linker_group(linker_len):
    """Categorize linker length into groups."""
    if linker_len <= 35:
        return "Short (≤35)"
    elif linker_len <= 45:
        return "Medium (36-45)"
    elif linker_len <= 55:
        return "Long (46-55)"
    else:
        return "Very Long (>55)"

results_df['Linker_Group'] = results_df['Linker_Length'].apply(get_linker_group)

# =============================================================================
# DISPLAY RESULTS
# =============================================================================

print("\n" + "=" * 70)
print("ANALYSIS RESULTS")
print("=" * 70)

# Summary statistics
successful = results_df[results_df['Status'] == 'Success']
print(f"\nTotal files: {len(results_df)}")
print(f"Successful: {len(successful)}")
print(f"Failed: {len(results_df) - len(successful)}")

if len(successful) > 0:
    print(f"\nRMSD Statistics:")
    print(f"  • Mean: {successful['Final_RMSD'].mean():.4f} Å")
    print(f"  • Std:  {successful['Final_RMSD'].std():.4f} Å")
    print(f"  • Min:  {successful['Final_RMSD'].min():.4f} Å")
    print(f"  • Max:  {successful['Final_RMSD'].max():.4f} Å")

# Display table
print("\n--- Results Table ---")
display_cols = ['Filename', 'Total_Residues', 'Linker_Length', 'Linker_Group',
                'Final_RMSD', 'Best_Alignment', 'Sequence_Length']
display_df = results_df[display_cols].copy()
display_df['Final_RMSD'] = display_df['Final_RMSD'].apply(
    lambda x: f"{x:.4f}" if pd.notna(x) else "N/A"
)
print(display_df.to_string(index=False))

# =============================================================================
# SAVE FOR NEXT CELL
# =============================================================================

sequence_df = results_df.copy()

print("\n" + "=" * 70)
print("✓ Cell 2 Complete")
print("=" * 70)
print(f"\nData saved to: sequence_df ({len(sequence_df)} rows)")
print("→ Run Cell 3 to filter and create FASTA files")
print("=" * 70)


In [None]:
# ================================================================================
# CELL 3: FILTER BY RMSD & PREPARE COMBINED FASTA FILE
# ================================================================================

#@title **Cell 3: Filter & Create Combined FASTA** { display-mode: "form" }
#@markdown ### Configuration
RMSD_THRESHOLD = 30  #@param {type:"number"}
FASTA_OUTPUT_DIR = "/content/fasta_files/"  #@param {type:"string"}

#@markdown ---
#@markdown **Note:** All sequences are combined into ONE FASTA file for faster MSA.
#@markdown
#@markdown **Requires:** Cell 2 must be run first.

import os
import re

print("=" * 70)
print("CELL 3: Filter & Prepare Combined FASTA")
print("=" * 70)
print(f"\nConfiguration:")
print(f"  • RMSD Threshold: ≤ {RMSD_THRESHOLD} Å")
print(f"  • Output: {FASTA_OUTPUT_DIR}")

# =============================================================================
# VALIDATION
# =============================================================================

print("\n--- Validation ---")

try:
    _ = sequence_df
    print(f"✓ sequence_df found with {len(sequence_df)} sequences")
except NameError:
    raise RuntimeError("✗ sequence_df not found! Run Cell 2 first.")

if len(sequence_df) == 0:
    raise RuntimeError("✗ sequence_df is empty!")

# =============================================================================
# FILTERING
# =============================================================================

print("\n" + "=" * 70)
print("FILTERING SEQUENCES")
print("=" * 70)

filtered_df = sequence_df[
    (sequence_df['Status'] == 'Success') &
    (sequence_df['Final_RMSD'] <= RMSD_THRESHOLD) &
    (sequence_df['Sequence'].notna()) &
    (sequence_df['Sequence'].str.len() > 0)
].copy()

print(f"\nFiltering results:")
print(f"  • Total sequences: {len(sequence_df)}")
print(f"  • Successful: {len(sequence_df[sequence_df['Status'] == 'Success'])}")
print(f"  • RMSD ≤ {RMSD_THRESHOLD}Å: {len(filtered_df)}")

if len(filtered_df) == 0:
    print("\n⚠️ No sequences passed the filter!")
    raise RuntimeError("No sequences to process")

# =============================================================================
# CREATE JOBNAMES
# =============================================================================

def clean_jobname(filename: str) -> str:
    """Create a clean jobname from filename."""
    name = filename.replace('.pdb', '')
    name = re.sub(r'[^\w\-]', '_', name)
    name = re.sub(r'_+', '_', name)
    name = name.strip('_')
    return name

def format_jobname(filename: str) -> str:
    return clean_jobname(filename)

jobnames = []
for _, row in filtered_df.iterrows():
    jobnames.append(format_jobname(row['Filename']))

filtered_df['Jobname'] = jobnames
filtered_df['RMSD'] = filtered_df['Final_RMSD']

# =============================================================================
# CREATE COMBINED FASTA FILE
# =============================================================================

print("\n" + "=" * 70)
print("CREATING COMBINED FASTA FILE")
print("=" * 70)

os.makedirs(FASTA_OUTPUT_DIR, exist_ok=True)

# Combined FASTA path
COMBINED_FASTA_PATH = os.path.join(FASTA_OUTPUT_DIR, "all_sequences.fasta")

print(f"\nCombining {len(filtered_df)} sequences into one FASTA file...")
print("-" * 50)

with open(COMBINED_FASTA_PATH, 'w') as f:
    for idx, row in filtered_df.iterrows():
        jobname = row['Jobname']
        sequence = row['Sequence']

        # Write FASTA entry
        f.write(f">{jobname}\n")
        # Write sequence in lines of 80 characters
        for i in range(0, len(sequence), 80):
            f.write(f"{sequence[i:i+80]}\n")

        print(f"  ✓ {jobname}: {len(sequence)} aa")

print(f"\n✓ Combined FASTA saved: {COMBINED_FASTA_PATH}")

# Store path for Cell 4
filtered_df['FASTA_Path'] = COMBINED_FASTA_PATH

# =============================================================================
# SUMMARY
# =============================================================================

print("\n" + "=" * 70)
print("SUMMARY")
print("=" * 70)

print(f"\n✓ {len(filtered_df)} sequences combined into: {COMBINED_FASTA_PATH}")
print(f"\n⚡ MSA will be computed ONCE for all sequences (much faster!)")

# Statistics by linker group
if 'Linker_Group' in filtered_df.columns:
    print("\n--- By Linker Group ---")
    group_stats = filtered_df.groupby('Linker_Group').agg({
        'Jobname': 'count',
        'Final_RMSD': ['mean', 'std']
    }).round(4)
    group_stats.columns = ['Count', 'Mean_RMSD', 'Std_RMSD']
    print(group_stats.to_string())

# Display filtered sequences
print("\n--- Filtered Sequences ---")
display_cols = ['Jobname', 'Linker_Length', 'Linker_Group', 'Final_RMSD', 'Sequence_Length']
display_df = filtered_df[display_cols].copy()
display_df['Final_RMSD'] = display_df['Final_RMSD'].apply(lambda x: f"{x:.4f}")
print(display_df.to_string(index=False))

# Create fasta_df for Cell 4
fasta_df = filtered_df[['Jobname', 'Filename', 'Sequence', 'Sequence_Length',
                        'Linker_Length', 'Linker_Group', 'RMSD']].copy()
fasta_df['Original_Filename'] = filtered_df['Filename']

# Store combined FASTA path as global variable
COMBINED_FASTA = COMBINED_FASTA_PATH

print("\n" + "=" * 70)
print("✓ Cell 3 Complete")
print("=" * 70)
print(f"\nData saved to:")
print(f"  • filtered_df: {len(filtered_df)} sequences")
print(f"  • fasta_df: {len(fasta_df)} sequences")
print(f"  • COMBINED_FASTA: {COMBINED_FASTA}")
print(f"\n→ Run Cell 4 to start AlphaFold predictions")
print("=" * 70)


In [None]:
# ================================================================================
# CELL 4: ALPHAFOLD BATCH EXECUTION (COMBINED FASTA - FAST MSA)
# ================================================================================

#@title **Cell 4: AlphaFold Predictions (Fast Batch)** { display-mode: "form" }
#@markdown ### AlphaFold Configuration
NUM_MODELS = 1  #@param {type:"integer"}
NUM_RECYCLES = 3  #@param {type:"integer"}
USE_AMBER = False  #@param {type:"boolean"}
MSA_MODE = "mmseqs2_uniref_env"  #@param ["mmseqs2_uniref_env", "mmseqs2_uniref", "single_sequence"]
TIMEOUT_MINUTES = 1000  #@param {type:"integer"}

#@markdown ---
#@markdown **⚡ Fast Mode:** All sequences use ONE MSA search (saves ~5min per sequence)
#@markdown
#@markdown **First run:** ~15 min installation + runtime restart required.

# CRITICAL: Set backend before any matplotlib import
import os
os.environ['MPLBACKEND'] = 'Agg'

import subprocess
import shutil
import time
import sys
from pathlib import Path
import glob

# Paths
ALPHAFOLD_OUTPUT_DIR = "/content/alphafold_results/"
FASTA_INPUT_DIR = "/content/fasta_files/"
CONDA_PATH = "/opt/miniforge"
CONDA_ENV = f"{CONDA_PATH}/envs/colabfold"
COLABFOLD_BIN = f"{CONDA_ENV}/bin/colabfold_batch"
MAMBA_BIN = f"{CONDA_PATH}/bin/mamba"

print("=" * 70)
print("CELL 4: AlphaFold Batch Execution (Fast MSA Mode)")
print("=" * 70)
print(f"\nConfiguration:")
print(f"  • Models: {NUM_MODELS}")
print(f"  • Recycles: {NUM_RECYCLES}")
print(f"  • AMBER: {USE_AMBER}")
print(f"  • MSA mode: {MSA_MODE}")
print(f"  • Timeout: {TIMEOUT_MINUTES} min (total)")
print(f"\n⚡ Fast Mode: Single MSA for all sequences!")

# ============================================================================
# STEP 1: GPU CHECK
# ============================================================================

print("\n[Step 1/4] Checking GPU...")
result = subprocess.run(['nvidia-smi', '-L'], capture_output=True, text=True)
if result.returncode == 0:
    gpu_name = result.stdout.strip().split('\n')[0]
    print(f"  ✓ {gpu_name}")
else:
    raise RuntimeError("GPU required - Runtime → Change runtime type → GPU")

# ============================================================================
# STEP 2: CHECK/INSTALL COLABFOLD
# ============================================================================

print("\n[Step 2/4] Setting up ColabFold environment...")

def check_installation_complete():
    return all([
        os.path.exists(COLABFOLD_BIN),
        os.path.exists(f"{CONDA_ENV}/bin/python"),
        os.path.exists(MAMBA_BIN),
    ])

def test_colabfold_runs():
    if not os.path.exists(COLABFOLD_BIN):
        return False
    try:
        r = subprocess.run([COLABFOLD_BIN, '--help'],
                          capture_output=True, text=True, timeout=30)
        return r.returncode == 0
    except:
        return False

if check_installation_complete() and test_colabfold_runs():
    print("  ✓ ColabFold is ready!")
else:
    print("  Installing ColabFold...")
    print("  This takes 12-18 minutes (one-time setup)\n")

    # Install Miniforge
    print("  [1/4] Installing Miniforge...")
    if not os.path.exists(MAMBA_BIN):
        miniforge_url = "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh"
        cmd = f'''
wget -q --timeout=120 "{miniforge_url}" -O /tmp/miniforge.sh || curl -sL --max-time 120 "{miniforge_url}" -o /tmp/miniforge.sh
bash /tmp/miniforge.sh -b -p {CONDA_PATH} -u
rm -f /tmp/miniforge.sh
'''
        subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=600)

    if os.path.exists(MAMBA_BIN):
        print("  ✓ Miniforge ready")
    else:
        raise RuntimeError("Miniforge installation failed")

    # Create environment
    print("  [2/4] Creating Python 3.10 environment...")
    if not os.path.exists(f"{CONDA_ENV}/bin/python"):
        cmd = f'{MAMBA_BIN} create -y -p {CONDA_ENV} python=3.10 -q'
        subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=300)

    if os.path.exists(f"{CONDA_ENV}/bin/python"):
        print("  ✓ Environment created")
    else:
        raise RuntimeError("Environment creation failed")

    # Install pdbfixer + openmm
    print("  [3/4] Installing pdbfixer + OpenMM...")
    cmd = f'{MAMBA_BIN} install -y -p {CONDA_ENV} -c conda-forge pdbfixer openmm -q'
    subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=600)
    print("  ✓ pdbfixer + OpenMM installed")

    # Install ColabFold + JAX
    print("  [4/4] Installing ColabFold + JAX...")
    cmd1 = f'{CONDA_ENV}/bin/pip install -q "colabfold[alphafold-minus-jax] @ git+https://github.com/sokrypton/ColabFold"'
    subprocess.run(cmd1, shell=True, capture_output=True, text=True, timeout=600)

    cmd2 = f'{CONDA_ENV}/bin/pip install -q "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html'
    subprocess.run(cmd2, shell=True, capture_output=True, text=True, timeout=300)

    if os.path.exists(COLABFOLD_BIN):
        print("  ✓ ColabFold + JAX installed")
    else:
        raise RuntimeError("ColabFold installation failed")

    print("\n" + "=" * 70)
    print("✓ INSTALLATION COMPLETE")
    print("=" * 70)
    print("""
⚠️  RUNTIME RESTART REQUIRED

Please:
  1. Click: Runtime → Restart runtime
  2. Wait ~10 seconds
  3. Click: Runtime → Run all

After restart, predictions will run automatically.
""")
    print("=" * 70)
    sys.exit(0)

# ============================================================================
# STEP 3: CHECK PREREQUISITES
# ============================================================================

print("\n[Step 3/4] Checking prerequisites...")

# Check fasta_df
try:
    _ = fasta_df
    print(f"  ✓ fasta_df: {len(fasta_df)} sequences")
except NameError:
    raise RuntimeError("fasta_df not found. Run Cells 1-3 first.")

# Check combined FASTA file
try:
    _ = COMBINED_FASTA
    if os.path.exists(COMBINED_FASTA):
        print(f"  ✓ Combined FASTA: {COMBINED_FASTA}")
    else:
        raise FileNotFoundError()
except:
    # Fallback: look for all_sequences.fasta
    COMBINED_FASTA = os.path.join(FASTA_INPUT_DIR, "all_sequences.fasta")
    if os.path.exists(COMBINED_FASTA):
        print(f"  ✓ Combined FASTA found: {COMBINED_FASTA}")
    else:
        raise RuntimeError("Combined FASTA not found. Run Cell 3 first.")

# Create output directory
output_dir = Path(ALPHAFOLD_OUTPUT_DIR)
output_dir.mkdir(parents=True, exist_ok=True)
print(f"  ✓ Output: {output_dir}")

# ============================================================================
# STEP 4: RUN ALPHAFOLD (SINGLE BATCH)
# ============================================================================

print("\n[Step 4/4] Running AlphaFold predictions...")
print(f"\n  ⚡ Processing {len(fasta_df)} sequences with SINGLE MSA search")
print(f"  Estimated time: ~5min MSA + ~{len(fasta_df) * 2}-{len(fasta_df) * 3}min predictions")

# Temporary output directory for ColabFold
temp_output_dir = output_dir / "temp_batch"
temp_output_dir.mkdir(exist_ok=True)

# Build command
cmd = [COLABFOLD_BIN]
cmd.extend(["--num-recycle", str(NUM_RECYCLES)])
cmd.extend(["--num-models", str(NUM_MODELS)])
cmd.extend(["--model-type", "alphafold2_ptm"])
cmd.extend(["--msa-mode", MSA_MODE])
cmd.extend(["--rank", "plddt"])

if USE_AMBER:
    cmd.extend(["--amber", "--use-gpu-relax"])

cmd.extend([COMBINED_FASTA, str(temp_output_dir)])

print(f"\n  Command: colabfold_batch [combined.fasta] [output_dir]")
print(f"  Running...\n")

t_start = time.time()

try:
    result = subprocess.run(
        cmd,
        capture_output=True,
        text=True,
        timeout=TIMEOUT_MINUTES * 60
    )

    elapsed = time.time() - t_start

    if result.returncode != 0:
        print(f"  ⚠️ ColabFold returned non-zero exit code")
        out = (result.stdout or "") + (result.stderr or "")
        if "pdbfixer" in out.lower():
            print("    → Restart runtime, re-run all cells")
        elif "mmseqs" in out.lower():
            print("    → MSA server issue, try again or use 'single_sequence'")
        else:
            lines = [l for l in out.strip().split('\n') if l.strip()][-5:]
            for line in lines:
                print(f"    {line[:80]}")

except subprocess.TimeoutExpired:
    print(f"  ✗ Timeout after {TIMEOUT_MINUTES} minutes")
    subprocess.run('pkill -9 -f colabfold', shell=True, capture_output=True)
    raise RuntimeError("Prediction timed out")

except Exception as e:
    print(f"  ✗ Exception: {e}")
    raise

print(f"\n  Total batch time: {elapsed/60:.1f} minutes")

# ============================================================================
# ORGANIZE OUTPUT FILES
# ============================================================================

print("\n" + "=" * 70)
print("ORGANIZING OUTPUT FILES")
print("=" * 70)

successful = []
failed = []

# Find all generated PDB files
if USE_AMBER:
    pdb_pattern = "*_relaxed_rank_001*.pdb"
else:
    pdb_pattern = "*_unrelaxed_rank_001*.pdb"

generated_pdbs = list(temp_output_dir.glob(pdb_pattern))
if not generated_pdbs:
    generated_pdbs = list(temp_output_dir.glob("*rank_001*.pdb"))
if not generated_pdbs:
    generated_pdbs = list(temp_output_dir.glob("*.pdb"))

print(f"\nFound {len(generated_pdbs)} PDB files")

# Map generated files to jobnames
for _, row in fasta_df.iterrows():
    jobname = row['Jobname']

    # Find matching PDB (ColabFold uses the FASTA header as prefix)
    matching_pdbs = [p for p in generated_pdbs if jobname in p.name]

    if matching_pdbs:
        src_pdb = matching_pdbs[0]

        # Determine type
        if "relaxed" in src_pdb.name:
            pdb_type = "relaxed"
        else:
            pdb_type = "unrelaxed"

        # Copy to final location with clean name
        final_pdb = output_dir / f"{jobname}_{pdb_type}_rank_001.pdb"
        shutil.copy(src_pdb, final_pdb)

        # Find matching JSON
        json_pattern = jobname + "*scores*.json"
        matching_jsons = list(temp_output_dir.glob(json_pattern))
        if not matching_jsons:
            matching_jsons = list(temp_output_dir.glob(f"{jobname}*.json"))

        final_json = None
        if matching_jsons:
            final_json = output_dir / f"{jobname}_scores_rank_001.json"
            shutil.copy(matching_jsons[0], final_json)

        successful.append({
            'Jobname': jobname,
            'Original_Filename': row['Original_Filename'],
            'Original_RMSD': row['RMSD'],
            'Linker_Group': row['Linker_Group'],
            'Sequence_Length': row['Sequence_Length'],
            'AlphaFold_PDB': str(final_pdb),
            'AlphaFold_JSON': str(final_json) if final_json else None,
            'PDB_Type': pdb_type,
            'Status': 'Success'
        })

        j = "✓" if final_json else "✗"
        print(f"  ✓ {jobname}: PDB ✓ | JSON {j}")
    else:
        failed.append({'Jobname': jobname, 'Error': 'PDB not found'})
        print(f"  ✗ {jobname}: PDB not found")

# Cleanup temp directory
shutil.rmtree(temp_output_dir, ignore_errors=True)

# ============================================================================
# SUMMARY
# ============================================================================

print("\n" + "=" * 70)
print("EXECUTION SUMMARY")
print("=" * 70)

print(f"\nTotal time: {elapsed/60:.1f} minutes")
print(f"Successful: {len(successful)}/{len(fasta_df)}")

if len(successful) > 0:
    avg_time = elapsed / len(successful)
    print(f"Average: {avg_time/60:.1f} min per sequence (including shared MSA)")

if failed:
    print("\n✗ Failed:")
    for f in failed:
        print(f"  • {f['Jobname']}: {f['Error']}")

# Create results DataFrame
if successful:
    alphafold_results_df = pd.DataFrame(successful)
    alphafold_results_df.to_csv(output_dir / "alphafold_results.csv", index=False)
    print(f"\n✓ Results saved to {output_dir}")
else:
    alphafold_results_df = pd.DataFrame()
    print("\n⚠️ No successful predictions")

print("\n" + "=" * 70)
print("✓ Cell 4 Complete")
print("=" * 70)
if successful:
    print("→ Run Cell 5 for RMSD analysis")
    print("→ Run Cell 6 for NMA analysis")
    print("→ Run Cell 7 for pLDDT/PAE analysis")
    print("→ Run Cell 8 for Final Summary & Download")
print("=" * 70)


In [None]:
# ================================================================================
# CELL 5: ALPHAFOLD RESULTS ANALYSIS & ANTIGEN DISTANCE MEASUREMENT
# ================================================================================

#@title **Cell 5: RMSD Comparison & Antigen Distance (130Å Target)** { display-mode: "form" }
#@markdown ### Configuration
REFERENCE_PATH = "/content/reference.pdb"  #@param {type:"string"}
ANTIGEN_COMPLEX_PATH = "/content/Include_antigen_with_perfect_distance.pdb"  #@param {type:"string"}

#@markdown ### Antigen Distance Settings
#@markdown Chain B (CD3e) membrane-proximal residue:
CHAIN_B_RESNUM = 102  #@param {type:"integer"}
#@markdown Chain D (HER2) membrane-proximal residue:
CHAIN_D_RESNUM = 652  #@param {type:"integer"}
#@markdown Target immunological synapse distance:
TARGET_DISTANCE = 130.0  #@param {type:"number"}

#@markdown ---
#@markdown **Requires:** Cells 1-4 must be run first.

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from pathlib import Path
import copy

print("=" * 70)
print("CELL 5: AlphaFold Analysis & Antigen Distance Measurement")
print("=" * 70)

ALPHAFOLD_OUTPUT_DIR = "/content/alphafold_results/"

# =============================================================================
# VALIDATION
# =============================================================================

print("\n--- Validation ---")

# Check alphafold_results_df
try:
    _ = alphafold_results_df
    if len(alphafold_results_df) == 0:
        raise ValueError("alphafold_results_df is empty")
    print(f"✓ alphafold_results_df: {len(alphafold_results_df)} predictions")
except (NameError, ValueError) as e:
    print(f"⚠️ alphafold_results_df issue: {e}")
    print("  Attempting to find AlphaFold outputs directly...")

    af_output_path = Path(ALPHAFOLD_OUTPUT_DIR)
    if af_output_path.exists():
        af_pdbs = list(af_output_path.glob("*_relaxed_rank_001.pdb"))
        if not af_pdbs:
            af_pdbs = list(af_output_path.glob("*_unrelaxed_rank_001.pdb"))
        if not af_pdbs:
            af_pdbs = list(af_output_path.glob("*rank_001*.pdb"))
        if af_pdbs:
            print(f"  Found {len(af_pdbs)} AlphaFold PDB files")
            alphafold_results_df = pd.DataFrame([{
                'Jobname': pdb.stem.replace('_relaxed_rank_001', '').replace('_unrelaxed_rank_001', ''),
                'AlphaFold_PDB': str(pdb),
                'Status': 'Success'
            } for pdb in af_pdbs])
        else:
            raise RuntimeError("No AlphaFold output files found. Run Cell 4 first.")
    else:
        raise RuntimeError(f"Output directory not found: {ALPHAFOLD_OUTPUT_DIR}")

# Check filtered_df
try:
    _ = filtered_df
    print(f"✓ filtered_df: {len(filtered_df)} sequences")
except NameError:
    print("⚠️ filtered_df not found - original RMSD values may not be available")
    filtered_df = None

# Check reference file
if not Path(REFERENCE_PATH).exists():
    raise FileNotFoundError(f"Reference file not found: {REFERENCE_PATH}")
print(f"✓ Reference file found")

# Check antigen complex file
if not Path(ANTIGEN_COMPLEX_PATH).exists():
    print(f"⚠️ Antigen complex not found: {ANTIGEN_COMPLEX_PATH}")
    print("  Antigen distance measurement will be skipped.")
    ANTIGEN_COMPLEX_AVAILABLE = False
else:
    print(f"✓ Antigen complex found")
    ANTIGEN_COMPLEX_AVAILABLE = True

# =============================================================================
# LOAD ANTIGEN-ANTIBODY COMPLEX
# =============================================================================

if ANTIGEN_COMPLEX_AVAILABLE:
    print("\n" + "=" * 70)
    print("LOADING ANTIGEN-ANTIBODY COMPLEX")
    print("=" * 70)

    parser = PDBParser(QUIET=True)
    antigen_complex = parser.get_structure('antigen_complex', ANTIGEN_COMPLEX_PATH)

    # Extract chains
    chain_A_residues = []  # scFv1 (anti-CD3)
    chain_B_residues = []  # CD3e antigen
    chain_C_residues = []  # scFv2 (anti-HER2)
    chain_D_residues = []  # HER2 antigen

    for model in antigen_complex:
        for chain in model:
            chain_id = chain.get_id()
            for residue in chain:
                if is_aa(residue, standard=True):
                    if chain_id == 'A':
                        chain_A_residues.append(residue)
                    elif chain_id == 'B':
                        chain_B_residues.append(residue)
                    elif chain_id == 'C':
                        chain_C_residues.append(residue)
                    elif chain_id == 'D':
                        chain_D_residues.append(residue)

    print(f"  Chain A (scFv1): {len(chain_A_residues)} residues")
    print(f"  Chain B (CD3e): {len(chain_B_residues)} residues")
    print(f"  Chain C (scFv2): {len(chain_C_residues)} residues")
    print(f"  Chain D (HER2): {len(chain_D_residues)} residues")

    # Get CA atoms for alignment
    chain_A_ca = get_ca_atoms(chain_A_residues)
    chain_C_ca = get_ca_atoms(chain_C_residues)

    # Find target residues for distance measurement
    chain_B_target_ca = None
    chain_D_target_ca = None

    for model in antigen_complex:
        if 'B' in [c.get_id() for c in model]:
            chain_B = model['B']
            for residue in chain_B:
                if residue.get_id()[1] == CHAIN_B_RESNUM and 'CA' in residue:
                    chain_B_target_ca = residue['CA']
                    break

        if 'D' in [c.get_id() for c in model]:
            chain_D = model['D']
            for residue in chain_D:
                if residue.get_id()[1] == CHAIN_D_RESNUM and 'CA' in residue:
                    chain_D_target_ca = residue['CA']
                    break

    if chain_B_target_ca and chain_D_target_ca:
        original_distance = chain_B_target_ca - chain_D_target_ca
        print(f"\n  Target residues found:")
        print(f"    Chain B res {CHAIN_B_RESNUM} CA: {chain_B_target_ca.get_coord()}")
        print(f"    Chain D res {CHAIN_D_RESNUM} CA: {chain_D_target_ca.get_coord()}")
        print(f"    Original distance: {original_distance:.2f} Å")
        print(f"    Target distance: {TARGET_DISTANCE:.2f} Å")
    else:
        print(f"  ⚠️ Could not find target residues for distance measurement")
        ANTIGEN_COMPLEX_AVAILABLE = False

# =============================================================================
# ANTIGEN DISTANCE CALCULATION FUNCTION
# =============================================================================

def calculate_antigen_distance(af_pdb_path: str) -> dict:
    """
    Calculate the distance between antigen membrane-proximal residues
    after superimposing the antigen-scFv complexes onto the AlphaFold structure.

    Method:
    1. Superimpose Chain A (scFv1) onto AlphaFold's scFv1 → get transformation matrix
    2. Apply the same transformation to Chain B (CD3e) → CD3e moves with scFv1
    3. Superimpose Chain C (scFv2) onto AlphaFold's scFv2 → get transformation matrix
    4. Apply the same transformation to Chain D (HER2) → HER2 moves with scFv2
    5. Measure distance between transformed Chain B res 102 and Chain D res 652
    """
    result = {
        'Antigen_Distance': None,
        'Distance_From_Target': None,
        'scFv1_Alignment_RMSD': None,
        'scFv2_Alignment_RMSD': None
    }

    if not ANTIGEN_COMPLEX_AVAILABLE:
        return result

    try:
        # Load AlphaFold structure
        parser = PDBParser(QUIET=True)
        af_structure = parser.get_structure('alphafold', af_pdb_path)
        af_residues = extract_residues(af_structure)

        # Extract AlphaFold scFv regions
        af_scfv1_residues = af_residues[:SCFV1_LENGTH]
        af_scfv2_residues = af_residues[-SCFV2_LENGTH:]

        af_scfv1_ca = get_ca_atoms(af_scfv1_residues)
        af_scfv2_ca = get_ca_atoms(af_scfv2_residues)

        # Check atom counts match
        if len(af_scfv1_ca) != len(chain_A_ca):
            result['Error'] = f"scFv1 atom mismatch: AF={len(af_scfv1_ca)}, ChainA={len(chain_A_ca)}"
            return result
        if len(af_scfv2_ca) != len(chain_C_ca):
            result['Error'] = f"scFv2 atom mismatch: AF={len(af_scfv2_ca)}, ChainC={len(chain_C_ca)}"
            return result

        # === Superimpose Chain A onto AlphaFold scFv1 ===
        # We want to move Chain A (and B) to match AlphaFold scFv1
        sup1 = Superimposer()
        sup1.set_atoms(af_scfv1_ca, chain_A_ca)  # fixed=AF, moving=ChainA
        result['scFv1_Alignment_RMSD'] = sup1.rms

        # Get Chain B target atom coordinates and apply transformation
        chain_B_coord = chain_B_target_ca.get_coord().copy()
        # Apply rotation and translation
        chain_B_coord_transformed = np.dot(chain_B_coord, sup1.rotran[0].T) + sup1.rotran[1]

        # === Superimpose Chain C onto AlphaFold scFv2 ===
        sup2 = Superimposer()
        sup2.set_atoms(af_scfv2_ca, chain_C_ca)  # fixed=AF, moving=ChainC
        result['scFv2_Alignment_RMSD'] = sup2.rms

        # Get Chain D target atom coordinates and apply transformation
        chain_D_coord = chain_D_target_ca.get_coord().copy()
        # Apply rotation and translation
        chain_D_coord_transformed = np.dot(chain_D_coord, sup2.rotran[0].T) + sup2.rotran[1]

        # === Calculate distance ===
        antigen_distance = np.linalg.norm(chain_B_coord_transformed - chain_D_coord_transformed)

        result['Antigen_Distance'] = antigen_distance
        result['Distance_From_Target'] = antigen_distance - TARGET_DISTANCE
        result['Chain_B_Transformed'] = chain_B_coord_transformed
        result['Chain_D_Transformed'] = chain_D_coord_transformed

    except Exception as e:
        result['Error'] = str(e)

    return result

# =============================================================================
# ANALYZE ALPHAFOLD PREDICTIONS
# =============================================================================

print("\n" + "=" * 70)
print("ANALYZING ALPHAFOLD PREDICTIONS")
print("=" * 70)

# Load reference structure
print("\n--- Loading Reference Structure ---")
ref_scfv1, ref_scfv2 = load_reference_structure(REFERENCE_PATH)

# Find all AlphaFold PDB files
af_output_path = Path(ALPHAFOLD_OUTPUT_DIR)

af_pdb_files = sorted(af_output_path.glob("*_relaxed_rank_001.pdb"))
if not af_pdb_files:
    af_pdb_files = sorted(af_output_path.glob("*_unrelaxed_rank_001.pdb"))
if not af_pdb_files:
    af_pdb_files = sorted(af_output_path.glob("*rank_001*.pdb"))
if not af_pdb_files:
    af_pdb_files = sorted(af_output_path.glob("*.pdb"))

print(f"\n--- Found {len(af_pdb_files)} AlphaFold PDB Files ---")

if len(af_pdb_files) == 0:
    raise RuntimeError("No PDB files to analyze")

# Analyze each AlphaFold prediction
af_rmsd_results = []

for i, pdb_file in enumerate(af_pdb_files, 1):
    jobname = pdb_file.stem.replace('_relaxed_rank_001', '').replace('_unrelaxed_rank_001', '')

    print(f"\n  [{i}/{len(af_pdb_files)}] {jobname}")

    # RMSD analysis
    result = analyze_single_target(ref_scfv1, ref_scfv2, str(pdb_file))
    result['Jobname'] = jobname
    result['AlphaFold_PDB'] = str(pdb_file)

    if result['Status'] == 'Success':
        print(f"      RMSD: {result['Final_RMSD']:.4f} Å ({result['Best_Alignment']})")
    else:
        print(f"      RMSD: FAILED - {result['Error']}")

    # Antigen distance analysis
    if ANTIGEN_COMPLEX_AVAILABLE:
        antigen_result = calculate_antigen_distance(str(pdb_file))
        result.update(antigen_result)

        if antigen_result['Antigen_Distance'] is not None:
            dist = antigen_result['Antigen_Distance']
            diff = antigen_result['Distance_From_Target']
            print(f"      Antigen Distance: {dist:.2f} Å (Δ from 130Å: {diff:+.2f} Å)")

    af_rmsd_results.append(result)

af_rmsd_df = pd.DataFrame(af_rmsd_results)

# =============================================================================
# MERGE WITH ORIGINAL DATA
# =============================================================================

print("\n" + "=" * 70)
print("MERGING ORIGINAL & ALPHAFOLD RESULTS")
print("=" * 70)

original_data = {}

if filtered_df is not None:
    for _, row in filtered_df.iterrows():
        if 'format_jobname' in dir():
            jobname = format_jobname(row['Filename'])
        else:
            jobname = row['Filename'].replace('.pdb', '')
        original_data[jobname] = {
            'Original_RMSD': row['Final_RMSD'],
            'Linker_Group': row.get('Linker_Group', 'Unknown'),
            'Original_Filename': row['Filename']
        }
elif 'Original_RMSD' in alphafold_results_df.columns:
    for _, row in alphafold_results_df.iterrows():
        original_data[row['Jobname']] = {
            'Original_RMSD': row['Original_RMSD'],
            'Linker_Group': row.get('Linker_Group', 'Unknown'),
            'Original_Filename': row.get('Original_Filename', row['Jobname'])
        }

# Build comparison DataFrame
comparison_data = []

for _, af_row in af_rmsd_df.iterrows():
    if af_row['Status'] != 'Success':
        continue

    jobname = af_row['Jobname']
    af_rmsd = af_row['Final_RMSD']

    orig = original_data.get(jobname, {})
    orig_rmsd = orig.get('Original_RMSD')
    linker_group = orig.get('Linker_Group', af_row.get('Linker_Group', 'Unknown'))

    row_data = {
        'Jobname': jobname,
        'Original_RMSD': orig_rmsd,
        'AlphaFold_RMSD': af_rmsd,
        'Linker_Group': linker_group,
        'AF_Best': af_row['Best_Alignment'],
        'AlphaFold_PDB': af_row['AlphaFold_PDB']
    }

    if orig_rmsd is not None:
        row_data['RMSD_Change'] = af_rmsd - orig_rmsd
        row_data['RMSD_Improvement'] = orig_rmsd - af_rmsd
        row_data['Percent_Change'] = (af_rmsd - orig_rmsd) / orig_rmsd * 100 if orig_rmsd != 0 else 0

    # Add antigen distance data
    if 'Antigen_Distance' in af_row and af_row['Antigen_Distance'] is not None:
        row_data['Antigen_Distance'] = af_row['Antigen_Distance']
        row_data['Distance_From_Target'] = af_row['Distance_From_Target']
        row_data['Abs_Distance_Error'] = abs(af_row['Distance_From_Target'])

    comparison_data.append(row_data)

comparison_df = pd.DataFrame(comparison_data)

if comparison_df.empty:
    print("\n⚠️ No comparison data available!")
    comparison_df = af_rmsd_df.copy()
    if 'AlphaFold_RMSD' not in comparison_df.columns and 'Final_RMSD' in comparison_df.columns:
        comparison_df['AlphaFold_RMSD'] = comparison_df['Final_RMSD']

# Sort by Antigen Distance error (closest to 130Å first)
if 'Abs_Distance_Error' in comparison_df.columns:
    comparison_df = comparison_df.sort_values('Abs_Distance_Error').reset_index(drop=True)
elif 'AlphaFold_RMSD' in comparison_df.columns:
    comparison_df = comparison_df.sort_values('AlphaFold_RMSD').reset_index(drop=True)

# =============================================================================
# DISPLAY RESULTS
# =============================================================================

print("\n" + "=" * 70)
print("COMPARISON RESULTS (Sorted by Closest to 130Å Target)")
print("=" * 70)

display_df = comparison_df.copy()

# Format columns
for col in ['Original_RMSD', 'AlphaFold_RMSD']:
    if col in display_df.columns:
        display_df[col] = display_df[col].apply(lambda x: f"{x:.4f}" if pd.notna(x) else "N/A")

if 'Antigen_Distance' in display_df.columns:
    display_df['Antigen_Distance'] = display_df['Antigen_Distance'].apply(
        lambda x: f"{x:.2f}" if pd.notna(x) else "N/A"
    )
if 'Distance_From_Target' in display_df.columns:
    display_df['Distance_From_Target'] = display_df['Distance_From_Target'].apply(
        lambda x: f"{x:+.2f}" if pd.notna(x) else "N/A"
    )

# Select columns to display
display_cols = ['Jobname', 'Antigen_Distance', 'Distance_From_Target',
                'AlphaFold_RMSD', 'Original_RMSD', 'Linker_Group']
display_cols = [c for c in display_cols if c in display_df.columns]

print("\n--- Results Table ---")
print(display_df[display_cols].to_string(index=False))

# =============================================================================
# STATISTICS
# =============================================================================

print("\n" + "=" * 70)
print("STATISTICS")
print("=" * 70)

if 'Antigen_Distance' in comparison_df.columns and comparison_df['Antigen_Distance'].notna().any():
    antigen_dist = comparison_df['Antigen_Distance'].dropna()
    print(f"\nAntigen Distance Statistics (Target: {TARGET_DISTANCE} Å):")
    print(f"  • Mean: {antigen_dist.mean():.2f} Å")
    print(f"  • Std:  {antigen_dist.std():.2f} Å")
    print(f"  • Min:  {antigen_dist.min():.2f} Å")
    print(f"  • Max:  {antigen_dist.max():.2f} Å")
    print(f"  • Closest to target: {antigen_dist.iloc[(antigen_dist - TARGET_DISTANCE).abs().argmin()]:.2f} Å")

if 'AlphaFold_RMSD' in comparison_df.columns:
    print(f"\nAlphaFold RMSD Statistics:")
    print(f"  • Mean: {comparison_df['AlphaFold_RMSD'].mean():.4f} Å")
    print(f"  • Std:  {comparison_df['AlphaFold_RMSD'].std():.4f} Å")
    print(f"  • Min:  {comparison_df['AlphaFold_RMSD'].min():.4f} Å")
    print(f"  • Max:  {comparison_df['AlphaFold_RMSD'].max():.4f} Å")

# =============================================================================
# VISUALIZATION
# =============================================================================

print("\n" + "=" * 70)
print("GENERATING VISUALIZATIONS")
print("=" * 70)

n_plots = 3 if ANTIGEN_COMPLEX_AVAILABLE and 'Antigen_Distance' in comparison_df.columns else 2
fig, axes = plt.subplots(1, n_plots, figsize=(5*n_plots, 5))

if n_plots == 2:
    axes = [axes[0], axes[1], None]

# Plot 1: Antigen Distance Distribution
if ANTIGEN_COMPLEX_AVAILABLE and 'Antigen_Distance' in comparison_df.columns:
    ax1 = axes[0]
    distances = comparison_df['Antigen_Distance'].dropna()

    bars = ax1.bar(range(len(distances)), distances.values,
                   color=['green' if abs(d - TARGET_DISTANCE) < 10 else 'orange' if abs(d - TARGET_DISTANCE) < 20 else 'red'
                          for d in distances.values])
    ax1.axhline(y=TARGET_DISTANCE, color='blue', linestyle='--', linewidth=2, label=f'Target ({TARGET_DISTANCE}Å)')
    ax1.axhspan(TARGET_DISTANCE-10, TARGET_DISTANCE+10, alpha=0.2, color='green', label='±10Å')
    ax1.set_xlabel('Sample')
    ax1.set_ylabel('Antigen Distance (Å)')
    ax1.set_title('Antigen Distance vs Target (130Å)')
    ax1.set_xticks(range(len(distances)))
    ax1.set_xticklabels(comparison_df['Jobname'].head(len(distances)), rotation=45, ha='right', fontsize=7)
    ax1.legend(fontsize=8)
    ax1.grid(axis='y', alpha=0.3)

# Plot 2: RMSD comparison
ax2 = axes[1] if axes[1] is not None else axes[0]
if 'AlphaFold_RMSD' in comparison_df.columns:
    x = np.arange(len(comparison_df))
    width = 0.35

    if 'Original_RMSD' in comparison_df.columns and comparison_df['Original_RMSD'].notna().any():
        bars1 = ax2.bar(x - width/2, comparison_df['Original_RMSD'], width, label='Original', color='#3498db', alpha=0.8)
        bars2 = ax2.bar(x + width/2, comparison_df['AlphaFold_RMSD'], width, label='AlphaFold', color='#e74c3c', alpha=0.8)
    else:
        bars2 = ax2.bar(x, comparison_df['AlphaFold_RMSD'], width, label='AlphaFold', color='#e74c3c', alpha=0.8)

    ax2.set_xlabel('Sample')
    ax2.set_ylabel('RMSD (Å)')
    ax2.set_title('RMSD Comparison')
    ax2.set_xticks(x)
    ax2.set_xticklabels(comparison_df['Jobname'], rotation=45, ha='right', fontsize=7)
    ax2.legend()
    ax2.grid(axis='y', alpha=0.3)

# Plot 3: Antigen Distance vs RMSD scatter
if ANTIGEN_COMPLEX_AVAILABLE and 'Antigen_Distance' in comparison_df.columns and axes[2] is not None:
    ax3 = axes[2]
    scatter = ax3.scatter(comparison_df['AlphaFold_RMSD'],
                         comparison_df['Antigen_Distance'],
                         c=comparison_df['Abs_Distance_Error'],
                         cmap='RdYlGn_r', s=100, edgecolors='black')
    ax3.axhline(y=TARGET_DISTANCE, color='blue', linestyle='--', alpha=0.7, label=f'Target ({TARGET_DISTANCE}Å)')
    ax3.set_xlabel('AlphaFold RMSD (Å)')
    ax3.set_ylabel('Antigen Distance (Å)')
    ax3.set_title('RMSD vs Antigen Distance')
    plt.colorbar(scatter, ax=ax3, label='|Distance - 130Å|')
    ax3.legend()

    for _, row in comparison_df.iterrows():
        if pd.notna(row.get('Antigen_Distance')):
            ax3.annotate(row['Jobname'],
                        (row['AlphaFold_RMSD'], row['Antigen_Distance']),
                        fontsize=6, alpha=0.7)

plt.tight_layout()
plt.savefig(Path(ALPHAFOLD_OUTPUT_DIR) / "rmsd_antigen_comparison.png", dpi=300, bbox_inches='tight')
plt.show()

print("  ✓ Saved: rmsd_antigen_comparison.png")

# =============================================================================
# TOP CANDIDATES
# =============================================================================

if 'Antigen_Distance' in comparison_df.columns and comparison_df['Antigen_Distance'].notna().any():
    print("\n" + "=" * 70)
    print(f"TOP 5 CANDIDATES (Closest to {TARGET_DISTANCE}Å Target)")
    print("=" * 70)

    top5 = comparison_df.head(5)

    for rank, (_, row) in enumerate(top5.iterrows(), 1):
        print(f"\n#{rank}: {row['Jobname']}")
        print(f"    Antigen Distance: {row['Antigen_Distance']:.2f} Å (Δ: {row['Distance_From_Target']:+.2f} Å)")
        print(f"    AlphaFold RMSD: {row['AlphaFold_RMSD']:.4f} Å")
        if pd.notna(row.get('Original_RMSD')):
            print(f"    Original RMSD: {row['Original_RMSD']:.4f} Å")

# =============================================================================
# SUMMARY
# =============================================================================

print("\n" + "=" * 70)
print("✓ Cell 5 Complete")
print("=" * 70)
print(f"\nData saved to: comparison_df ({len(comparison_df)} rows)")
print("\nKey columns added:")
print("  • Antigen_Distance: Distance between CD3e and HER2 membrane residues")
print("  • Distance_From_Target: How far from 130Å target")
print("  • Abs_Distance_Error: Absolute error from target")
print("\n→ Run Cell 6 for NMA analysis")
print("=" * 70)


In [None]:
# ================================================================================
# CELL 6: NORMAL MODE ANALYSIS (NMA)
# ================================================================================

#@title **Cell 6: Normal Mode Analysis (NMA)** { display-mode: "form" }
#@markdown ### NMA Parameters
GNM_CUTOFF = 10.0  #@param {type:"number"}
ANM_CUTOFF = 15.0  #@param {type:"number"}
N_MODES = 20  #@param {type:"integer"}

#@markdown ---
#@markdown **Metrics calculated:**
#@markdown - Effective Stiffness (GNM)
#@markdown - Mean Square Fluctuation (MSF)
#@markdown - Deformability Index (ANM)
#@markdown - Inter-Domain Distance Fluctuation
#@markdown
#@markdown **Requires:** Cells 1-4 must be run first.

import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
from pathlib import Path
import os
from typing import Dict, List, Optional, Tuple

print("=" * 70)
print("CELL 6: Normal Mode Analysis (NMA)")
print("=" * 70)

# Install ProDy if needed
try:
    import prody
    from prody import *
except ImportError:
    print("Installing ProDy...")
    import subprocess
    subprocess.run(['pip', 'install', '-q', 'prody'], capture_output=True)
    import prody
    from prody import *

prody.confProDy(verbosity='warning')
print(f"ProDy Version: {prody.__version__}")

# Domain constants
SCFV1_LENGTH = 240
SCFV2_LENGTH = 242
ALPHAFOLD_OUTPUT_DIR = "/content/alphafold_results/"

print(f"\nDomain Structure:")
print(f"  • scFv1: {SCFV1_LENGTH} residues")
print(f"  • scFv2: {SCFV2_LENGTH} residues")

# =============================================================================
# NMA ANALYZER CLASS
# =============================================================================

class LinkerFlexibilityAnalyzer:
    """Flexibility analysis for rigid linker validation using GNM/ANM."""

    def __init__(self, structure_path: str, name: str = None):
        self.name = name or os.path.basename(structure_path).replace('.pdb', '')
        self.structure = parsePDB(structure_path)
        self.calphas = self.structure.select('calpha')

        if self.calphas is None:
            raise ValueError(f"No CA atoms found in {structure_path}")

        self.n_residues = self.calphas.numAtoms()
        self.domain_indices = self._calculate_domain_indices()
        self.linker_length = self.domain_indices['linker_length']

        self._gnm = None
        self._anm = None
        self.results = {}

    def _calculate_domain_indices(self) -> Dict:
        linker_length = self.n_residues - SCFV1_LENGTH - SCFV2_LENGTH
        if linker_length <= 0:
            raise ValueError(f"Invalid total length {self.n_residues}")

        return {
            'scfv1': (0, SCFV1_LENGTH),
            'linker': (SCFV1_LENGTH, SCFV1_LENGTH + linker_length),
            'scfv2': (SCFV1_LENGTH + linker_length, self.n_residues),
            'linker_length': linker_length
        }

    def run_gnm(self, cutoff: float = 10.0, n_modes: int = 20):
        self._gnm = GNM(self.name)
        self._gnm.buildKirchhoff(self.calphas, cutoff=cutoff)
        self._gnm.calcModes(n_modes=n_modes)
        return self._gnm

    def calc_msf_gnm(self) -> np.ndarray:
        if self._gnm is None:
            self.run_gnm()

        msf = calcSqFlucts(self._gnm)
        self.results['msf_gnm'] = msf
        self.results['msf_gnm_mean'] = np.mean(msf)
        self.results['msf_gnm_std'] = np.std(msf)

        # Domain-specific MSF
        for domain, (start, end) in [('scfv1', self.domain_indices['scfv1']),
                                      ('scfv2', self.domain_indices['scfv2']),
                                      ('linker', self.domain_indices['linker'])]:
            domain_msf = msf[start:end]
            self.results[f'{domain}_msf_mean'] = np.mean(domain_msf)
            self.results[f'{domain}_msf_std'] = np.std(domain_msf)
            self.results[f'{domain}_msf_max'] = np.max(domain_msf)
            self.results[f'{domain}_msf_min'] = np.min(domain_msf)

        return msf

    def run_anm(self, cutoff: float = 15.0, n_modes: int = 20):
        self._anm = ANM(self.name)
        self._anm.buildHessian(self.calphas, cutoff=cutoff)
        self._anm.calcModes(n_modes=n_modes)
        return self._anm

    def calc_deformability(self, n_modes: int = None) -> np.ndarray:
        if self._anm is None:
            self.run_anm()

        if n_modes is None:
            n_modes = len(self._anm)

        eigenvalues = self._anm.getEigvals()[:n_modes]
        eigenvectors = self._anm.getEigvecs()[:, :n_modes]

        eigenvectors_reshaped = eigenvectors.reshape(self.n_residues, 3, n_modes)
        u_squared = np.sum(eigenvectors_reshaped**2, axis=1)
        deformability_raw = np.sum(u_squared / eigenvalues**2, axis=1)

        # Normalize to 0-1
        deformability = (deformability_raw - deformability_raw.min()) / \
                       (deformability_raw.max() - deformability_raw.min())

        self.results['deformability'] = deformability
        self.results['deformability_mean'] = np.mean(deformability)
        self.results['deformability_std'] = np.std(deformability)

        for domain, (start, end) in [('scfv1', self.domain_indices['scfv1']),
                                      ('scfv2', self.domain_indices['scfv2']),
                                      ('linker', self.domain_indices['linker'])]:
            domain_def = deformability[start:end]
            self.results[f'{domain}_deformability_mean'] = np.mean(domain_def)
            self.results[f'{domain}_deformability_max'] = np.max(domain_def)
            self.results[f'{domain}_deformability_min'] = np.min(domain_def)

        return deformability

    def calc_stiffness(self) -> float:
        if self._gnm is None:
            self.run_gnm()

        eigenvalues = self._gnm.getEigvals()[:10]
        stiffness = np.mean(eigenvalues)

        self.results['stiffness'] = stiffness
        self.results['stiffness_min_eigenvalue'] = eigenvalues[0]
        self.results['stiffness_max_eigenvalue'] = eigenvalues[-1]

        return stiffness

    def calc_inter_domain_distance_fluctuation(self, n_modes: int = 20) -> float:
        if self._anm is None:
            self.run_anm(n_modes=n_modes)

        coords = self.calphas.getCoords()
        scfv1_start, scfv1_end = self.domain_indices['scfv1']
        scfv2_start, scfv2_end = self.domain_indices['scfv2']

        com1_initial = np.mean(coords[scfv1_start:scfv1_end], axis=0)
        com2_initial = np.mean(coords[scfv2_start:scfv2_end], axis=0)
        initial_distance = np.linalg.norm(com2_initial - com1_initial)

        eigenvectors = self._anm.getEigvecs()
        eigenvalues = self._anm.getEigvals()
        n_modes_available = min(n_modes, len(self._anm))

        distance_changes = []
        for mode_idx in range(n_modes_available):
            mode_vector = eigenvectors[:, mode_idx].reshape(-1, 3)
            amplitude = 5.0 / np.sqrt(eigenvalues[mode_idx])
            displaced_coords = coords + amplitude * mode_vector

            com1_displaced = np.mean(displaced_coords[scfv1_start:scfv1_end], axis=0)
            com2_displaced = np.mean(displaced_coords[scfv2_start:scfv2_end], axis=0)
            new_distance = np.linalg.norm(com2_displaced - com1_displaced)
            distance_changes.append(new_distance - initial_distance)

        weights = 1.0 / eigenvalues[:n_modes_available]
        weights = weights / np.sum(weights)
        mean_change = np.average(distance_changes, weights=weights)
        variance = np.average((np.array(distance_changes) - mean_change)**2, weights=weights)
        inter_domain_fluctuation = np.sqrt(variance)

        self.results['inter_domain_distance_initial'] = initial_distance
        self.results['inter_domain_fluctuation'] = inter_domain_fluctuation
        self.results['inter_domain_fluctuation_simple'] = np.std(distance_changes)

        return inter_domain_fluctuation

    def run_full_analysis(self, gnm_cutoff: float = 10.0, anm_cutoff: float = 15.0, n_modes: int = 20):
        self.run_gnm(cutoff=gnm_cutoff, n_modes=n_modes)
        self.calc_msf_gnm()
        self.calc_stiffness()

        self.run_anm(cutoff=anm_cutoff, n_modes=n_modes)
        self.calc_deformability(n_modes=n_modes)
        self.calc_inter_domain_distance_fluctuation(n_modes=n_modes)

        return self.results

    def get_summary_dict(self) -> dict:
        return {
            'Filename': self.name,
            'Total_Residues': self.n_residues,
            'Linker_Length': self.linker_length,
            'Effective_Stiffness': self.results.get('stiffness', np.nan),
            'Min_Eigenvalue': self.results.get('stiffness_min_eigenvalue', np.nan),
            'Max_Eigenvalue_10': self.results.get('stiffness_max_eigenvalue', np.nan),
            'Inter_Domain_Fluctuation': self.results.get('inter_domain_fluctuation', np.nan),
            'Inter_Domain_Fluctuation_Simple': self.results.get('inter_domain_fluctuation_simple', np.nan),
            'Initial_COM_Distance': self.results.get('inter_domain_distance_initial', np.nan),
            'MSF_Global_Mean': self.results.get('msf_gnm_mean', np.nan),
            'MSF_Global_Std': self.results.get('msf_gnm_std', np.nan),
            'MSF_scFv1_Mean': self.results.get('scfv1_msf_mean', np.nan),
            'MSF_scFv2_Mean': self.results.get('scfv2_msf_mean', np.nan),
            'MSF_Linker_Mean': self.results.get('linker_msf_mean', np.nan),
            'MSF_Linker_Std': self.results.get('linker_msf_std', np.nan),
            'MSF_Linker_Max': self.results.get('linker_msf_max', np.nan),
            'MSF_Linker_Min': self.results.get('linker_msf_min', np.nan),
            'Deformability_Global_Mean': self.results.get('deformability_mean', np.nan),
            'Deformability_Global_Std': self.results.get('deformability_std', np.nan),
            'Deformability_scFv1_Mean': self.results.get('scfv1_deformability_mean', np.nan),
            'Deformability_scFv2_Mean': self.results.get('scfv2_deformability_mean', np.nan),
            'Deformability_Linker_Mean': self.results.get('linker_deformability_mean', np.nan),
            'Deformability_Linker_Max': self.results.get('linker_deformability_max', np.nan),
            'Deformability_Linker_Min': self.results.get('linker_deformability_min', np.nan),
        }

# =============================================================================
# BATCH ANALYSIS
# =============================================================================

print("\n" + "-" * 70)
print("Running NMA Analysis...")
print("-" * 70)

try:
    _ = alphafold_results_df
    if len(alphafold_results_df) == 0:
        raise ValueError("Empty")
except:
    raise RuntimeError("alphafold_results_df not found. Run Cell 4 first.")

nma_results = []

for idx, row in alphafold_results_df.iterrows():
    jobname = row['Jobname']
    pdb_path = row['AlphaFold_PDB']

    print(f"\n[{idx+1}/{len(alphafold_results_df)}] Analyzing: {jobname}")

    if not os.path.exists(pdb_path):
        print(f"  ✗ PDB not found: {pdb_path}")
        continue

    try:
        analyzer = LinkerFlexibilityAnalyzer(pdb_path, name=jobname)
        analyzer.run_full_analysis(gnm_cutoff=GNM_CUTOFF, anm_cutoff=ANM_CUTOFF, n_modes=N_MODES)

        summary = analyzer.get_summary_dict()
        summary['Original_Filename'] = row.get('Original_Filename', jobname)
        summary['Original_RMSD'] = row.get('Original_RMSD', np.nan)
        summary['Linker_Group'] = row.get('Linker_Group', 'Unknown')

        nma_results.append(summary)

        print(f"  ✓ Stiffness: {summary['Effective_Stiffness']:.6f}")
        print(f"  ✓ Inter-Domain Fluct: {summary['Inter_Domain_Fluctuation']:.4f} Å")
        print(f"  ✓ Linker MSF: {summary['MSF_Linker_Mean']:.4f} Ų")

    except Exception as e:
        print(f"  ✗ Error: {e}")

# Create DataFrame
if nma_results:
    nma_results_df = pd.DataFrame(nma_results)
    nma_results_df.to_csv(Path(ALPHAFOLD_OUTPUT_DIR) / "nma_analysis_results.csv", index=False)

    print("\n" + "=" * 70)
    print(f"NMA ANALYSIS COMPLETE: {len(nma_results)} structures")
    print("=" * 70)

    # Display summary
    print("\n--- NMA Summary ---")
    display_cols = ['Filename', 'Linker_Length', 'Effective_Stiffness',
                   'Inter_Domain_Fluctuation', 'MSF_Linker_Mean', 'Deformability_Linker_Mean']
    display_df = nma_results_df[display_cols].copy()

    for col in ['Effective_Stiffness']:
        display_df[col] = display_df[col].apply(lambda x: f"{x:.6f}")
    for col in ['Inter_Domain_Fluctuation', 'MSF_Linker_Mean', 'Deformability_Linker_Mean']:
        display_df[col] = display_df[col].apply(lambda x: f"{x:.4f}")

    print(display_df.to_string(index=False))
else:
    nma_results_df = pd.DataFrame()
    print("\n⚠️ No NMA results obtained")

print("\n" + "=" * 70)
print("✓ Cell 6 Complete")
print("=" * 70)
print("→ Run Cell 7 for pLDDT/PAE analysis")
print("=" * 70)


In [None]:
# ================================================================================
# CELL 7: ALPHAFOLD JSON ANALYSIS (pLDDT & PAE)
# ================================================================================

#@title **Cell 7: pLDDT & PAE Analysis** { display-mode: "form" }
#@markdown Analyzes AlphaFold JSON output files for confidence metrics.
#@markdown
#@markdown **Metrics calculated:**
#@markdown - pLDDT scores (global, domain-specific, junction)
#@markdown - PAE metrics (inter-domain, linker stability)
#@markdown - pTM/ipTM scores
#@markdown
#@markdown **Requires:** Cell 4 must be run first.

import json
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import warnings
import os
warnings.filterwarnings('ignore')

print("=" * 70)
print("CELL 7: AlphaFold JSON Analysis (pLDDT & PAE)")
print("=" * 70)

# Domain constants
SCFV1_LENGTH = 240
SCFV2_LENGTH = 242
ALPHAFOLD_OUTPUT_DIR = "/content/alphafold_results/"

print(f"\nDomain Structure:")
print(f"  • scFv1: {SCFV1_LENGTH} residues")
print(f"  • scFv2: {SCFV2_LENGTH} residues")

# =============================================================================
# ANALYSIS FUNCTIONS
# =============================================================================

def load_alphafold_json(filepath: str) -> Optional[Dict]:
    try:
        with open(filepath, 'r') as f:
            return json.load(f)
    except Exception as e:
        print(f"Error loading {filepath}: {e}")
        return None


def calculate_domain_indices(total_length: int) -> Dict:
    linker_length = total_length - SCFV1_LENGTH - SCFV2_LENGTH
    if linker_length <= 0:
        raise ValueError(f"Invalid total length {total_length}")

    return {
        'scfv1': (0, SCFV1_LENGTH),
        'linker': (SCFV1_LENGTH, SCFV1_LENGTH + linker_length),
        'scfv2': (SCFV1_LENGTH + linker_length, total_length),
        'linker_length': linker_length,
        'junction_n': (SCFV1_LENGTH, SCFV1_LENGTH + 3),
        'junction_c': (SCFV1_LENGTH + linker_length - 3, SCFV1_LENGTH + linker_length)
    }


def extract_plddt_metrics(plddt: List[float], indices: Dict) -> Dict[str, float]:
    plddt_array = np.array(plddt)

    scfv1_plddt = plddt_array[indices['scfv1'][0]:indices['scfv1'][1]]
    scfv2_plddt = plddt_array[indices['scfv2'][0]:indices['scfv2'][1]]
    linker_plddt = plddt_array[indices['linker'][0]:indices['linker'][1]]

    junction_plddt = np.concatenate([
        plddt_array[indices['junction_n'][0]:indices['junction_n'][1]],
        plddt_array[indices['junction_c'][0]:indices['junction_c'][1]]
    ])

    return {
        'pLDDT_Global_Mean': float(np.mean(plddt_array)),
        'pLDDT_Global_Std': float(np.std(plddt_array)),
        'pLDDT_Global_Min': float(np.min(plddt_array)),
        'pLDDT_Global_Max': float(np.max(plddt_array)),
        'pLDDT_scFv1_Mean': float(np.mean(scfv1_plddt)),
        'pLDDT_scFv1_Std': float(np.std(scfv1_plddt)),
        'pLDDT_scFv1_Min': float(np.min(scfv1_plddt)),
        'pLDDT_scFv2_Mean': float(np.mean(scfv2_plddt)),
        'pLDDT_scFv2_Std': float(np.std(scfv2_plddt)),
        'pLDDT_scFv2_Min': float(np.min(scfv2_plddt)),
        'pLDDT_Linker_Mean': float(np.mean(linker_plddt)),
        'pLDDT_Linker_Std': float(np.std(linker_plddt)),
        'pLDDT_Linker_Min': float(np.min(linker_plddt)),
        'pLDDT_Linker_Max': float(np.max(linker_plddt)),
        'pLDDT_Junction_Mean': float(np.mean(junction_plddt)),
        'pLDDT_Junction_Min': float(np.min(junction_plddt)),
        'pLDDT_Junction_N_Mean': float(np.mean(plddt_array[indices['junction_n'][0]:indices['junction_n'][1]])),
        'pLDDT_Junction_C_Mean': float(np.mean(plddt_array[indices['junction_c'][0]:indices['junction_c'][1]])),
    }


def extract_pae_metrics(pae: List[List[float]], indices: Dict) -> Dict[str, float]:
    pae_array = np.array(pae)

    scfv1_s, scfv1_e = indices['scfv1']
    scfv2_s, scfv2_e = indices['scfv2']
    linker_s, linker_e = indices['linker']

    pae_12 = pae_array[scfv1_s:scfv1_e, scfv2_s:scfv2_e]
    pae_21 = pae_array[scfv2_s:scfv2_e, scfv1_s:scfv1_e]
    inter_domain = (np.mean(pae_12) + np.mean(pae_21)) / 2

    linker_to_1 = np.mean(pae_array[linker_s:linker_e, scfv1_s:scfv1_e])
    linker_to_2 = np.mean(pae_array[linker_s:linker_e, scfv2_s:scfv2_e])

    pae_linker = pae_array[linker_s:linker_e, linker_s:linker_e]
    mask = ~np.eye(linker_e - linker_s, dtype=bool)
    linker_internal = np.mean(pae_linker[mask]) if mask.sum() > 0 else 0.0

    pae_scfv1 = pae_array[scfv1_s:scfv1_e, scfv1_s:scfv1_e]
    mask1 = ~np.eye(scfv1_e - scfv1_s, dtype=bool)
    scfv1_internal = np.mean(pae_scfv1[mask1])

    pae_scfv2 = pae_array[scfv2_s:scfv2_e, scfv2_s:scfv2_e]
    mask2 = ~np.eye(scfv2_e - scfv2_s, dtype=bool)
    scfv2_internal = np.mean(pae_scfv2[mask2])

    junction_n = np.mean(pae_array[linker_s:linker_s+5, scfv1_e-10:scfv1_e])
    junction_c = np.mean(pae_array[linker_e-5:linker_e, scfv2_s:scfv2_s+10])

    return {
        'PAE_Inter_Domain': float(inter_domain),
        'PAE_scFv1_to_scFv2': float(np.mean(pae_12)),
        'PAE_scFv2_to_scFv1': float(np.mean(pae_21)),
        'PAE_Linker_to_scFv1': float(linker_to_1),
        'PAE_Linker_to_scFv2': float(linker_to_2),
        'PAE_Linker_Stability': float((linker_to_1 + linker_to_2) / 2),
        'PAE_Linker_Internal': float(linker_internal),
        'PAE_scFv1_Internal': float(scfv1_internal),
        'PAE_scFv2_Internal': float(scfv2_internal),
        'PAE_Junction_N': float(junction_n),
        'PAE_Junction_C': float(junction_c),
        'PAE_Junction_Mean': float((junction_n + junction_c) / 2),
        'PAE_Global_Mean': float(np.mean(pae_array)),
    }


def analyze_alphafold_json(filepath: str, filename: str = None) -> Optional[Dict]:
    data = load_alphafold_json(filepath)
    if data is None:
        return None

    if filename is None:
        filename = Path(filepath).name

    plddt = data.get('plddt', [])
    if not plddt:
        return None

    total_length = len(plddt)

    try:
        indices = calculate_domain_indices(total_length)
    except ValueError:
        return None

    results = {
        'Filename': filename,
        'Total_Length': total_length,
        'Linker_Length': indices['linker_length'],
        'pTM_Score': data.get('ptm', data.get('pTM', data.get('ranking_confidence', np.nan))),
        'ipTM_Score': data.get('iptm', data.get('ipTM', np.nan))
    }

    results.update(extract_plddt_metrics(plddt, indices))

    pae = data.get('pae', data.get('predicted_aligned_error', None))
    if pae is not None:
        results.update(extract_pae_metrics(pae, indices))
    else:
        for key in ['PAE_Inter_Domain', 'PAE_Linker_Stability', 'PAE_Junction_Mean']:
            results[key] = np.nan

    return results

# =============================================================================
# BATCH ANALYSIS
# =============================================================================

print("\n" + "-" * 70)
print("Running JSON Analysis...")
print("-" * 70)

try:
    _ = alphafold_results_df
    if len(alphafold_results_df) == 0:
        raise ValueError("Empty")
except:
    raise RuntimeError("alphafold_results_df not found. Run Cell 4 first.")

json_results = []

for idx, row in alphafold_results_df.iterrows():
    jobname = row['Jobname']
    json_path = row.get('AlphaFold_JSON', None)

    print(f"\n[{idx+1}/{len(alphafold_results_df)}] Analyzing: {jobname}")

    if json_path is None or not os.path.exists(str(json_path)):
        print(f"  ✗ JSON not found")
        continue

    try:
        result = analyze_alphafold_json(str(json_path), filename=jobname)

        if result is not None:
            result['Original_Filename'] = row.get('Original_Filename', jobname)
            result['Original_RMSD'] = row.get('Original_RMSD', np.nan)
            result['Linker_Group'] = row.get('Linker_Group', 'Unknown')

            json_results.append(result)

            ptm = result.get('pTM_Score', np.nan)
            ptm_str = f"{ptm:.3f}" if pd.notna(ptm) else "N/A"
            print(f"  ✓ pTM: {ptm_str}")
            print(f"  ✓ Global pLDDT: {result['pLDDT_Global_Mean']:.1f}")
            print(f"  ✓ Linker pLDDT: {result['pLDDT_Linker_Mean']:.1f}")
            if pd.notna(result.get('PAE_Inter_Domain', np.nan)):
                print(f"  ✓ Inter-Domain PAE: {result['PAE_Inter_Domain']:.2f}")

    except Exception as e:
        print(f"  ✗ Error: {e}")

# Create DataFrame
if json_results:
    json_results_df = pd.DataFrame(json_results)
    json_results_df.to_csv(Path(ALPHAFOLD_OUTPUT_DIR) / "json_analysis_results.csv", index=False)

    print("\n" + "=" * 70)
    print(f"JSON ANALYSIS COMPLETE: {len(json_results)} files")
    print("=" * 70)

    print("\n--- pLDDT/PAE Summary ---")
    display_cols = ['Filename', 'Linker_Length', 'pTM_Score',
                   'pLDDT_Global_Mean', 'pLDDT_Linker_Mean', 'PAE_Inter_Domain']
    display_cols = [c for c in display_cols if c in json_results_df.columns]
    display_df = json_results_df[display_cols].copy()

    for col in display_df.columns:
        if 'pLDDT' in col:
            display_df[col] = display_df[col].apply(lambda x: f"{x:.1f}" if pd.notna(x) else "N/A")
        elif 'PAE' in col:
            display_df[col] = display_df[col].apply(lambda x: f"{x:.2f}" if pd.notna(x) else "N/A")
        elif 'pTM' in col:
            display_df[col] = display_df[col].apply(lambda x: f"{x:.3f}" if pd.notna(x) else "N/A")

    print(display_df.to_string(index=False))
else:
    json_results_df = pd.DataFrame()
    print("\n⚠️ No JSON results obtained")

print("\n" + "=" * 70)
print("✓ Cell 7 Complete")
print("=" * 70)
print("→ Run Cell 8 for Final Summary & Download")
print("=" * 70)


In [None]:
# ================================================================================
# CELL 8: FINAL SUMMARY - COMBINED RESULTS & DOWNLOAD
# ================================================================================
# Combines all results and provides downloadable files:
# - Excel file with all metrics (multiple sheets)
# - ZIP file containing all PDB, JSON, CSV files
# ================================================================================

#@title **Cell 8: Final Summary & Download** { display-mode: "form" }

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import warnings
import shutil
import zipfile
import os
warnings.filterwarnings('ignore')

# Install openpyxl for Excel export
try:
    import openpyxl
except ImportError:
    import subprocess
    subprocess.run(['pip', 'install', '-q', 'openpyxl'], capture_output=True)
    import openpyxl

print("=" * 70)
print("CELL 8: Final Summary - Combined Results & Download")
print("=" * 70)

ALPHAFOLD_OUTPUT_DIR = "/content/alphafold_results/"
DOWNLOAD_DIR = "/content/download/"

# Create download directory
os.makedirs(DOWNLOAD_DIR, exist_ok=True)

# ============================================================================
# COLLECT ALL RESULTS
# ============================================================================

print("\n[Step 1/5] Collecting results from all analyses...")

available_dfs = {}

# AlphaFold results (from Cell 4)
try:
    if 'alphafold_results_df' in dir() and len(alphafold_results_df) > 0:
        available_dfs['alphafold'] = alphafold_results_df
        print(f"  ✓ AlphaFold results: {len(alphafold_results_df)} entries")
except:
    print("  ✗ AlphaFold results not found")

# Comparison results (from Cell 5)
try:
    if 'comparison_df' in dir() and len(comparison_df) > 0:
        available_dfs['comparison'] = comparison_df
        print(f"  ✓ RMSD Comparison results: {len(comparison_df)} entries")
except:
    print("  ✗ RMSD Comparison results not found")

# NMA results (from Cell 6)
try:
    if 'nma_results_df' in dir() and len(nma_results_df) > 0:
        available_dfs['nma'] = nma_results_df
        print(f"  ✓ NMA results: {len(nma_results_df)} entries")
except:
    print("  ✗ NMA results not found")

# JSON results (from Cell 7)
try:
    if 'json_results_df' in dir() and len(json_results_df) > 0:
        available_dfs['json'] = json_results_df
        print(f"  ✓ JSON (pLDDT/PAE) results: {len(json_results_df)} entries")
except:
    print("  ✗ JSON results not found")

if not available_dfs:
    print("\n✗ No analysis results found!")
    raise RuntimeError("No results to combine")

# ============================================================================
# MERGE ALL RESULTS
# ============================================================================

print("\n[Step 2/5] Merging all results...")

# Start with AlphaFold results as base
if 'alphafold' in available_dfs:
    combined_df = available_dfs['alphafold'].copy()
    base_cols = ['Jobname', 'Original_Filename', 'Original_RMSD', 'Linker_Group',
                 'Sequence_Length', 'AlphaFold_PDB', 'AlphaFold_JSON', 'PDB_Type', 'Elapsed_Time']
    combined_df = combined_df[[c for c in base_cols if c in combined_df.columns]]
else:
    combined_df = available_dfs.get('nma', available_dfs.get('json', pd.DataFrame()))

# Merge Comparison results (Cell 5)
if 'comparison' in available_dfs:
    comp_cols = ['Jobname', 'AlphaFold_RMSD', 'RMSD_Change', 'RMSD_Improvement', 'Percent_Change']
    comp_cols = [c for c in comp_cols if c in available_dfs['comparison'].columns]
    if comp_cols:
        comp_merge = available_dfs['comparison'][comp_cols].copy()
        combined_df = combined_df.merge(comp_merge, on='Jobname', how='left')
        print(f"  ✓ Merged RMSD Comparison results")

# Merge NMA results
if 'nma' in available_dfs:
    nma_cols_to_merge = [c for c in available_dfs['nma'].columns
                        if c not in ['Original_Filename', 'Original_RMSD', 'Linker_Group', 'Total_Residues']]
    nma_merge = available_dfs['nma'][nma_cols_to_merge].copy()
    nma_merge = nma_merge.rename(columns={'Filename': 'Jobname'})
    combined_df = combined_df.merge(nma_merge, on='Jobname', how='left')
    print(f"  ✓ Merged NMA results")

# Merge JSON results
if 'json' in available_dfs:
    json_cols_to_merge = [c for c in available_dfs['json'].columns
                         if c not in ['Original_Filename', 'Original_RMSD', 'Linker_Group', 'Total_Length']]
    json_merge = available_dfs['json'][json_cols_to_merge].copy()
    json_merge = json_merge.rename(columns={'Filename': 'Jobname'})
    combined_df = combined_df.merge(json_merge, on='Jobname', how='left')
    print(f"  ✓ Merged JSON results")

print(f"\n  Combined DataFrame: {len(combined_df)} rows × {len(combined_df.columns)} columns")

# ============================================================================
# CALCULATE RIGIDITY SCORES
# ============================================================================

print("\n[Step 3/5] Calculating rigidity scores...")

def normalize_0_1(series, higher_is_better=True):
    if series.isna().all():
        return series
    min_val = series.min()
    max_val = series.max()
    if max_val == min_val:
        return pd.Series([0.5] * len(series))
    normalized = (series - min_val) / (max_val - min_val)
    if not higher_is_better:
        normalized = 1 - normalized
    return normalized

rigidity_components = {}

if 'Effective_Stiffness' in combined_df.columns:
    rigidity_components['Stiffness_Score'] = normalize_0_1(combined_df['Effective_Stiffness'], higher_is_better=True)
if 'Inter_Domain_Fluctuation' in combined_df.columns:
    rigidity_components['InterDomain_Score'] = normalize_0_1(combined_df['Inter_Domain_Fluctuation'], higher_is_better=False)
if 'pLDDT_Linker_Mean' in combined_df.columns:
    rigidity_components['pLDDT_Score'] = normalize_0_1(combined_df['pLDDT_Linker_Mean'], higher_is_better=True)
if 'PAE_Linker_Stability' in combined_df.columns:
    rigidity_components['PAE_Score'] = normalize_0_1(combined_df['PAE_Linker_Stability'], higher_is_better=False)
if 'MSF_Linker_Mean' in combined_df.columns:
    rigidity_components['MSF_Score'] = normalize_0_1(combined_df['MSF_Linker_Mean'], higher_is_better=False)

if rigidity_components:
    rigidity_df = pd.DataFrame(rigidity_components)
    combined_df['Rigidity_Score'] = rigidity_df.mean(axis=1, skipna=True)
    combined_df['Rigidity_Rank'] = combined_df['Rigidity_Score'].rank(ascending=False, method='min').astype(int)
    print(f"  ✓ Calculated Rigidity Score from {len(rigidity_components)} components")

# ============================================================================
# CREATE SUMMARY DATAFRAME (KEY METRICS ONLY)
# ============================================================================

# Define key columns for summary sheet
summary_columns = [
    'Jobname', 'Linker_Length', 'Original_RMSD', 'AlphaFold_RMSD', 'RMSD_Improvement',
    'Rigidity_Score', 'Rigidity_Rank',
    'Effective_Stiffness', 'Inter_Domain_Fluctuation', 'Initial_COM_Distance',
    'MSF_Linker_Mean', 'Deformability_Linker_Mean',
    'pTM_Score', 'pLDDT_Global_Mean', 'pLDDT_Linker_Mean', 'pLDDT_Linker_Min', 'pLDDT_Junction_Mean',
    'PAE_Inter_Domain', 'PAE_Linker_Stability', 'PAE_Junction_Mean'
]
summary_columns = [c for c in summary_columns if c in combined_df.columns]
summary_df = combined_df[summary_columns].copy()

# Sort by Rigidity Score
if 'Rigidity_Score' in summary_df.columns:
    summary_df = summary_df.sort_values('Rigidity_Score', ascending=False)

# ============================================================================
# SAVE TO EXCEL (MULTIPLE SHEETS)
# ============================================================================

print("\n[Step 4/5] Saving results to Excel and CSV...")

excel_path = Path(DOWNLOAD_DIR) / "BiTE_Analysis_Results.xlsx"

with pd.ExcelWriter(excel_path, engine='openpyxl') as writer:
    # Sheet 1: Summary (key metrics only)
    summary_df.to_excel(writer, sheet_name='Summary', index=False)
    print(f"  ✓ Sheet 'Summary': {len(summary_df)} rows × {len(summary_df.columns)} cols")

    # Sheet 2: All Combined Data
    combined_df.to_excel(writer, sheet_name='All_Data', index=False)
    print(f"  ✓ Sheet 'All_Data': {len(combined_df)} rows × {len(combined_df.columns)} cols")

    # Sheet 3: NMA Results (if available)
    if 'nma' in available_dfs:
        available_dfs['nma'].to_excel(writer, sheet_name='NMA_Analysis', index=False)
        print(f"  ✓ Sheet 'NMA_Analysis': {len(available_dfs['nma'])} rows")

    # Sheet 4: JSON Results (if available)
    if 'json' in available_dfs:
        available_dfs['json'].to_excel(writer, sheet_name='pLDDT_PAE_Analysis', index=False)
        print(f"  ✓ Sheet 'pLDDT_PAE_Analysis': {len(available_dfs['json'])} rows")

    # Sheet 5: RMSD Comparison (if available)
    if 'comparison' in available_dfs:
        available_dfs['comparison'].to_excel(writer, sheet_name='RMSD_Comparison', index=False)
        print(f"  ✓ Sheet 'RMSD_Comparison': {len(available_dfs['comparison'])} rows")

    # Sheet 6: Column Descriptions
    descriptions = pd.DataFrame([
        ['Jobname', 'Sample identifier'],
        ['Linker_Length', 'Number of amino acids in linker region'],
        ['Original_RMSD', 'RMSD of original structure vs reference (Å)'],
        ['AlphaFold_RMSD', 'RMSD of AlphaFold prediction vs reference (Å)'],
        ['RMSD_Improvement', 'Original_RMSD - AlphaFold_RMSD (positive = improved)'],
        ['Rigidity_Score', 'Composite rigidity score (0-1, higher = more rigid)'],
        ['Rigidity_Rank', 'Rank by rigidity (1 = most rigid)'],
        ['Effective_Stiffness', 'GNM-based stiffness (higher = more rigid)'],
        ['Inter_Domain_Fluctuation', 'COM distance fluctuation between scFv1/scFv2 (Å, lower = better)'],
        ['MSF_Linker_Mean', 'Mean square fluctuation of linker (Ų, lower = more rigid)'],
        ['Deformability_Linker_Mean', 'ANM deformability of linker (0-1, lower = more rigid)'],
        ['pTM_Score', 'AlphaFold predicted TM-score'],
        ['pLDDT_Global_Mean', 'Mean pLDDT across all residues'],
        ['pLDDT_Linker_Mean', 'Mean pLDDT of linker region (higher = more confident)'],
        ['pLDDT_Linker_Min', 'Minimum pLDDT in linker region'],
        ['pLDDT_Junction_Mean', 'Mean pLDDT at linker-scFv junctions'],
        ['PAE_Inter_Domain', 'PAE between scFv1 and scFv2 (Å, lower = better)'],
        ['PAE_Linker_Stability', 'PAE of linker to scFv domains (Å, lower = better)'],
        ['PAE_Junction_Mean', 'PAE at junction regions (Å, lower = better)'],
    ], columns=['Column', 'Description'])
    descriptions.to_excel(writer, sheet_name='Column_Descriptions', index=False)
    print(f"  ✓ Sheet 'Column_Descriptions': metric explanations")

print(f"\n  ✓ Excel saved: {excel_path}")

# Also save as CSV
csv_path = Path(DOWNLOAD_DIR) / "BiTE_Analysis_Summary.csv"
summary_df.to_csv(csv_path, index=False)
print(f"  ✓ CSV saved: {csv_path}")

# ============================================================================
# CREATE ZIP FILE WITH ALL RESULTS
# ============================================================================

print("\n[Step 5/5] Creating ZIP file with all results...")

zip_path = Path(DOWNLOAD_DIR) / "BiTE_Analysis_Complete.zip"

with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
    # Add Excel file
    zipf.write(excel_path, "BiTE_Analysis_Results.xlsx")
    print(f"  ✓ Added: BiTE_Analysis_Results.xlsx")

    # Add CSV file
    zipf.write(csv_path, "BiTE_Analysis_Summary.csv")
    print(f"  ✓ Added: BiTE_Analysis_Summary.csv")

    # Add PDB files
    pdb_dir = Path(ALPHAFOLD_OUTPUT_DIR)
    pdb_files = list(pdb_dir.glob("*.pdb"))
    if pdb_files:
        for pdb_file in pdb_files:
            zipf.write(pdb_file, f"PDB_files/{pdb_file.name}")
        print(f"  ✓ Added: {len(pdb_files)} PDB files → PDB_files/")

    # Add JSON files
    json_files = list(pdb_dir.glob("*.json"))
    if json_files:
        for json_file in json_files:
            zipf.write(json_file, f"JSON_files/{json_file.name}")
        print(f"  ✓ Added: {len(json_files)} JSON files → JSON_files/")

    # Add visualization if exists
    plot_path = pdb_dir / "FINAL_summary_plots.png"
    if plot_path.exists():
        zipf.write(plot_path, "FINAL_summary_plots.png")
        print(f"  ✓ Added: FINAL_summary_plots.png")

print(f"\n  ✓ ZIP created: {zip_path}")

# ============================================================================
# DISPLAY SUMMARY TABLE
# ============================================================================

print("\n" + "=" * 70)
print("FINAL SUMMARY TABLE")
print("=" * 70)

# Display key metrics
display_cols = ['Jobname', 'Linker_Length', 'Rigidity_Score', 'Rigidity_Rank',
                'Effective_Stiffness', 'Inter_Domain_Fluctuation',
                'pLDDT_Linker_Mean', 'PAE_Inter_Domain']
display_cols = [c for c in display_cols if c in summary_df.columns]
display_df = summary_df[display_cols].copy()

# Format for display
for col in display_df.columns:
    if col in ['Rigidity_Score', 'Effective_Stiffness']:
        display_df[col] = display_df[col].apply(lambda x: f"{x:.4f}" if pd.notna(x) else "N/A")
    elif col in ['Inter_Domain_Fluctuation']:
        display_df[col] = display_df[col].apply(lambda x: f"{x:.3f}" if pd.notna(x) else "N/A")
    elif 'pLDDT' in col:
        display_df[col] = display_df[col].apply(lambda x: f"{x:.1f}" if pd.notna(x) else "N/A")
    elif 'PAE' in col:
        display_df[col] = display_df[col].apply(lambda x: f"{x:.2f}" if pd.notna(x) else "N/A")

print(display_df.to_string(index=False))

# ============================================================================
# INTERPRETATION GUIDE
# ============================================================================

print("\n" + "-" * 70)
print("INTERPRETATION GUIDE")
print("-" * 70)
print("• Rigidity_Score: Higher = more rigid linker (0-1 scale)")
print("• Effective_Stiffness: Higher = stiffer structure")
print("• Inter_Domain_Fluctuation: Lower = more stable domain separation (Å)")
print("• pLDDT_Linker_Mean: Higher = more confident prediction (>70 good, >90 excellent)")
print("• PAE_Inter_Domain: Lower = better domain orientation confidence (Å)")
print("-" * 70)

# ============================================================================
# TOP CANDIDATES
# ============================================================================

if 'Rigidity_Score' in summary_df.columns:
    print("\n" + "=" * 70)
    print("TOP 5 MOST RIGID CANDIDATES")
    print("=" * 70)

    top5 = summary_df.head(5)

    for rank, (idx, row) in enumerate(top5.iterrows(), 1):
        print(f"\n#{rank}: {row['Jobname']}")
        if 'Linker_Length' in row:
            print(f"    Linker Length: {row['Linker_Length']} aa")
        print(f"    Rigidity Score: {row['Rigidity_Score']:.4f}")
        if 'Effective_Stiffness' in row and pd.notna(row['Effective_Stiffness']):
            print(f"    Effective Stiffness: {row['Effective_Stiffness']:.6f}")
        if 'Inter_Domain_Fluctuation' in row and pd.notna(row['Inter_Domain_Fluctuation']):
            print(f"    Inter-Domain Fluctuation: {row['Inter_Domain_Fluctuation']:.4f} Å")
        if 'pLDDT_Linker_Mean' in row and pd.notna(row['pLDDT_Linker_Mean']):
            print(f"    Linker pLDDT: {row['pLDDT_Linker_Mean']:.1f}")
        if 'PAE_Inter_Domain' in row and pd.notna(row['PAE_Inter_Domain']):
            print(f"    Inter-Domain PAE: {row['PAE_Inter_Domain']:.2f} Å")

# ============================================================================
# GENERATE VISUALIZATION
# ============================================================================

print("\n" + "-" * 70)
print("GENERATING VISUALIZATIONS")
print("-" * 70)

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Plot 1: Rigidity Score Ranking
if 'Rigidity_Score' in combined_df.columns:
    ax1 = axes[0, 0]
    sorted_df = combined_df.sort_values('Rigidity_Score', ascending=True)
    colors = plt.cm.RdYlGn(sorted_df['Rigidity_Score'] / sorted_df['Rigidity_Score'].max())
    ax1.barh(range(len(sorted_df)), sorted_df['Rigidity_Score'], color=colors)
    ax1.set_yticks(range(len(sorted_df)))
    ax1.set_yticklabels(sorted_df['Jobname'], fontsize=8)
    ax1.set_xlabel('Rigidity Score')
    ax1.set_title('Rigidity Score Ranking\n(Higher = More Rigid)')
    ax1.axvline(x=sorted_df['Rigidity_Score'].mean(), color='red', linestyle='--', label='Mean')
    ax1.legend()

# Plot 2: Stiffness vs Inter-Domain Fluctuation
if 'Effective_Stiffness' in combined_df.columns and 'Inter_Domain_Fluctuation' in combined_df.columns:
    ax2 = axes[0, 1]
    scatter = ax2.scatter(combined_df['Effective_Stiffness'],
                         combined_df['Inter_Domain_Fluctuation'],
                         c=combined_df.get('Rigidity_Score', 'blue'),
                         cmap='RdYlGn', s=100, edgecolors='black')
    ax2.set_xlabel('Effective Stiffness')
    ax2.set_ylabel('Inter-Domain Fluctuation (Å)')
    ax2.set_title('Stiffness vs Inter-Domain Fluctuation')
    plt.colorbar(scatter, ax=ax2, label='Rigidity Score')
    for idx, row in combined_df.iterrows():
        ax2.annotate(row['Jobname'],
                    (row['Effective_Stiffness'], row['Inter_Domain_Fluctuation']),
                    fontsize=6, alpha=0.7)

# Plot 3: pLDDT Distribution by Domain
if 'pLDDT_Linker_Mean' in combined_df.columns:
    ax3 = axes[1, 0]
    plddt_data = []
    labels = []
    if 'pLDDT_scFv1_Mean' in combined_df.columns:
        plddt_data.append(combined_df['pLDDT_scFv1_Mean'].dropna())
        labels.append('scFv1')
    if 'pLDDT_Linker_Mean' in combined_df.columns:
        plddt_data.append(combined_df['pLDDT_Linker_Mean'].dropna())
        labels.append('Linker')
    if 'pLDDT_scFv2_Mean' in combined_df.columns:
        plddt_data.append(combined_df['pLDDT_scFv2_Mean'].dropna())
        labels.append('scFv2')
    if plddt_data:
        bp = ax3.boxplot(plddt_data, labels=labels, patch_artist=True)
        colors = ['#3498db', '#e74c3c', '#2ecc71']
        for patch, color in zip(bp['boxes'], colors[:len(plddt_data)]):
            patch.set_facecolor(color)
            patch.set_alpha(0.7)
        ax3.set_ylabel('pLDDT Score')
        ax3.set_title('pLDDT Distribution by Domain')
        ax3.axhline(y=70, color='orange', linestyle='--', label='Confident (70)')
        ax3.axhline(y=90, color='green', linestyle='--', label='Very High (90)')
        ax3.legend(fontsize=8)

# Plot 4: PAE Metrics
if 'PAE_Inter_Domain' in combined_df.columns:
    ax4 = axes[1, 1]
    pae_cols = ['PAE_Inter_Domain', 'PAE_Linker_Stability', 'PAE_Linker_Internal']
    pae_cols = [c for c in pae_cols if c in combined_df.columns]
    if pae_cols:
        x = np.arange(len(combined_df))
        width = 0.25
        for i, col in enumerate(pae_cols):
            ax4.bar(x + i*width, combined_df[col], width, label=col.replace('PAE_', ''))
        ax4.set_xlabel('Sample')
        ax4.set_ylabel('PAE (Å)')
        ax4.set_title('PAE Metrics by Sample')
        ax4.set_xticks(x + width)
        ax4.set_xticklabels(combined_df['Jobname'], rotation=45, ha='right', fontsize=8)
        ax4.legend(fontsize=8)

plt.tight_layout()
plot_path = Path(DOWNLOAD_DIR) / "FINAL_summary_plots.png"
plt.savefig(plot_path, dpi=300, bbox_inches='tight')
plt.show()

print(f"  ✓ Saved: {plot_path}")

# Add plot to ZIP
with zipfile.ZipFile(zip_path, 'a') as zipf:
    zipf.write(plot_path, "FINAL_summary_plots.png")

# ============================================================================
# DOWNLOAD LINKS (for Colab)
# ============================================================================

print("\n" + "=" * 70)
print("DOWNLOAD FILES")
print("=" * 70)

# Check if running in Colab
try:
    from google.colab import files
    IN_COLAB = True
except:
    IN_COLAB = False

print(f"\nFiles ready for download in: {DOWNLOAD_DIR}")
print(f"\n📁 Available files:")
print(f"  • BiTE_Analysis_Results.xlsx - Excel with all results (6 sheets)")
print(f"  • BiTE_Analysis_Summary.csv - Key metrics summary")
print(f"  • BiTE_Analysis_Complete.zip - All files including PDB & JSON")
print(f"  • FINAL_summary_plots.png - Visualization")

if IN_COLAB:
    print("\n" + "-" * 70)
    print("Click the links below to download:")
    print("-" * 70)

    # Download Excel
    print("\n📊 Downloading Excel file...")
    files.download(str(excel_path))

    # Download ZIP
    print("\n📦 Downloading ZIP file (contains all PDB, JSON, CSV)...")
    files.download(str(zip_path))
else:
    print(f"\n📂 Files are saved in: {DOWNLOAD_DIR}")
    print("   You can download them from the file browser on the left.")

# ============================================================================
# FINAL OUTPUT
# ============================================================================

print("\n" + "=" * 70)
print("✓ ANALYSIS COMPLETE!")
print("=" * 70)

print(f"""
Output Summary:
  • Total samples analyzed: {len(combined_df)}
  • Excel file: BiTE_Analysis_Results.xlsx (6 sheets)
  • ZIP file: BiTE_Analysis_Complete.zip
    - PDB files: {len(list(Path(ALPHAFOLD_OUTPUT_DIR).glob('*.pdb')))}
    - JSON files: {len(list(Path(ALPHAFOLD_OUTPUT_DIR).glob('*.json')))}
    - Summary CSV and Excel
    - Visualization plots
""")

# Store for further analysis
final_results_df = combined_df.copy()
final_summary_df = summary_df.copy()

print("=" * 70)
print("Variables available for further analysis:")
print("  • final_results_df - All combined data")
print("  • final_summary_df - Key metrics summary")
print("=" * 70)
