# Final Candidate Analysis
## BiTE Rigid Linker - 최종 후보 분석

### 필요한 파일
```
/content/
├── reference.pdb                              ← ColabFold PDB
├── predictions/                               ← 외부 AlphaFold2 결과
│   ├── *_relaxed_rank_001.pdb
│   ├── *_relaxed_rank_002.pdb
│   ├── ...
│   ├── *_scores_rank_001.json                 ← pLDDT/PAE 데이터
│   └── ...
└── Include_antigen_with_perfect_distance.pdb  ← 항원-항체 복합체
```

### 분석 항목
| 항목 | 설명 | 출처 |
|------|------|------|
| RMSD | Reference와의 구조 차이 | PDB |
| Antigen Distance | CD3e-HER2 거리 (목표: 130Å) | PDB |
| pTM, pLDDT | 구조 예측 신뢰도 | JSON |
| PAE | 도메인 간 오차 예측 | JSON |
| NMA | 유연성/강성 분석 | PDB |

### 출력
- **모든 메트릭을 컬럼으로 하는 전체 테이블**
- CSV 파일 저장

---

In [None]:
# ================================================================================
# CELL 1: RMSD ANALYSIS FUNCTIONS
# ================================================================================

#@title **Cell 1: Load RMSD Analysis Functions** { display-mode: "form" }
#@markdown This cell loads all core functions for RMSD calculation.

# --- Package Installation ---
!pip install biopython --quiet

# --- Imports ---
import numpy as np
import pandas as pd
from pathlib import Path
from typing import List, Tuple, Dict, Optional
import warnings
warnings.filterwarnings('ignore')

from Bio.PDB import PDBParser, Superimposer, PPBuilder
from Bio.PDB.Atom import Atom
from Bio.PDB.Residue import Residue
from Bio.PDB.Polypeptide import is_aa

# --- Constants ---
SCFV1_LENGTH = 240  # First scFv domain length
SCFV2_LENGTH = 242  # Second scFv domain length

print("=" * 70)
print("CELL 1: Loading RMSD Analysis Functions")
print("=" * 70)

# =============================================================================
# CORE FUNCTIONS
# =============================================================================

def get_ca_atoms(residues: List[Residue]) -> List[Atom]:
    """Extract CA atoms from residues."""
    ca_atoms = []
    for res in residues:
        if 'CA' in res:
            ca_atoms.append(res['CA'])
    return ca_atoms


def extract_residues(structure) -> List[Residue]:
    """Extract all amino acid residues from a structure."""
    residues = []
    for model in structure:
        for chain in model:
            for residue in chain:
                if is_aa(residue, standard=True):
                    residues.append(residue)
    return residues


def align_and_calculate_rmsd(
    ref_atoms: List[Atom],
    target_atoms: List[Atom],
    full_target_atoms: List[Atom] = None
) -> Tuple[float, List[Atom]]:
    """Align target atoms to reference atoms and calculate RMSD."""
    if len(ref_atoms) != len(target_atoms):
        raise ValueError(f"Atom count mismatch: ref={len(ref_atoms)}, target={len(target_atoms)}")

    if len(ref_atoms) == 0:
        raise ValueError("No atoms to align")

    sup = Superimposer()
    sup.set_atoms(ref_atoms, target_atoms)

    if full_target_atoms is not None:
        sup.apply(full_target_atoms)
    else:
        sup.apply(target_atoms)

    return sup.rms, target_atoms


def calculate_rmsd_only(atoms1: List[Atom], atoms2: List[Atom]) -> float:
    """Calculate RMSD between two sets of atoms without alignment."""
    if len(atoms1) != len(atoms2):
        raise ValueError(f"Atom count mismatch: {len(atoms1)} vs {len(atoms2)}")

    coords1 = np.array([atom.get_coord() for atom in atoms1])
    coords2 = np.array([atom.get_coord() for atom in atoms2])

    diff = coords1 - coords2
    rmsd = np.sqrt(np.mean(np.sum(diff**2, axis=1)))

    return rmsd


def load_reference_structure(ref_path: str) -> Tuple[List[Atom], List[Atom]]:
    """Load reference structure and extract scFv1 and scFv2 CA atoms."""
    parser = PDBParser(QUIET=True)
    ref_structure = parser.get_structure('reference', ref_path)
    ref_residues = extract_residues(ref_structure)

    total_residues = len(ref_residues)
    print(f"Reference structure: {total_residues} residues")

    if total_residues < SCFV1_LENGTH + SCFV2_LENGTH:
        raise ValueError(f"Reference structure too short: {total_residues} residues")

    ref_scfv1 = ref_residues[:SCFV1_LENGTH]
    ref_scfv2 = ref_residues[-SCFV2_LENGTH:]

    ref_scfv1_ca = get_ca_atoms(ref_scfv1)
    ref_scfv2_ca = get_ca_atoms(ref_scfv2)

    print(f"  • scFv1: {len(ref_scfv1_ca)} CA atoms")
    print(f"  • scFv2: {len(ref_scfv2_ca)} CA atoms")

    return ref_scfv1_ca, ref_scfv2_ca


def analyze_single_target(
    ref_scfv1_ca: List[Atom],
    ref_scfv2_ca: List[Atom],
    target_path: str
) -> Dict:
    """Analyze a single target structure against reference domains."""
    result = {
        'Filename': Path(target_path).name,
        'Status': 'Failed',
        'Total_Residues': 0,
        'Linker_Length': 0,
        'scFv1_RMSD': None,
        'scFv2_RMSD': None,
        'Avg_RMSD': None,
        'Final_RMSD': None,
        'Best_Alignment': None,
        'Error': None
    }

    try:
        parser = PDBParser(QUIET=True)
        target_structure = parser.get_structure('target', target_path)
        target_residues = extract_residues(target_structure)

        total_residues = len(target_residues)
        result['Total_Residues'] = total_residues

        if total_residues < SCFV1_LENGTH + SCFV2_LENGTH:
            result['Error'] = f"Too few residues: {total_residues}"
            return result

        linker_length = total_residues - SCFV1_LENGTH - SCFV2_LENGTH
        result['Linker_Length'] = linker_length

        target_scfv1 = target_residues[:SCFV1_LENGTH]
        target_scfv2 = target_residues[-SCFV2_LENGTH:]

        target_scfv1_ca = get_ca_atoms(target_scfv1)
        target_scfv2_ca = get_ca_atoms(target_scfv2)
        target_all_ca = get_ca_atoms(target_residues)

        if len(target_scfv1_ca) != len(ref_scfv1_ca):
            result['Error'] = f"scFv1 atom mismatch: {len(target_scfv1_ca)} vs {len(ref_scfv1_ca)}"
            return result
        if len(target_scfv2_ca) != len(ref_scfv2_ca):
            result['Error'] = f"scFv2 atom mismatch: {len(target_scfv2_ca)} vs {len(ref_scfv2_ca)}"
            return result

        # Align on scFv1
        rmsd_scfv1, _ = align_and_calculate_rmsd(ref_scfv1_ca, target_scfv1_ca, target_all_ca)
        aligned_scfv2_ca = get_ca_atoms(target_residues[-SCFV2_LENGTH:])
        rmsd_scfv2_after_scfv1_align = calculate_rmsd_only(ref_scfv2_ca, aligned_scfv2_ca)

        # Reload for fresh coordinates
        parser2 = PDBParser(QUIET=True)
        target_structure2 = parser2.get_structure('target2', target_path)
        target_residues2 = extract_residues(target_structure2)
        target_scfv2_ca_2 = get_ca_atoms(target_residues2[-SCFV2_LENGTH:])
        target_all_ca_2 = get_ca_atoms(target_residues2)

        # Align on scFv2
        rmsd_scfv2, _ = align_and_calculate_rmsd(ref_scfv2_ca, target_scfv2_ca_2, target_all_ca_2)
        aligned_scfv1_ca = get_ca_atoms(target_residues2[:SCFV1_LENGTH])
        rmsd_scfv1_after_scfv2_align = calculate_rmsd_only(ref_scfv1_ca, aligned_scfv1_ca)

        avg_rmsd_align_scfv1 = (rmsd_scfv1 + rmsd_scfv2_after_scfv1_align) / 2
        avg_rmsd_align_scfv2 = (rmsd_scfv2 + rmsd_scfv1_after_scfv2_align) / 2

        if avg_rmsd_align_scfv1 <= avg_rmsd_align_scfv2:
            result['Final_RMSD'] = rmsd_scfv2_after_scfv1_align
            result['Best_Alignment'] = 'scFv1'
            result['scFv1_RMSD'] = rmsd_scfv1
            result['scFv2_RMSD'] = rmsd_scfv2_after_scfv1_align
        else:
            result['Final_RMSD'] = rmsd_scfv1_after_scfv2_align
            result['Best_Alignment'] = 'scFv2'
            result['scFv1_RMSD'] = rmsd_scfv1_after_scfv2_align
            result['scFv2_RMSD'] = rmsd_scfv2

        result['Avg_RMSD'] = min(avg_rmsd_align_scfv1, avg_rmsd_align_scfv2)
        result['Status'] = 'Success'

    except Exception as e:
        result['Error'] = str(e)

    return result

# =============================================================================
# SUMMARY
# =============================================================================

print("\n" + "-" * 70)
print("✓ Functions loaded successfully!")
print("-" * 70)
print(f"\nDomain structure:")
print(f"  • scFv1: {SCFV1_LENGTH} residues (N-terminal)")
print(f"  • scFv2: {SCFV2_LENGTH} residues (C-terminal)")
print("=" * 70)


In [None]:
# ================================================================================
# CELL 2: RMSD ANALYSIS (ColabFold vs External AlphaFold2)
# ================================================================================

#@title **Cell 2: RMSD Comparison** { display-mode: "form" }
#@markdown ### Configuration
#@markdown **Reference:** ColabFold에서 얻은 PDB (1개)
REFERENCE_PATH = "/content/reference.pdb"  #@param {type:"string"}

#@markdown **Predictions:** 외부 AlphaFold2에서 얻은 relaxed PDB들 (5개)
PREDICTIONS_FOLDER = "/content/predictions/"  #@param {type:"string"}

#@markdown ---
#@markdown **Requires:** Cell 1 must be run first.

print("=" * 70)
print("CELL 2: RMSD Analysis (ColabFold vs External AlphaFold2)")
print("=" * 70)

# =============================================================================
# VALIDATION
# =============================================================================

print("\n--- Validation ---")

if not Path(REFERENCE_PATH).exists():
    raise FileNotFoundError(f"Reference file not found: {REFERENCE_PATH}")
print(f"✓ Reference: {REFERENCE_PATH}")

pred_folder = Path(PREDICTIONS_FOLDER)
if not pred_folder.exists():
    raise FileNotFoundError(f"Predictions folder not found: {PREDICTIONS_FOLDER}")

pdb_files = sorted(pred_folder.glob("*.pdb"))
if not pdb_files:
    raise FileNotFoundError(f"No PDB files found in {PREDICTIONS_FOLDER}")
print(f"✓ Found {len(pdb_files)} prediction PDB files")

# =============================================================================
# LOAD REFERENCE & ANALYZE
# =============================================================================

print("\n" + "=" * 70)
print("LOADING REFERENCE STRUCTURE")
print("=" * 70)

ref_scfv1, ref_scfv2 = load_reference_structure(REFERENCE_PATH)

# Also get reference info
parser = PDBParser(QUIET=True)
ref_structure = parser.get_structure('ref', REFERENCE_PATH)
ref_residues = extract_residues(ref_structure)
ref_linker_length = len(ref_residues) - SCFV1_LENGTH - SCFV2_LENGTH

print(f"  • Linker length: {ref_linker_length} residues")
print(f"  • Total length: {len(ref_residues)} residues")

print("\n" + "=" * 70)
print("ANALYZING PREDICTIONS vs REFERENCE")
print("=" * 70)

results = []

for i, pdb_file in enumerate(pdb_files, 1):
    print(f"\n  [{i}/{len(pdb_files)}] {pdb_file.name}", end="")

    result = analyze_single_target(ref_scfv1, ref_scfv2, str(pdb_file))
    result['PDB_Path'] = str(pdb_file)

    if result['Status'] == 'Success':
        print(f" → RMSD: {result['Final_RMSD']:.4f} Å ({result['Best_Alignment']})")
    else:
        print(f" → FAILED: {result['Error']}")

    results.append(result)

# Create DataFrame
rmsd_df = pd.DataFrame(results)

# =============================================================================
# DISPLAY RESULTS
# =============================================================================

print("\n" + "=" * 70)
print("RMSD RESULTS")
print("=" * 70)

display_df = rmsd_df.copy()
display_df['Final_RMSD'] = display_df['Final_RMSD'].apply(
    lambda x: f"{x:.4f}" if pd.notna(x) else "N/A"
)
display_df['scFv1_RMSD'] = display_df['scFv1_RMSD'].apply(
    lambda x: f"{x:.4f}" if pd.notna(x) else "N/A"
)
display_df['scFv2_RMSD'] = display_df['scFv2_RMSD'].apply(
    lambda x: f"{x:.4f}" if pd.notna(x) else "N/A"
)

display_cols = ['Filename', 'Final_RMSD', 'scFv1_RMSD', 'scFv2_RMSD', 'Best_Alignment']
print(display_df[display_cols].to_string(index=False))

# Statistics
successful = rmsd_df[rmsd_df['Status'] == 'Success']
if len(successful) > 0:
    print(f"\nStatistics:")
    print(f"  • Mean RMSD: {successful['Final_RMSD'].mean():.4f} Å")
    print(f"  • Std RMSD:  {successful['Final_RMSD'].std():.4f} Å")
    print(f"  • Min RMSD:  {successful['Final_RMSD'].min():.4f} Å")
    print(f"  • Max RMSD:  {successful['Final_RMSD'].max():.4f} Å")

# Store for later cells
comparison_df = rmsd_df.copy()
comparison_df['Jobname'] = comparison_df['Filename'].apply(lambda x: x.replace('.pdb', ''))

print("\n" + "=" * 70)
print("✓ Cell 2 Complete")
print("=" * 70)
print(f"\nData saved to: rmsd_df, comparison_df ({len(rmsd_df)} rows)")
print("→ Run Cell 3 for NMA analysis")
print("=" * 70)


In [None]:
# ================================================================================
# CELL 3: NORMAL MODE ANALYSIS (NMA)
# ================================================================================

#@title **Cell 3: Normal Mode Analysis (NMA)** { display-mode: "form" }
#@markdown ### NMA Parameters
GNM_CUTOFF = 10.0  #@param {type:"number"}
ANM_CUTOFF = 15.0  #@param {type:"number"}
N_MODES = 20  #@param {type:"integer"}

#@markdown ---
#@markdown **Requires:** Cell 2 must be run first.

import warnings
warnings.filterwarnings('ignore')

print("=" * 70)
print("CELL 3: Normal Mode Analysis (NMA)")
print("=" * 70)

# Install ProDy
try:
    import prody
    from prody import *
except ImportError:
    print("Installing ProDy...")
    import subprocess
    subprocess.run(['pip', 'install', '-q', 'prody'], capture_output=True)
    import prody
    from prody import *

prody.confProDy(verbosity='warning')
print(f"ProDy Version: {prody.__version__}")

# =============================================================================
# NMA ANALYZER CLASS
# =============================================================================

class LinkerFlexibilityAnalyzer:
    def __init__(self, structure_path: str, name: str = None):
        self.name = name or Path(structure_path).stem
        self.structure = parsePDB(structure_path)
        self.calphas = self.structure.select('calpha')

        if self.calphas is None:
            raise ValueError(f"No CA atoms found in {structure_path}")

        self.n_residues = self.calphas.numAtoms()
        self.domain_indices = self._calculate_domain_indices()
        self.linker_length = self.domain_indices['linker_length']

        self._gnm = None
        self._anm = None
        self.results = {}

    def _calculate_domain_indices(self):
        linker_length = self.n_residues - SCFV1_LENGTH - SCFV2_LENGTH
        if linker_length <= 0:
            raise ValueError(f"Invalid total length {self.n_residues}")

        return {
            'scfv1': (0, SCFV1_LENGTH),
            'linker': (SCFV1_LENGTH, SCFV1_LENGTH + linker_length),
            'scfv2': (SCFV1_LENGTH + linker_length, self.n_residues),
            'linker_length': linker_length
        }

    def run_gnm(self, cutoff=10.0, n_modes=20):
        self._gnm = GNM(self.name)
        self._gnm.buildKirchhoff(self.calphas, cutoff=cutoff)
        self._gnm.calcModes(n_modes=n_modes)
        return self._gnm

    def calc_msf_gnm(self):
        if self._gnm is None:
            self.run_gnm()

        msf = calcSqFlucts(self._gnm)
        self.results['msf_gnm'] = msf
        self.results['msf_gnm_mean'] = np.mean(msf)

        for domain, (start, end) in [('scfv1', self.domain_indices['scfv1']),
                                      ('scfv2', self.domain_indices['scfv2']),
                                      ('linker', self.domain_indices['linker'])]:
            domain_msf = msf[start:end]
            self.results[f'{domain}_msf_mean'] = np.mean(domain_msf)
            self.results[f'{domain}_msf_std'] = np.std(domain_msf)

        return msf

    def run_anm(self, cutoff=15.0, n_modes=20):
        self._anm = ANM(self.name)
        self._anm.buildHessian(self.calphas, cutoff=cutoff)
        self._anm.calcModes(n_modes=n_modes)
        return self._anm

    def calc_deformability(self, n_modes=None):
        if self._anm is None:
            self.run_anm()

        if n_modes is None:
            n_modes = len(self._anm)

        eigenvalues = self._anm.getEigvals()[:n_modes]
        eigenvectors = self._anm.getEigvecs()[:, :n_modes]

        eigenvectors_reshaped = eigenvectors.reshape(self.n_residues, 3, n_modes)
        u_squared = np.sum(eigenvectors_reshaped**2, axis=1)
        deformability_raw = np.sum(u_squared / eigenvalues**2, axis=1)

        deformability = (deformability_raw - deformability_raw.min()) / \
                       (deformability_raw.max() - deformability_raw.min())

        self.results['deformability'] = deformability
        self.results['deformability_mean'] = np.mean(deformability)

        for domain, (start, end) in [('linker', self.domain_indices['linker'])]:
            domain_def = deformability[start:end]
            self.results[f'{domain}_deformability_mean'] = np.mean(domain_def)

        return deformability

    def calc_stiffness(self):
        if self._gnm is None:
            self.run_gnm()

        eigenvalues = self._gnm.getEigvals()[:10]
        stiffness = np.mean(eigenvalues)
        self.results['stiffness'] = stiffness

        return stiffness

    def calc_inter_domain_distance_fluctuation(self, n_modes=20):
        if self._anm is None:
            self.run_anm(n_modes=n_modes)

        coords = self.calphas.getCoords()
        scfv1_start, scfv1_end = self.domain_indices['scfv1']
        scfv2_start, scfv2_end = self.domain_indices['scfv2']

        com1_initial = np.mean(coords[scfv1_start:scfv1_end], axis=0)
        com2_initial = np.mean(coords[scfv2_start:scfv2_end], axis=0)
        initial_distance = np.linalg.norm(com2_initial - com1_initial)

        eigenvectors = self._anm.getEigvecs()
        eigenvalues = self._anm.getEigvals()
        n_modes_available = min(n_modes, len(self._anm))

        distance_changes = []
        for mode_idx in range(n_modes_available):
            mode_vector = eigenvectors[:, mode_idx].reshape(-1, 3)
            amplitude = 5.0 / np.sqrt(eigenvalues[mode_idx])
            displaced_coords = coords + amplitude * mode_vector

            com1_displaced = np.mean(displaced_coords[scfv1_start:scfv1_end], axis=0)
            com2_displaced = np.mean(displaced_coords[scfv2_start:scfv2_end], axis=0)
            new_distance = np.linalg.norm(com2_displaced - com1_displaced)
            distance_changes.append(new_distance - initial_distance)

        weights = 1.0 / eigenvalues[:n_modes_available]
        weights = weights / np.sum(weights)
        mean_change = np.average(distance_changes, weights=weights)
        variance = np.average((np.array(distance_changes) - mean_change)**2, weights=weights)
        inter_domain_fluctuation = np.sqrt(variance)

        self.results['inter_domain_distance_initial'] = initial_distance
        self.results['inter_domain_fluctuation'] = inter_domain_fluctuation

        return inter_domain_fluctuation

    def run_full_analysis(self, gnm_cutoff=10.0, anm_cutoff=15.0, n_modes=20):
        self.run_gnm(cutoff=gnm_cutoff, n_modes=n_modes)
        self.calc_msf_gnm()
        self.calc_stiffness()

        self.run_anm(cutoff=anm_cutoff, n_modes=n_modes)
        self.calc_deformability(n_modes=n_modes)
        self.calc_inter_domain_distance_fluctuation(n_modes=n_modes)

        return self.results

    def get_summary_dict(self):
        return {
            'Filename': self.name,
            'Linker_Length': self.linker_length,
            'Effective_Stiffness': self.results.get('stiffness', np.nan),
            'Inter_Domain_Fluctuation': self.results.get('inter_domain_fluctuation', np.nan),
            'Initial_COM_Distance': self.results.get('inter_domain_distance_initial', np.nan),
            'MSF_Linker_Mean': self.results.get('linker_msf_mean', np.nan),
            'Deformability_Linker_Mean': self.results.get('linker_deformability_mean', np.nan),
        }

# =============================================================================
# RUN NMA ON ALL PREDICTIONS
# =============================================================================

print("\n--- Running NMA Analysis ---")

try:
    _ = rmsd_df
    print(f"✓ Found {len(rmsd_df)} structures to analyze")
except NameError:
    raise RuntimeError("rmsd_df not found. Run Cell 2 first.")

nma_results = []

for idx, row in rmsd_df.iterrows():
    pdb_path = row['PDB_Path']
    filename = row['Filename']

    print(f"\n  [{idx+1}/{len(rmsd_df)}] {filename}", end="")

    try:
        analyzer = LinkerFlexibilityAnalyzer(pdb_path, name=filename.replace('.pdb', ''))
        analyzer.run_full_analysis(gnm_cutoff=GNM_CUTOFF, anm_cutoff=ANM_CUTOFF, n_modes=N_MODES)

        summary = analyzer.get_summary_dict()
        nma_results.append(summary)

        print(f" → Stiffness: {summary['Effective_Stiffness']:.6f}")

    except Exception as e:
        print(f" → ERROR: {e}")
        nma_results.append({
            'Filename': filename.replace('.pdb', ''),
            'Error': str(e)
        })

nma_df = pd.DataFrame(nma_results)

# =============================================================================
# DISPLAY RESULTS
# =============================================================================

print("\n" + "=" * 70)
print("NMA RESULTS")
print("=" * 70)

display_cols = ['Filename', 'Effective_Stiffness', 'Inter_Domain_Fluctuation',
                'MSF_Linker_Mean', 'Deformability_Linker_Mean']
display_cols = [c for c in display_cols if c in nma_df.columns]

display_df = nma_df[display_cols].copy()
for col in display_df.columns:
    if col != 'Filename':
        display_df[col] = display_df[col].apply(
            lambda x: f"{x:.4f}" if pd.notna(x) else "N/A"
        )

print(display_df.to_string(index=False))

print("\n" + "=" * 70)
print("✓ Cell 3 Complete")
print("=" * 70)
print(f"\nData saved to: nma_df ({len(nma_df)} rows)")
print("→ Run Cell 4 for pLDDT/PAE analysis")
print("=" * 70)


In [None]:
# ================================================================================
# CELL 4: pLDDT & PAE ANALYSIS (From AlphaFold JSON files)
# ================================================================================

#@title **Cell 4: pLDDT & PAE Analysis** { display-mode: "form" }
#@markdown ### Configuration
#@markdown JSON files should be in the same folder as PDB files or specify separately:
JSON_FOLDER = "/content/predictions/"  #@param {type:"string"}

#@markdown ---
#@markdown **Note:** AlphaFold2 generates JSON files with pLDDT and PAE data.
#@markdown
#@markdown **Requires:** Cell 2 must be run first.

import json
import os

print("=" * 70)
print("CELL 4: pLDDT & PAE Analysis")
print("=" * 70)

# =============================================================================
# ANALYSIS FUNCTIONS
# =============================================================================

def load_alphafold_json(filepath):
    try:
        with open(filepath, 'r') as f:
            return json.load(f)
    except Exception as e:
        return None


def calculate_domain_indices(total_length):
    linker_length = total_length - SCFV1_LENGTH - SCFV2_LENGTH
    if linker_length <= 0:
        return None

    return {
        'scfv1': (0, SCFV1_LENGTH),
        'linker': (SCFV1_LENGTH, SCFV1_LENGTH + linker_length),
        'scfv2': (SCFV1_LENGTH + linker_length, total_length),
        'linker_length': linker_length,
    }


def extract_plddt_metrics(plddt, indices):
    plddt_array = np.array(plddt)

    linker_plddt = plddt_array[indices['linker'][0]:indices['linker'][1]]
    scfv1_plddt = plddt_array[indices['scfv1'][0]:indices['scfv1'][1]]
    scfv2_plddt = plddt_array[indices['scfv2'][0]:indices['scfv2'][1]]

    return {
        'pLDDT_Global_Mean': float(np.mean(plddt_array)),
        'pLDDT_scFv1_Mean': float(np.mean(scfv1_plddt)),
        'pLDDT_scFv2_Mean': float(np.mean(scfv2_plddt)),
        'pLDDT_Linker_Mean': float(np.mean(linker_plddt)),
        'pLDDT_Linker_Min': float(np.min(linker_plddt)),
    }


def extract_pae_metrics(pae, indices):
    pae_array = np.array(pae)

    scfv1_s, scfv1_e = indices['scfv1']
    scfv2_s, scfv2_e = indices['scfv2']
    linker_s, linker_e = indices['linker']

    # Inter-domain PAE
    pae_12 = pae_array[scfv1_s:scfv1_e, scfv2_s:scfv2_e]
    pae_21 = pae_array[scfv2_s:scfv2_e, scfv1_s:scfv1_e]
    inter_domain = (np.mean(pae_12) + np.mean(pae_21)) / 2

    # Linker stability
    linker_to_1 = np.mean(pae_array[linker_s:linker_e, scfv1_s:scfv1_e])
    linker_to_2 = np.mean(pae_array[linker_s:linker_e, scfv2_s:scfv2_e])

    return {
        'PAE_Inter_Domain': float(inter_domain),
        'PAE_Linker_to_scFv1': float(linker_to_1),
        'PAE_Linker_to_scFv2': float(linker_to_2),
        'PAE_Linker_Stability': float((linker_to_1 + linker_to_2) / 2),
    }


def analyze_json_file(json_path, name=None):
    data = load_alphafold_json(json_path)
    if data is None:
        return None

    plddt = data.get('plddt', [])
    if not plddt:
        return None

    total_length = len(plddt)
    indices = calculate_domain_indices(total_length)
    if indices is None:
        return None

    results = {
        'Filename': name or Path(json_path).stem,
        'Total_Length': total_length,
        'pTM_Score': data.get('ptm', data.get('pTM', data.get('ranking_confidence', np.nan))),
    }

    results.update(extract_plddt_metrics(plddt, indices))

    pae = data.get('pae', data.get('predicted_aligned_error', None))
    if pae is not None:
        results.update(extract_pae_metrics(pae, indices))

    return results

# =============================================================================
# FIND AND ANALYZE JSON FILES
# =============================================================================

print("\n--- Finding JSON Files ---")

json_folder = Path(JSON_FOLDER)
json_files = list(json_folder.glob("*.json"))

# Also try common AlphaFold JSON naming patterns
if not json_files:
    json_files = list(json_folder.glob("*_scores*.json"))
if not json_files:
    json_files = list(json_folder.glob("ranking_debug.json"))

print(f"Found {len(json_files)} JSON files")

# If no JSON files, try to extract pLDDT from B-factor in PDB files
USE_PDB_BFACTOR = len(json_files) == 0

if USE_PDB_BFACTOR:
    print("\n⚠️ No JSON files found. Extracting pLDDT from PDB B-factors...")

    plddt_results = []

    for idx, row in rmsd_df.iterrows():
        pdb_path = row['PDB_Path']
        filename = row['Filename']

        print(f"\n  [{idx+1}/{len(rmsd_df)}] {filename}", end="")

        try:
            parser = PDBParser(QUIET=True)
            structure = parser.get_structure('pdb', pdb_path)

            # Extract B-factors (pLDDT in AlphaFold PDBs)
            bfactors = []
            for model in structure:
                for chain in model:
                    for residue in chain:
                        if is_aa(residue, standard=True) and 'CA' in residue:
                            bfactors.append(residue['CA'].get_bfactor())

            if len(bfactors) > 0:
                total_length = len(bfactors)
                indices = calculate_domain_indices(total_length)

                if indices:
                    plddt_array = np.array(bfactors)
                    linker_plddt = plddt_array[indices['linker'][0]:indices['linker'][1]]

                    result = {
                        'Filename': filename.replace('.pdb', ''),
                        'pLDDT_Global_Mean': np.mean(plddt_array),
                        'pLDDT_Linker_Mean': np.mean(linker_plddt),
                        'pLDDT_Linker_Min': np.min(linker_plddt),
                    }
                    plddt_results.append(result)
                    print(f" → pLDDT: {result['pLDDT_Global_Mean']:.1f}")
                else:
                    print(f" → Invalid length")
            else:
                print(f" → No B-factors")

        except Exception as e:
            print(f" → ERROR: {e}")

    plddt_df = pd.DataFrame(plddt_results) if plddt_results else pd.DataFrame()

else:
    print("\n--- Analyzing JSON Files ---")

    plddt_results = []

    for idx, json_file in enumerate(json_files, 1):
        print(f"\n  [{idx}/{len(json_files)}] {json_file.name}", end="")

        result = analyze_json_file(str(json_file), name=json_file.stem)

        if result:
            plddt_results.append(result)
            ptm = result.get('pTM_Score', np.nan)
            ptm_str = f"{ptm:.3f}" if pd.notna(ptm) else "N/A"
            print(f" → pTM: {ptm_str}, pLDDT: {result['pLDDT_Global_Mean']:.1f}")
        else:
            print(f" → Failed to parse")

    plddt_df = pd.DataFrame(plddt_results) if plddt_results else pd.DataFrame()

# =============================================================================
# DISPLAY RESULTS
# =============================================================================

print("\n" + "=" * 70)
print("pLDDT/PAE RESULTS")
print("=" * 70)

if len(plddt_df) > 0:
    display_cols = ['Filename', 'pTM_Score', 'pLDDT_Global_Mean', 'pLDDT_Linker_Mean',
                   'PAE_Inter_Domain', 'PAE_Linker_Stability']
    display_cols = [c for c in display_cols if c in plddt_df.columns]

    display_df = plddt_df[display_cols].copy()
    for col in display_df.columns:
        if col != 'Filename':
            display_df[col] = display_df[col].apply(
                lambda x: f"{x:.2f}" if pd.notna(x) else "N/A"
            )

    print(display_df.to_string(index=False))
else:
    print("⚠️ No pLDDT/PAE data available")

print("\n" + "=" * 70)
print("✓ Cell 4 Complete")
print("=" * 70)
print(f"\nData saved to: plddt_df ({len(plddt_df)} rows)")
print("→ Run Cell 5 for Antigen Distance analysis")
print("=" * 70)


In [None]:
# ================================================================================
# CELL 5: ANTIGEN DISTANCE MEASUREMENT (130Å Target)
# ================================================================================

#@title **Cell 5: Antigen Distance (130Å Target)** { display-mode: "form" }
#@markdown ### Configuration
ANTIGEN_COMPLEX_PATH = "/content/Include_antigen_with_perfect_distance.pdb"  #@param {type:"string"}

#@markdown ### Antigen Distance Settings
CHAIN_B_RESNUM = 102  #@param {type:"integer"}
CHAIN_D_RESNUM = 652  #@param {type:"integer"}
TARGET_DISTANCE = 130.0  #@param {type:"number"}

#@markdown ---
#@markdown **Requires:** Cell 2 must be run first.

print("=" * 70)
print("CELL 5: Antigen Distance Measurement")
print("=" * 70)
print(f"\nTarget: {TARGET_DISTANCE} Å (CD3e-HER2 membrane distance)")

# =============================================================================
# VALIDATION
# =============================================================================

print("\n--- Validation ---")

if not Path(ANTIGEN_COMPLEX_PATH).exists():
    print(f"⚠️ Antigen complex not found: {ANTIGEN_COMPLEX_PATH}")
    print("  Antigen distance measurement will be skipped.")
    ANTIGEN_AVAILABLE = False
else:
    print(f"✓ Antigen complex found")
    ANTIGEN_AVAILABLE = True

# =============================================================================
# LOAD ANTIGEN-ANTIBODY COMPLEX
# =============================================================================

if ANTIGEN_AVAILABLE:
    print("\n" + "=" * 70)
    print("LOADING ANTIGEN-ANTIBODY COMPLEX")
    print("=" * 70)

    parser = PDBParser(QUIET=True)
    antigen_complex = parser.get_structure('antigen_complex', ANTIGEN_COMPLEX_PATH)

    chain_A_residues = []
    chain_B_residues = []
    chain_C_residues = []
    chain_D_residues = []

    for model in antigen_complex:
        for chain in model:
            chain_id = chain.get_id()
            for residue in chain:
                if is_aa(residue, standard=True):
                    if chain_id == 'A':
                        chain_A_residues.append(residue)
                    elif chain_id == 'B':
                        chain_B_residues.append(residue)
                    elif chain_id == 'C':
                        chain_C_residues.append(residue)
                    elif chain_id == 'D':
                        chain_D_residues.append(residue)

    print(f"  Chain A (scFv1): {len(chain_A_residues)} residues")
    print(f"  Chain B (CD3e): {len(chain_B_residues)} residues")
    print(f"  Chain C (scFv2): {len(chain_C_residues)} residues")
    print(f"  Chain D (HER2): {len(chain_D_residues)} residues")

    chain_A_ca = get_ca_atoms(chain_A_residues)
    chain_C_ca = get_ca_atoms(chain_C_residues)

    # Find target residues
    chain_B_target_ca = None
    chain_D_target_ca = None

    for model in antigen_complex:
        if 'B' in [c.get_id() for c in model]:
            for residue in model['B']:
                if residue.get_id()[1] == CHAIN_B_RESNUM and 'CA' in residue:
                    chain_B_target_ca = residue['CA']
                    break

        if 'D' in [c.get_id() for c in model]:
            for residue in model['D']:
                if residue.get_id()[1] == CHAIN_D_RESNUM and 'CA' in residue:
                    chain_D_target_ca = residue['CA']
                    break

    if chain_B_target_ca and chain_D_target_ca:
        original_distance = chain_B_target_ca - chain_D_target_ca
        print(f"\n  ✓ Target residues found")
        print(f"  Original complex distance: {original_distance:.2f} Å")
    else:
        print(f"  ⚠️ Could not find target residues")
        ANTIGEN_AVAILABLE = False

# =============================================================================
# MEASURE ANTIGEN DISTANCE FOR EACH PDB
# =============================================================================

if ANTIGEN_AVAILABLE:
    print("\n" + "=" * 70)
    print("MEASURING ANTIGEN DISTANCES")
    print("=" * 70)

    antigen_results = []

    for idx, row in rmsd_df.iterrows():
        pdb_path = row['PDB_Path']
        filename = row['Filename']

        print(f"\n  [{idx+1}/{len(rmsd_df)}] {filename}", end="")

        try:
            parser = PDBParser(QUIET=True)
            af_structure = parser.get_structure('af', pdb_path)
            af_residues = extract_residues(af_structure)

            af_scfv1_ca = get_ca_atoms(af_residues[:SCFV1_LENGTH])
            af_scfv2_ca = get_ca_atoms(af_residues[-SCFV2_LENGTH:])

            if len(af_scfv1_ca) != len(chain_A_ca) or len(af_scfv2_ca) != len(chain_C_ca):
                print(f" → Atom mismatch")
                continue

            # Superimpose Chain A → AlphaFold scFv1
            sup1 = Superimposer()
            sup1.set_atoms(af_scfv1_ca, chain_A_ca)

            chain_B_coord = chain_B_target_ca.get_coord().copy()
            chain_B_transformed = np.dot(chain_B_coord, sup1.rotran[0].T) + sup1.rotran[1]

            # Superimpose Chain C → AlphaFold scFv2
            sup2 = Superimposer()
            sup2.set_atoms(af_scfv2_ca, chain_C_ca)

            chain_D_coord = chain_D_target_ca.get_coord().copy()
            chain_D_transformed = np.dot(chain_D_coord, sup2.rotran[0].T) + sup2.rotran[1]

            # Calculate distance
            antigen_distance = np.linalg.norm(chain_B_transformed - chain_D_transformed)
            distance_from_target = antigen_distance - TARGET_DISTANCE

            print(f" → {antigen_distance:.2f} Å (Δ: {distance_from_target:+.2f} Å)")

            antigen_results.append({
                'Filename': filename.replace('.pdb', ''),
                'Antigen_Distance': antigen_distance,
                'Distance_From_Target': distance_from_target,
                'Abs_Error': abs(distance_from_target),
            })

        except Exception as e:
            print(f" → ERROR: {e}")

    antigen_df = pd.DataFrame(antigen_results)
else:
    antigen_df = pd.DataFrame()

# =============================================================================
# DISPLAY RESULTS
# =============================================================================

print("\n" + "=" * 70)
print("ANTIGEN DISTANCE RESULTS")
print("=" * 70)

if len(antigen_df) > 0:
    display_df = antigen_df.copy()
    display_df['Antigen_Distance'] = display_df['Antigen_Distance'].apply(lambda x: f"{x:.2f}")
    display_df['Distance_From_Target'] = display_df['Distance_From_Target'].apply(lambda x: f"{x:+.2f}")

    print(f"\n(Target: {TARGET_DISTANCE} Å)")
    print(display_df[['Filename', 'Antigen_Distance', 'Distance_From_Target']].to_string(index=False))

    # Best candidate
    best_idx = antigen_df['Abs_Error'].idxmin()
    best = antigen_df.loc[best_idx]
    print(f"\n★ Closest to target: {best['Filename']}")
    print(f"  Distance: {best['Antigen_Distance']:.2f} Å (Δ: {best['Distance_From_Target']:+.2f} Å)")
else:
    print("⚠️ No antigen distance data available")

print("\n" + "=" * 70)
print("✓ Cell 5 Complete")
print("=" * 70)
print(f"\nData saved to: antigen_df ({len(antigen_df)} rows)")
print("→ Run Cell 6 for Final Summary")
print("=" * 70)


In [None]:
# ================================================================================
# CELL 6: FINAL SUMMARY - ALL METRICS TABLE
# ================================================================================

#@title **Cell 6: Final Summary Table** { display-mode: "form" }
#@markdown Combines all analysis results into one comprehensive table.

print("=" * 70)
print("CELL 6: Final Summary - All Metrics")
print("=" * 70)

# =============================================================================
# MERGE ALL DATAFRAMES
# =============================================================================

print("\n--- Collecting Results ---")

# Start with RMSD results
final_df = pd.DataFrame()

try:
    final_df = rmsd_df[['Filename', 'Final_RMSD', 'scFv1_RMSD', 'scFv2_RMSD',
                        'Linker_Length', 'Best_Alignment', 'PDB_Path']].copy()
    final_df['Jobname'] = final_df['Filename'].apply(lambda x: x.replace('.pdb', ''))
    print(f"  ✓ RMSD: {len(final_df)} entries")
except Exception as e:
    print(f"  ✗ RMSD data error: {e}")

# Merge NMA results
try:
    if 'nma_df' in dir() and len(nma_df) > 0:
        nma_merge = nma_df.copy()
        nma_merge['Jobname'] = nma_merge['Filename'].apply(lambda x: x.replace('.pdb', ''))
        nma_cols = ['Jobname', 'Effective_Stiffness', 'Inter_Domain_Fluctuation',
                   'Initial_COM_Distance', 'MSF_Linker_Mean', 'Deformability_Linker_Mean']
        nma_cols = [c for c in nma_cols if c in nma_merge.columns]
        final_df = final_df.merge(nma_merge[nma_cols], on='Jobname', how='left')
        print(f"  ✓ NMA: merged")
except Exception as e:
    print(f"  ⚠️ NMA merge error: {e}")

# Merge pLDDT/PAE results
try:
    if 'plddt_df' in dir() and len(plddt_df) > 0:
        plddt_merge = plddt_df.copy()
        plddt_merge['Jobname'] = plddt_merge['Filename'].apply(lambda x: x.replace('.pdb', ''))
        plddt_cols = ['Jobname', 'pTM_Score', 'pLDDT_Global_Mean', 'pLDDT_scFv1_Mean',
                     'pLDDT_scFv2_Mean', 'pLDDT_Linker_Mean', 'pLDDT_Linker_Min',
                     'PAE_Inter_Domain', 'PAE_Linker_to_scFv1', 'PAE_Linker_to_scFv2',
                     'PAE_Linker_Stability']
        plddt_cols = [c for c in plddt_cols if c in plddt_merge.columns]
        final_df = final_df.merge(plddt_merge[plddt_cols], on='Jobname', how='left')
        print(f"  ✓ pLDDT/PAE: merged")
except Exception as e:
    print(f"  ⚠️ pLDDT/PAE merge error: {e}")

# Merge Antigen Distance results
try:
    if 'antigen_df' in dir() and len(antigen_df) > 0:
        antigen_merge = antigen_df.copy()
        antigen_merge['Jobname'] = antigen_merge['Filename'].apply(lambda x: x.replace('.pdb', ''))
        antigen_cols = ['Jobname', 'Antigen_Distance', 'Distance_From_Target']
        final_df = final_df.merge(antigen_merge[antigen_cols], on='Jobname', how='left')
        print(f"  ✓ Antigen Distance: merged")
except Exception as e:
    print(f"  ⚠️ Antigen distance merge error: {e}")

# =============================================================================
# ORGANIZE COLUMNS
# =============================================================================

# Define column order for clean display
column_order = [
    # Identification
    'Jobname',

    # RMSD metrics
    'Final_RMSD',
    'scFv1_RMSD',
    'scFv2_RMSD',
    'Best_Alignment',
    'Linker_Length',

    # Antigen Distance (130Å target)
    'Antigen_Distance',
    'Distance_From_Target',

    # pLDDT metrics
    'pTM_Score',
    'pLDDT_Global_Mean',
    'pLDDT_scFv1_Mean',
    'pLDDT_scFv2_Mean',
    'pLDDT_Linker_Mean',
    'pLDDT_Linker_Min',

    # PAE metrics
    'PAE_Inter_Domain',
    'PAE_Linker_to_scFv1',
    'PAE_Linker_to_scFv2',
    'PAE_Linker_Stability',

    # NMA metrics
    'Effective_Stiffness',
    'Inter_Domain_Fluctuation',
    'Initial_COM_Distance',
    'MSF_Linker_Mean',
    'Deformability_Linker_Mean',
]

# Filter to available columns
available_cols = [c for c in column_order if c in final_df.columns]
final_df = final_df[available_cols]

# =============================================================================
# DISPLAY FULL TABLE
# =============================================================================

print("\n" + "=" * 70)
print("ALL METRICS TABLE")
print("=" * 70)

# Set pandas display options for full view
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', None)
pd.set_option('display.float_format', lambda x: f'{x:.4f}' if pd.notna(x) else 'NaN')

print(f"\nTotal entries: {len(final_df)}")
print(f"Total columns: {len(final_df.columns)}")
print("\n")

# Display the full dataframe
print(final_df.to_string(index=False))

# =============================================================================
# COLUMN DESCRIPTIONS
# =============================================================================

print("\n" + "=" * 70)
print("COLUMN DESCRIPTIONS")
print("=" * 70)

descriptions = {
    'Jobname': 'Sample name',
    'Final_RMSD': 'RMSD vs Reference (Å) - lower is more similar',
    'scFv1_RMSD': 'scFv1 domain RMSD (Å)',
    'scFv2_RMSD': 'scFv2 domain RMSD (Å)',
    'Best_Alignment': 'Which domain was used for alignment',
    'Linker_Length': 'Linker length (residues)',
    'Antigen_Distance': 'CD3e-HER2 membrane distance (Å) - target: 130Å',
    'Distance_From_Target': 'Difference from 130Å target (Å)',
    'pTM_Score': 'Predicted TM score (0-1) - higher is better',
    'pLDDT_Global_Mean': 'Global pLDDT (0-100) - higher is better',
    'pLDDT_scFv1_Mean': 'scFv1 pLDDT mean',
    'pLDDT_scFv2_Mean': 'scFv2 pLDDT mean',
    'pLDDT_Linker_Mean': 'Linker pLDDT mean - higher is more confident',
    'pLDDT_Linker_Min': 'Linker pLDDT minimum',
    'PAE_Inter_Domain': 'Inter-domain PAE (Å) - lower is better',
    'PAE_Linker_to_scFv1': 'Linker to scFv1 PAE (Å)',
    'PAE_Linker_to_scFv2': 'Linker to scFv2 PAE (Å)',
    'PAE_Linker_Stability': 'Average linker attachment PAE - lower is better',
    'Effective_Stiffness': 'GNM stiffness (eigenvalue mean) - higher is more rigid',
    'Inter_Domain_Fluctuation': 'Domain distance fluctuation (Å) - lower is more rigid',
    'Initial_COM_Distance': 'Initial scFv1-scFv2 COM distance (Å)',
    'MSF_Linker_Mean': 'Linker mean square fluctuation - lower is more rigid',
    'Deformability_Linker_Mean': 'Linker deformability (0-1) - lower is more rigid',
}

for col in available_cols:
    if col in descriptions:
        print(f"  • {col}: {descriptions[col]}")

# =============================================================================
# SAVE TO CSV (optional)
# =============================================================================

csv_path = '/content/final_analysis_results.csv'
final_df.to_csv(csv_path, index=False)
print(f"\n✓ Saved to: {csv_path}")

# =============================================================================
# SUMMARY STATISTICS
# =============================================================================

print("\n" + "=" * 70)
print("SUMMARY STATISTICS")
print("=" * 70)

numeric_cols = final_df.select_dtypes(include=[np.number]).columns.tolist()

for col in numeric_cols:
    if final_df[col].notna().any():
        vals = final_df[col].dropna()
        print(f"\n  {col}:")
        print(f"    Mean: {vals.mean():.4f}")
        print(f"    Std:  {vals.std():.4f}")
        print(f"    Min:  {vals.min():.4f}")
        print(f"    Max:  {vals.max():.4f}")

print("\n" + "=" * 70)
print("✓ ANALYSIS COMPLETE")
print("=" * 70)
print("\nDataFrame 'final_df' contains all metrics.")
print("Use final_df.to_clipboard() to copy to Excel.")
print("=" * 70)
