<a href="https://colab.research.google.com/github/JKourelis/Colab_Boltz-2/blob/main/Boltz_2_csv.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<img src="https://raw.githubusercontent.com/jwohlwend/boltz/main/docs/boltz2_title.png" height="200" align="right" style="height:240px">

## Boltz-2: Democratizing Biomolecular Interaction Modeling

Easy to use protein structure and binding affinity prediction using [Boltz-2](https://doi.org/10.1101/2025.06.14.659707). Boltz-2 is a biomolecular foundation model that jointly models complex structures and binding affinities, approaching [AlphaFold3](https://www.nature.com/articles/s41586-024-07487-w) accuracy while running 1000x faster than physics-based methods.

**Key Features:**
- **Structure Prediction**: Protein, DNA, RNA, and ligand complexes with AlphaFold3-level accuracy
- **Binding Affinity**: First deep learning model to approach FEP accuracy for drug discovery
- **Open Source**: MIT license for academic and commercial use
- **Fast**: 1000x faster than traditional physics-based methods

**Usage Options:**
1. **Manual Input**: Enter sequences directly in the configuration boxes below
2. **FASTA Upload**: Upload FASTA files for batch processing

**Repository:**
- [Boltz-2 Colab Repository](https://github.com/JKourelis/Colab_Boltz-2)

**Citations:**

[Wohlwend J, Corso G, Passaro S, et al. Boltz-1: Democratizing Biomolecular Interaction Modeling. *bioRxiv*, 2024](https://doi.org/10.1101/2024.11.19.624167)

[Passaro S, Corso G, Wohlwend J, et al. Boltz-2: Towards Accurate and Efficient Binding Affinity Prediction. *bioRxiv*, 2025](https://doi.org/10.1101/2025.06.14.659707)

If using automatic MSA generation: [Mirdita M, Schütze K, Moriwaki Y, et al. ColabFold: making protein folding accessible to all. *Nature Methods*, 2022](https://doi.org/10.1038/s41592-022-01488-1)

In [1]:
#@title Cell 1: Install Boltz-2 with cuEquivariance Kernel Test
%%time
import subprocess
import sys
import os
import re

# Restart marker to handle Colab Feb 2025 NumPy issue
restart_marker = "/content/.boltz_numpy_restart"
is_post_restart = os.path.exists(restart_marker)

def run_cmd(cmd, desc):
    """Execute command with output suppression unless error"""
    print(f"[{desc}]")
    result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
    if result.returncode != 0:
        print(f"FAILED: {result.stderr[:300]}")
        return False
    print("OK")
    return True

def get_cuda_version():
    """Detect CUDA version from nvidia-smi"""
    try:
        result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)
        if result.returncode == 0:
            match = re.search(r'CUDA Version: (\d+\.\d+)', result.stdout)
            if match:
                version = match.group(1)
                major = int(version.split('.')[0])
                minor = int(version.split('.')[1])
                return major, minor, version
    except Exception as e:
        print(f"⚠️  Could not detect CUDA version: {e}")
    return None, None, None

def test_cuequivariance_kernels():
    """Test if cuEquivariance triangle kernels are available"""
    print("\n" + "=" * 60)
    print("CUEQUIVARIANCE KERNEL PREFLIGHT TEST")
    print("=" * 60)

    try:
        import torch
        print(f"✅ PyTorch: {torch.__version__}")
        print(f"✅ CUDA available: {torch.cuda.is_available()}")
        if torch.cuda.is_available():
            print(f"✅ CUDA version: {torch.version.cuda}")
            print(f"✅ GPU: {torch.cuda.get_device_name(0)}")
    except Exception as e:
        print(f"❌ PyTorch check failed: {e}")
        return False

    # Test cuequivariance-torch import
    try:
        import cuequivariance_torch
        print(f"✅ cuequivariance-torch installed")
    except ImportError as e:
        print(f"⚠️  cuequivariance-torch not found: {e}")
        return False

    # Test cuequivariance-ops-torch-cu12 import
    try:
        import cuequivariance_ops_torch
        print(f"✅ cuequivariance-ops-torch-cu12 installed")
    except ImportError as e:
        print(f"⚠️  cuequivariance-ops-torch-cu12 not found: {e}")
        return False

    # CRITICAL TEST: triangle_multiplicative_update
    try:
        from cuequivariance_ops_torch.triangle import triangle_multiplicative_update
        print(f"✅ triangle_multiplicative_update import: SUCCESS")

        if callable(triangle_multiplicative_update):
            print(f"✅ triangle_multiplicative_update is callable")
        else:
            print(f"❌ triangle_multiplicative_update exists but is not callable")
            return False

    except ImportError as e:
        print(f"❌ triangle_multiplicative_update import FAILED: {e}")
        print(f"   This error requires --no_kernels flag")
        return False
    except Exception as e:
        print(f"❌ Unexpected error testing triangle kernels: {e}")
        return False

    # Test triangle_attention_update
    try:
        from cuequivariance_ops_torch.triangle import triangle_attention_update
        print(f"✅ triangle_attention_update import: SUCCESS")
    except ImportError as e:
        print(f"⚠️  triangle_attention_update import failed: {e}")
        return False

    print("=" * 60)
    return True

if not is_post_restart:
    # PRE-RESTART: PREFLIGHT CHECKS
    print("=" * 60)
    print("PREFLIGHT CHECKS")
    print("=" * 60)

    # Check GPU
    cuda_major, cuda_minor, cuda_version = get_cuda_version()
    if cuda_major is None:
        print("❌ No CUDA detected - cannot proceed")
        sys.exit(1)

    print(f"✅ CUDA Version: {cuda_version}")
    print(f"   Driver CUDA: {cuda_major}.{cuda_minor}")

    # Check GPU type
    result = subprocess.run(['nvidia-smi', '--query-gpu=name', '--format=csv,noheader'],
                          capture_output=True, text=True)
    if result.returncode == 0:
        gpu_name = result.stdout.strip()
        print(f"✅ GPU: {gpu_name}")

    # Map CUDA version to PyTorch (SHOW PLAN BEFORE RESTART)
    if cuda_major == 12:
        if cuda_minor <= 1:
            pytorch_cuda = "cu121"
            pytorch_version = "2.5.1"
            reason = "cu121 compatible with CUDA 12.0-12.1 (PyTorch 2.5.1 max)"
        elif cuda_minor <= 2:
            pytorch_cuda = "cu121"
            pytorch_version = "2.5.1"
            reason = "cu121 forward compatible with CUDA 12.2 (PyTorch 2.5.1)"
        elif cuda_minor <= 5:
            pytorch_cuda = "cu124"
            pytorch_version = "2.6.0"
            reason = "cu124 compatible with CUDA 12.3-12.5 (PyTorch 2.6.0)"
        elif cuda_minor <= 7:
            pytorch_cuda = "cu126"
            pytorch_version = "2.6.0"
            reason = "cu126 compatible with CUDA 12.6-12.7 (PyTorch 2.6.0)"
        else:
            pytorch_cuda = "cu128"
            pytorch_version = "2.7.0"
            reason = "cu128 for CUDA 12.8+ (PyTorch 2.7.0)"
    elif cuda_major == 11:
        pytorch_cuda = "cu118"
        pytorch_version = "2.6.0"
        reason = "cu118 for CUDA 11.x (PyTorch 2.6.0)"
    else:
        print(f"❌ Unsupported CUDA major version: {cuda_major}")
        sys.exit(1)

    print(f"\n📦 Selected PyTorch configuration:")
    print(f"   PyTorch: {pytorch_version}")
    print(f"   CUDA wheels: {pytorch_cuda}")
    print(f"   Rationale: {reason}")

    # Check pre-loaded NumPy
    result = subprocess.run([sys.executable, "-c", "import numpy; print(numpy.__version__)"],
                          capture_output=True, text=True)
    if result.returncode == 0:
        numpy_ver = result.stdout.strip()
        print(f"\n⚠️  Pre-loaded NumPy: {numpy_ver}")
        if numpy_ver.startswith('2.'):
            print("   → Colab Feb 2025 issue detected")
            print("   → NumPy 2.x must be cleared from memory")

    print("\n" + "=" * 60)
    print("RESTARTING RUNTIME TO CLEAR NUMPY 2.X")
    print("=" * 60)
    print("Runtime will restart in 2 seconds")
    print("After restart: Run this cell again to install")
    print("=" * 60)

    with open(restart_marker, "w") as f:
        f.write("restarted")

    import time
    time.sleep(2)
    os.kill(os.getpid(), 9)

else:
    # POST-RESTART: INSTALLATION
    print("=" * 60)
    print("INSTALLING BOLTZ-2")
    print("=" * 60)

    # Detect CUDA version
    cuda_major, cuda_minor, cuda_version = get_cuda_version()

    if cuda_major is None:
        print("❌ No CUDA detected - cannot proceed")
        sys.exit(1)

    print(f"✅ CUDA Version: {cuda_version}")
    print(f"   Driver CUDA: {cuda_major}.{cuda_minor}")

    # Map CUDA version to PyTorch wheel
    if cuda_major == 12:
        if cuda_minor <= 1:
            pytorch_cuda = "cu121"
            pytorch_version = "2.5.1"
            reason = "cu121 compatible with CUDA 12.0-12.1"
        elif cuda_minor <= 2:
            pytorch_cuda = "cu121"
            pytorch_version = "2.5.1"
            reason = "cu121 forward compatible with CUDA 12.2"
        elif cuda_minor <= 5:
            pytorch_cuda = "cu124"
            pytorch_version = "2.6.0"
            reason = "cu124 compatible with CUDA 12.3-12.5"
        elif cuda_minor <= 7:
            pytorch_cuda = "cu126"
            pytorch_version = "2.6.0"
            reason = "cu126 compatible with CUDA 12.6-12.7"
        else:
            pytorch_cuda = "cu128"
            pytorch_version = "2.7.0"
            reason = "cu128 for CUDA 12.8+"
    elif cuda_major == 11:
        pytorch_cuda = "cu118"
        pytorch_version = "2.6.0"
        reason = "cu118 for CUDA 11.x"
    else:
        print(f"❌ Unsupported CUDA major version: {cuda_major}")
        sys.exit(1)

    print(f"\n📦 Selected PyTorch configuration:")
    print(f"   PyTorch: {pytorch_version}")
    print(f"   CUDA wheels: {pytorch_cuda}")
    print(f"   Rationale: {reason}")

    # Clean existing installations
    print("\n" + "=" * 60)
    print("[Cleanup]")
    cleanup_packages = [
        'torch', 'torchvision', 'torchaudio',
        'pytorch-lightning', 'torchmetrics',
        'boltz', 'numpy', 'pandas'
    ]
    subprocess.run(
        f"{sys.executable} -m pip uninstall {' '.join(cleanup_packages)} -y",
        shell=True, capture_output=True
    )
    print("OK")

    # Install NumPy FIRST
    print("\n" + "=" * 60)
    print("[1/4] NumPy 1.26.4 - FIRST (Colab compatibility)")
    if not run_cmd(
        f"{sys.executable} -m pip install --no-cache-dir 'numpy==1.26.4'",
        "Installing numpy==1.26.4"
    ):
        print("❌ NumPy installation failed")
        sys.exit(1)

    # Verify NumPy
    result = subprocess.run(
        [sys.executable, "-c", "import numpy; print(numpy.__version__)"],
        capture_output=True, text=True
    )
    if '1.26' not in result.stdout:
        print(f"❌ NumPy version wrong: {result.stdout.strip()}")
        sys.exit(1)
    print(f"   ✅ Verified: NumPy {result.stdout.strip()}")

    # Install PyTorch
    print("\n" + "=" * 60)
    print(f"[2/4] PyTorch {pytorch_version} ({pytorch_cuda})")
    pytorch_url = f"https://download.pytorch.org/whl/{pytorch_cuda}"
    if not run_cmd(
        f"{sys.executable} -m pip install torch=={pytorch_version} torchvision torchaudio --index-url {pytorch_url}",
        f"Installing PyTorch {pytorch_version}"
    ):
        print("❌ PyTorch installation failed")
        sys.exit(1)

    # Install pytorch-lightning
    print("\n" + "=" * 60)
    print("[3/4] Lightning stack")
    if not run_cmd(
        f"{sys.executable} -m pip install pytorch-lightning==2.5.0 torchmetrics==1.4.0",
        "Installing Lightning"
    ):
        print("❌ Lightning installation failed")
        sys.exit(1)

    # Install Boltz
    print("\n" + "=" * 60)
    print("[4/4] Boltz-2 (+ dependencies)")
    print("   Note: cuequivariance packages managed by Boltz")
    if not run_cmd(
        f"{sys.executable} -m pip install boltz",
        "Installing Boltz-2"
    ):
        print("❌ Boltz installation failed")
        sys.exit(1)

    # PERMANENT FIX: Create sitecustomize.py
    print("\n" + "=" * 60)
    print("INSTALLING PERMANENT SYS.PATH FIX")
    print("=" * 60)

    sitecustomize_content = """# Colab Feb 2025 NumPy priority fix
import sys
import os

if '/env/python' in sys.path:
    sys.path.remove('/env/python')

if 'PYTHONPATH' in os.environ:
    del os.environ['PYTHONPATH']
"""

    sitecustomize_path = "/usr/local/lib/python3.12/dist-packages/sitecustomize.py"
    with open(sitecustomize_path, "w") as f:
        f.write(sitecustomize_content)

    print(f"   ✅ Created {sitecustomize_path}")
    print("   This fixes sys.path on EVERY future kernel start")

    # IPython startup script
    ipython_startup_dir = "/root/.ipython/profile_default/startup"
    os.makedirs(ipython_startup_dir, exist_ok=True)

    ipython_fix_path = os.path.join(ipython_startup_dir, "00-fix_syspath.py")
    with open(ipython_fix_path, "w") as f:
        f.write(sitecustomize_content)

    print(f"   ✅ Created {ipython_fix_path}")
    print("   Backup fix for IPython kernels")

    # APPLY FIX TO CURRENT KERNEL
    print("\n" + "=" * 60)
    print("APPLYING FIX TO CURRENT KERNEL")
    print("=" * 60)

    if '/env/python' in sys.path:
        sys.path.remove('/env/python')
        print("   ✅ Removed /env/python from sys.path")
    else:
        print("   ✅ /env/python not in sys.path")

    if 'PYTHONPATH' in os.environ:
        del os.environ['PYTHONPATH']
        print("   ✅ Cleared PYTHONPATH environment variable")

    # Clear cached imports
    modules_to_clear = [key for key in list(sys.modules.keys())
                       if key.startswith(('numpy', 'pandas', 'np', 'pd'))]
    for mod in modules_to_clear:
        del sys.modules[mod]

    if modules_to_clear:
        print(f"   ✅ Cleared {len(modules_to_clear)} cached modules")

    # VERIFICATION IN CURRENT KERNEL
    print("\n" + "=" * 60)
    print("VERIFICATION")
    print("=" * 60)

    print("\n[Testing imports in current kernel]")
    try:
        import numpy as np
        import pandas as pd

        print(f"   ✅ NumPy {np.__version__}")
        print(f"   ✅ Pandas {pd.__version__}")

        if not np.__version__.startswith('1.26'):
            print(f"   ❌ NumPy version wrong: expected 1.26.x, got {np.__version__}")
            sys.exit(1)

        print("\n   🎉 All imports working!")

    except Exception as e:
        print(f"   ❌ Import failed: {e}")
        import traceback
        traceback.print_exc()
        sys.exit(1)

    # CUEQUIVARIANCE KERNEL TEST
    kernels_available = test_cuequivariance_kernels()

    if kernels_available:
        print("\n✅ KERNEL TEST PASSED")
        print("   cuEquivariance kernels are available")
        print("   Will run Boltz-2 WITHOUT --no_kernels flag")
        use_no_kernels = False
    else:
        print("\n❌ KERNEL TEST FAILED")
        print("   cuEquivariance kernels are NOT available")
        print("   Will run Boltz-2 WITH --no_kernels flag")
        print("   Performance penalty: ~12 seconds per prediction")
        use_no_kernels = True

    # Store result for execution cell
    if 'global_settings' not in globals():
        global_settings = {}
    global_settings['use_no_kernels'] = use_no_kernels
    global_settings['kernels_tested'] = True

    print(f"\n🔧 Flag stored: use_no_kernels = {use_no_kernels}")

    # Show installed versions
    print("\n" + "=" * 60)
    print("INSTALLED PACKAGE VERSIONS")
    print("=" * 60)

    result = subprocess.run(
        [sys.executable, "-m", "pip", "list", "--format=freeze"],
        capture_output=True, text=True
    )

    all_packages = result.stdout.strip().split('\n')
    relevant = [
        'numpy', 'pandas', 'scipy',
        'torch', 'torchvision', 'torchaudio',
        'pytorch-lightning', 'torchmetrics',
        'boltz', 'cuequivariance-torch',
        'cuequivariance-ops-torch-cu11',
        'cuequivariance-ops-torch-cu12'
    ]

    print("\n📋 Core packages:")
    for pkg in relevant:
        for line in all_packages:
            if line.lower().startswith(pkg.lower() + '=='):
                print(f"   {line}")
                break
        else:
            for line in all_packages:
                if pkg.lower().replace('-', '_') in line.lower():
                    print(f"   {line}")
                    break

    # Save complete requirements.txt
    print("\n📄 Saving complete requirements.txt...")
    with open("/content/requirements_boltz.txt", "w") as f:
        f.write(f"# Boltz-2 Installation - CUDA {cuda_version}\n")
        f.write(f"# PyTorch {pytorch_version} ({pytorch_cuda})\n\n")
        f.write(result.stdout)
    print("   ✅ Saved to: /content/requirements_boltz.txt")

    # Cleanup and mark ready
    os.remove(restart_marker)
    with open("/content/BOLTZ_READY", "w") as f:
        f.write("Ready")

    print("\n" + "=" * 60)
    print("✅ BOLTZ-2 INSTALLATION COMPLETE")
    print("=" * 60)
    print("Next: Run Cell 2 to set up CSV processor")

INSTALLING BOLTZ-2
✅ CUDA Version: 12.4
   Driver CUDA: 12.4

📦 Selected PyTorch configuration:
   PyTorch: 2.6.0
   CUDA wheels: cu124
   Rationale: cu124 compatible with CUDA 12.3-12.5

[Cleanup]
OK

[1/4] NumPy 1.26.4 - FIRST (Colab compatibility)
[Installing numpy==1.26.4]
OK
   ✅ Verified: NumPy 1.26.4

[2/4] PyTorch 2.6.0 (cu124)
[Installing PyTorch 2.6.0]
OK

[3/4] Lightning stack
[Installing Lightning]
OK

[4/4] Boltz-2 (+ dependencies)
   Note: cuequivariance packages managed by Boltz
[Installing Boltz-2]
OK

INSTALLING PERMANENT SYS.PATH FIX
   ✅ Created /usr/local/lib/python3.12/dist-packages/sitecustomize.py
   This fixes sys.path on EVERY future kernel start
   ✅ Created /root/.ipython/profile_default/startup/00-fix_syspath.py
   Backup fix for IPython kernels

APPLYING FIX TO CURRENT KERNEL
   ✅ Removed /env/python from sys.path
   ✅ Cleared PYTHONPATH environment variable
   ✅ Cleared 86 cached modules

VERIFICATION

[Testing imports in current kernel]
   ✅ NumPy 1.26.4


In [2]:
#@title Cell 2: Boltz CSV Processor Setup
import pandas as pd
import os
import re
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any
from io import StringIO

class BoltzJobProcessor:
    """Process CSV files for Boltz-2 batch predictions"""

    EMBEDDED_REFERENCE = """Type,CCD Code,Name,Target Residue,Molecular Formula,All Atom Count,Heavy Atom Count
PTM,SEP,Phosphoserine,SER,C3H8NO6P,18,11
PTM,TPO,Phosphothreonine,THR,C4H10NO6P,21,12
PTM,PTR,Phosphotyrosine,TYR,C9H12NO6P,28,17
PTM,MLY,N-Methyllysine,LYS,C7H16N2O2,27,11
PTM,ALY,N-Acetyllysine,LYS,C8H16N2O3,29,13
PTM,HYP,Hydroxyproline,PRO,C5H9NO3,18,9
PTM,M3L,N-Trimethyllysine,LYS,C9H20N2O2,33,13
PTM,MLZ,N-Methyllysine,LYS,C7H16N2O2,27,11
PTM,CSD,Cysteine sulfinic acid,CYS,C3H7NO4S,16,9
PTM,CSO,S-Hydroxycysteine,CYS,C3H7NO3S,15,8
PTM,CGU,Gamma-carboxyglutamic acid,GLU,C5H7NO6,21,12
PTM,FME,N-Formylmethionine,MET,C6H11NO3S,22,11
PTM,NEP,N-(phosphonoethyl)isoleucine,ILE,C8H18NO5P,32,15
PTM,HIC,4-Methyl-histidine,HIS,C7H11N3O2,23,12
PTM,CAS,S-(dimethylarsenic)cysteine,CYS,C5H11AsNO2S,20,9
Ligand,ATP,Adenosine triphosphate,NA,C10H16N5O13P3,47,31
Ligand,ADP,Adenosine diphosphate,NA,C10H15N5O10P2,42,27
Ligand,AMP,Adenosine monophosphate,NA,C10H14N5O7P,37,23
Ligand,GTP,Guanosine triphosphate,NA,C10H16N5O14P3,48,32
Ligand,GDP,Guanosine diphosphate,NA,C10H15N5O11P2,43,28
Ligand,GMP,Guanosine monophosphate,NA,C10H14N5O8P,38,24
Ligand,CTP,Cytidine triphosphate,NA,C9H16N3O14P3,45,29
Ligand,CDP,Cytidine diphosphate,NA,C9H15N3O11P2,40,25
Ligand,UTP,Uridine triphosphate,NA,C9H15N2O15P3,44,29
Ligand,UDP,Uridine diphosphate,NA,C9H14N2O12P2,39,25
Ligand,NAD,Nicotinamide adenine dinucleotide,NA,C21H27N7O14P2,71,44
Ligand,NAP,NADP,NA,C21H28N7O17P3,86,55
Ligand,FAD,Flavin adenine dinucleotide,NA,C27H33N9O15P2,91,53
Ligand,FMN,Flavin mononucleotide,NA,C17H21N4O9P,52,31
Ligand,HEM,Heme,NA,C34H32FeN4O4,75,43
Ligand,SAH,S-Adenosyl-L-homocysteine,NA,C14H20N6O5S,46,26
Ligand,SAM,S-Adenosyl-L-methionine,NA,C15H22N6O5S,49,27
Ligand,COA,Coenzyme A,NA,C21H36N7O16P3S,90,57
Ligand,ACO,Acetyl coenzyme A,NA,C23H38N7O17P3S,99,61
Ligand,PLP,Pyridoxal-5-phosphate,NA,C8H10NO6P,25,16
Ligand,TPP,Thiamine diphosphate,NA,C12H19N4O7P2S,45,25
Ligand,BTN,Biotin,NA,C10H16N2O3S,32,16
Ligand,MTA,5-Methylthioadenosine,NA,C11H15N5O3S,35,20
Ligand,THM,Thiamine,NA,C12H17ClN4OS,38,18
Ion,MG,Magnesium ion,NA,Mg,1,1
Ion,ZN,Zinc ion,NA,Zn,1,1
Ion,CA,Calcium ion,NA,Ca,1,1
Ion,FE,Iron ion,NA,Fe,1,1
Ion,MN,Manganese ion,NA,Mn,1,1
Ion,CU,Copper ion,NA,Cu,1,1
Ion,CO,Cobalt ion,NA,Co,1,1
Ion,NI,Nickel ion,NA,Ni,1,1
Ion,K,Potassium ion,NA,K,1,1
Ion,NA,Sodium ion,NA,Na,1,1
Ion,CL,Chloride ion,NA,Cl,1,1
Glycan,NAG,N-Acetyl-D-glucosamine,NA,C8H15NO6,30,15
Glycan,MAN,D-Mannose,NA,C6H12O6,24,12
Glycan,FUC,L-Fucose,NA,C6H12O5,23,11
Glycan,GAL,D-Galactose,NA,C6H12O6,24,12
Glycan,SIA,N-Acetylneuraminic acid,NA,C11H19NO9,40,21
Glycan,GLC,D-Glucose,NA,C6H12O6,24,12
Glycan,BMA,beta-D-Mannose,NA,C6H12O6,24,12
Glycan,NDG,N-Acetyl-D-glucosamine,NA,C8H15NO6,30,15
Glycan,A2G,N-Acetyl-D-galactosamine,NA,C8H15NO6,30,15
Glycan,FUL,L-Fucose,NA,C6H12O5,23,11
DNA_Mod,5MC,5-Methylcytosine,DC,C10H15N3O7P,36,21
DNA_Mod,6MA,N6-Methyladenine,DA,C11H16N5O6P,39,23
DNA_Mod,5HMC,5-Hydroxymethylcytosine,DC,C10H15N3O8P,37,22
DNA_Mod,8OG,8-Oxoguanine,DG,C10H13N5O8P,37,24
DNA_Mod,M2G,N2-Methylguanine,DG,C11H16N5O7P,40,24
DNA_Mod,4MC,N4-Methylcytosine,DC,C10H15N3O7P,36,21
DNA_Mod,1MA,1-Methyladenine,DA,C11H16N5O6P,39,23
DNA_Mod,3MA,3-Methyladenine,DA,C11H16N5O6P,39,23
RNA_Mod,PSU,Pseudouridine,U,C9H12N2O9P,33,21
RNA_Mod,1MA,1-Methyladenosine,A,C11H15N5O7P,39,24
RNA_Mod,7MG,7-Methylguanosine,G,C11H15N5O8P,40,25
RNA_Mod,5MC,5-Methylcytidine,C,C10H15N3O8P,37,22
RNA_Mod,2MA,2-Methyladenosine,A,C11H15N5O7P,39,24
RNA_Mod,M2G,N2-Methylguanosine,G,C11H15N5O8P,40,25
RNA_Mod,M7G,7-Methylguanosine,G,C11H15N5O8P,40,25
RNA_Mod,M1A,1-Methyladenosine,A,C11H15N5O7P,39,24
RNA_Mod,OMC,2'-O-Methylcytidine,C,C10H15N3O8P,37,22
RNA_Mod,OMG,2'-O-Methylguanosine,G,C11H15N5O8P,40,25"""

    def __init__(self, reference_csv: Optional[str] = None):
        """Initialize processor with reference data"""
        if reference_csv:
            self.reference_data = pd.read_csv(StringIO(reference_csv))
        else:
            self.reference_data = pd.read_csv(StringIO(self.EMBEDDED_REFERENCE))

        self.ptm_lookup = self._create_lookup('PTM')
        self.ligand_lookup = self._create_lookup('Ligand')
        self.ion_lookup = self._create_lookup('Ion')
        self.glycan_lookup = self._create_lookup('Glycan')
        self.dna_mod_lookup = self._create_lookup('DNA_Mod')
        self.rna_mod_lookup = self._create_lookup('RNA_Mod')

        self.aa_3to1 = {
            'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',
            'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',
            'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',
            'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y'
        }

        self.amino_acids = set('ACDEFGHIKLMNPQRSTVWY')

    def _create_lookup(self, type_name: str) -> Dict[str, Dict[str, Any]]:
        """Create lookup dictionary for a specific type"""
        type_data = self.reference_data[self.reference_data['Type'] == type_name]
        lookup = {}
        for _, row in type_data.iterrows():
            if pd.notna(row['CCD Code']):
                lookup[row['CCD Code']] = {
                    'name': row['Name'],
                    'target_residue': row['Target Residue'] if pd.notna(row['Target Residue']) else 'NA'
                }
        return lookup

    def _validate_sequence_characters(self, sequence: str, seq_type: str) -> List[str]:
        """Validate sequence contains only allowed characters"""
        errors = []
        sequence = sequence.upper()

        if seq_type.lower() == 'protein':
            invalid_chars = []
            for i, char in enumerate(sequence, 1):
                if char.isalpha() and char not in self.amino_acids:
                    invalid_chars.append(f"{char}@{i}")
                elif not char.isalpha() and not char.isspace():
                    invalid_chars.append(f"{char}@{i}")
            if invalid_chars:
                errors.append(f"Invalid amino acids: {', '.join(invalid_chars[:10])}" +
                            ("..." if len(invalid_chars) > 10 else ""))

        elif seq_type.lower() == 'dna':
            valid_bases = set('ATCG')
            invalid_chars = []
            for i, char in enumerate(sequence, 1):
                if char.isalpha() and char not in valid_bases:
                    invalid_chars.append(f"{char}@{i}")
                elif not char.isalpha() and not char.isspace():
                    invalid_chars.append(f"{char}@{i}")
            if invalid_chars:
                errors.append(f"Invalid DNA bases: {', '.join(invalid_chars[:10])}" +
                            ("..." if len(invalid_chars) > 10 else ""))

        elif seq_type.lower() == 'rna':
            valid_bases = set('AUCG')
            invalid_chars = []
            for i, char in enumerate(sequence, 1):
                if char.isalpha() and char not in valid_bases:
                    invalid_chars.append(f"{char}@{i}")
                elif not char.isalpha() and not char.isspace():
                    invalid_chars.append(f"{char}@{i}")
            if invalid_chars:
                errors.append(f"Invalid RNA bases: {', '.join(invalid_chars[:10])}" +
                            ("..." if len(invalid_chars) > 10 else ""))

        return errors

    def _is_smiles(self, ligand_string: str) -> bool:
        """Check if string is likely a SMILES representation"""
        if len(ligand_string) < 3:
            return False
        smiles_chars = set('[]()=#@+-\\/CNOPSFClBrI0123456789')
        return any(char in smiles_chars for char in ligand_string)

    def _remap_modification_chains(self, mod_string: str, name_mapping: Dict[str, str]) -> str:
        """Remap modification chain IDs from user names to A, B, C... format"""
        if pd.isna(mod_string) or str(mod_string).strip() == '':
            return mod_string

        mod_string = str(mod_string).strip()

        for old_name, new_chain_id in name_mapping.items():
            mod_string = mod_string.replace(f"{old_name}:", f"{new_chain_id}:")

        return mod_string

    def _remap_contact_chains(self, contacts_string: str, name_mapping: Dict[str, str]) -> str:
        """Remap contact chain IDs from user names to A, B, C... format"""
        if pd.isna(contacts_string) or str(contacts_string).strip() == '':
            return contacts_string

        contacts_string = str(contacts_string).strip()

        for old_name, new_chain_id in name_mapping.items():
            contacts_string = contacts_string.replace(f"{old_name}:", f"{new_chain_id}:")

        return contacts_string

    def _remap_bond_chains(self, bonds_string: str, name_mapping: Dict[str, str]) -> str:
        """Remap covalent bond chain IDs from user names to A, B, C... format"""
        if pd.isna(bonds_string) or str(bonds_string).strip() == '':
            return bonds_string

        bonds_string = str(bonds_string).strip()

        for old_name, new_chain_id in name_mapping.items():
            bonds_string = bonds_string.replace(f"{old_name}:", f"{new_chain_id}:")

        return bonds_string

    def _parse_modifications(self, mod_string: str, sequence: str, chain_id: str, seq_type: str) -> Tuple[List[Dict], List[str]]:
        """Parse modifications in format: CHAIN:POSITION:CCD_CODE"""
        errors = []
        mods = []

        if pd.isna(mod_string) or str(mod_string).strip() == '':
            return mods, errors

        mod_string = str(mod_string).strip()
        mod_entries = [e.strip() for e in mod_string.split(';') if e.strip()]

        if seq_type.lower() == 'protein':
            mod_lookups = [self.ptm_lookup, self.glycan_lookup]
            mod_types = ['PTM', 'Glycan']
        elif seq_type.lower() == 'dna':
            mod_lookups = [self.dna_mod_lookup]
            mod_types = ['DNA Mod']
        elif seq_type.lower() == 'rna':
            mod_lookups = [self.rna_mod_lookup]
            mod_types = ['RNA Mod']
        else:
            return mods, errors

        for entry in mod_entries:
            parts = entry.split(':')
            if len(parts) != 3:
                errors.append(f"Invalid modification format: '{entry}'. Use CHAIN:POSITION:CCD_CODE")
                continue

            mod_chain, pos_str, ccd_code = parts
            mod_chain = mod_chain.strip()
            ccd_code = ccd_code.strip()

            if mod_chain != chain_id:
                continue

            try:
                position = int(pos_str.strip())

                found = False
                for lookup, mod_type in zip(mod_lookups, mod_types):
                    if ccd_code in lookup:
                        found = True

                        if position < 1 or position > len(sequence):
                            errors.append(f"{mod_type} position {position} out of range (sequence length: {len(sequence)})")
                            continue

                        target_residue = lookup[ccd_code]['target_residue']
                        actual_residue = sequence[position - 1].upper()

                        if target_residue != 'NA':
                            if mod_type == 'Glycan':
                                if actual_residue not in ['N', 'T', 'S']:
                                    errors.append(f"Glycan {ccd_code} requires N/T/S but found {actual_residue} at position {position}")
                                    continue
                            elif mod_type == 'PTM':
                                target_1letter = self.aa_3to1.get(target_residue.upper())
                                if target_1letter and actual_residue != target_1letter:
                                    errors.append(f"PTM {ccd_code} targets {target_residue}({target_1letter}) but found {actual_residue} at position {position}")
                                    continue
                            else:
                                if len(target_residue) > 1 and target_residue.startswith('D'):
                                    target_residue = target_residue[1]
                                if actual_residue != target_residue:
                                    errors.append(f"{mod_type} {ccd_code} targets {target_residue} but found {actual_residue} at position {position}")
                                    continue

                        mods.append({
                            'chain_id': mod_chain,
                            'position': position,
                            'ccd': ccd_code
                        })
                        break

                if not found:
                    errors.append(f"Unknown modification code: '{ccd_code}'")

            except ValueError:
                errors.append(f"Invalid position in modification: '{entry}'")

        return mods, errors

    def _parse_pocket(self, binder: str, contacts_string: str) -> Tuple[Optional[Dict], List[str]]:
        """Parse pocket constraint"""
        errors = []

        if pd.isna(binder) or str(binder).strip() == '':
            return None, errors

        binder = str(binder).strip()

        if pd.isna(contacts_string) or str(contacts_string).strip() == '':
            errors.append(f"Pocket binder '{binder}' specified but no contacts provided")
            return None, errors

        contacts_string = str(contacts_string).strip()
        contact_entries = [e.strip() for e in contacts_string.split(';') if e.strip()]

        contacts = []
        for entry in contact_entries:
            parts = entry.split(':')
            if len(parts) != 2:
                errors.append(f"Invalid contact format: '{entry}'. Use CHAIN:RESIDUE")
                continue

            try:
                chain_id, residue_num = parts
                contacts.append([chain_id.strip(), int(residue_num.strip())])
            except ValueError:
                errors.append(f"Invalid contact specification: '{entry}'")

        if not contacts:
            errors.append(f"No valid contacts parsed for pocket binder '{binder}'")
            return None, errors

        return {
            'binder': binder,
            'contacts': contacts
        }, errors

    def _parse_covalent_bonds(self, bonds_string: str) -> Tuple[List[Dict], List[str]]:
        """Parse covalent bond constraints"""
        errors = []
        bonds = []

        if pd.isna(bonds_string) or str(bonds_string).strip() == '':
            return bonds, errors

        bonds_string = str(bonds_string).strip()
        bond_entries = [e.strip() for e in bonds_string.split(';') if e.strip()]

        for entry in bond_entries:
            parts = entry.split(':')
            if len(parts) != 6:
                errors.append(f"Invalid covalent bond format: '{entry}'. Use CHAIN:RES:ATOM:CHAIN:RES:ATOM")
                continue

            try:
                chain1, res1, atom1, chain2, res2, atom2 = parts
                bonds.append({
                    'atom1': [chain1.strip(), int(res1.strip()), atom1.strip()],
                    'atom2': [chain2.strip(), int(res2.strip()), atom2.strip()]
                })
            except ValueError:
                errors.append(f"Invalid covalent bond specification: '{entry}'")

        return bonds, errors

    def _sanitize_jobname(self, jobname: str) -> Tuple[str, List[str]]:
        """Sanitize jobname for filesystem compatibility"""
        errors = []

        if pd.isna(jobname) or str(jobname).strip() == '':
            errors.append("Missing jobname")
            return '', errors

        original = str(jobname)
        sanitized = original.lower()
        sanitized = sanitized.replace('-', '_')
        sanitized = re.sub(r'[^a-z0-9_]', '', sanitized)

        if len(sanitized) > 128:
            sanitized = sanitized[:128]
            errors.append(f"Jobname truncated to 128 characters")

        if not sanitized.strip():
            errors.append(f"Jobname '{original}' became empty after sanitization")
            return '', errors

        return sanitized, errors

    def _generate_yaml(self, job: Dict) -> str:
        """Generate YAML format for Boltz-2"""
        lines = ["version: 1", "sequences:"]

        protein_groups = {}
        dna_groups = {}
        rna_groups = {}
        ligand_groups = {}

        for seq in job['sequences']:
            seq_type = seq['type']

            if seq_type == 'protein':
                key = (seq['sequence'], tuple(sorted((m['position'], m['ccd']) for m in seq['modifications'])) if seq['modifications'] else ())
                if key not in protein_groups:
                    protein_groups[key] = []
                protein_groups[key].append(seq)

            elif seq_type == 'dna':
                key = (seq['sequence'], tuple(sorted((m['position'], m['ccd']) for m in seq['modifications'])) if seq['modifications'] else ())
                if key not in dna_groups:
                    dna_groups[key] = []
                dna_groups[key].append(seq)

            elif seq_type == 'rna':
                key = (seq['sequence'], tuple(sorted((m['position'], m['ccd']) for m in seq['modifications'])) if seq['modifications'] else ())
                if key not in rna_groups:
                    rna_groups[key] = []
                rna_groups[key].append(seq)

            elif seq_type == 'ligand':
                if 'smiles' in seq:
                    key = ('smiles', seq['smiles'])
                else:
                    key = ('ccd', seq['ccd'])
                if key not in ligand_groups:
                    ligand_groups[key] = []
                ligand_groups[key].append(seq)

        for (sequence, mod_tuple), seqs in protein_groups.items():
            lines.append("  - protein:")
            chain_ids = [s['id'] for s in seqs]
            if len(chain_ids) == 1:
                lines.append(f"      id: {chain_ids[0]}")
            else:
                lines.append(f"      id: [{', '.join(chain_ids)}]")
            lines.append(f"      sequence: {sequence}")

            if seqs[0]['modifications']:
                lines.append("      modifications:")
                for mod in seqs[0]['modifications']:
                    lines.append(f"        - ptmType: {mod['ccd']}")
                    lines.append(f"          ptmPosition: {mod['position']}")

        for (sequence, mod_tuple), seqs in dna_groups.items():
            lines.append("  - dna:")
            chain_ids = [s['id'] for s in seqs]
            if len(chain_ids) == 1:
                lines.append(f"      id: {chain_ids[0]}")
            else:
                lines.append(f"      id: [{', '.join(chain_ids)}]")
            lines.append(f"      sequence: {sequence}")

            if seqs[0]['modifications']:
                lines.append("      modifications:")
                for mod in seqs[0]['modifications']:
                    lines.append(f"        - modificationType: {mod['ccd']}")
                    lines.append(f"          basePosition: {mod['position']}")

        for (sequence, mod_tuple), seqs in rna_groups.items():
            lines.append("  - rna:")
            chain_ids = [s['id'] for s in seqs]
            if len(chain_ids) == 1:
                lines.append(f"      id: {chain_ids[0]}")
            else:
                lines.append(f"      id: [{', '.join(chain_ids)}]")
            lines.append(f"      sequence: {sequence}")

            if seqs[0]['modifications']:
                lines.append("      modifications:")
                for mod in seqs[0]['modifications']:
                    lines.append(f"        - modificationType: {mod['ccd']}")
                    lines.append(f"          basePosition: {mod['position']}")

        for (lig_type, lig_value), seqs in ligand_groups.items():
            lines.append("  - ligand:")
            chain_ids = [s['id'] for s in seqs]
            if len(chain_ids) == 1:
                lines.append(f"      id: {chain_ids[0]}")
            else:
                lines.append(f"      id: [{', '.join(chain_ids)}]")

            if lig_type == 'smiles':
                lines.append(f"      smiles: '{lig_value}'")
            else:
                lines.append(f"      ccd: {lig_value}")

        if job.get('pocket') or job.get('covalent_bonds'):
            lines.append("constraints:")

            if job.get('pocket'):
                pocket = job['pocket']
                lines.append("  - pocket:")
                lines.append(f"      binder: {pocket['binder']}")
                lines.append("      contacts:")
                for contact in pocket['contacts']:
                    lines.append(f"        - [{contact[0]}, {contact[1]}]")

            if job.get('covalent_bonds'):
                for bond in job['covalent_bonds']:
                    lines.append("  - bond:")
                    lines.append(f"      atom1: [{bond['atom1'][0]}, {bond['atom1'][1]}, {bond['atom1'][2]}]")
                    lines.append(f"      atom2: [{bond['atom2'][0]}, {bond['atom2'][1]}, {bond['atom2'][2]}]")

        return '\n'.join(lines)

    def _process_job(self, row: pd.Series) -> Tuple[Optional[Dict], List[str]]:
        """Process a single job row from CSV"""
        errors = []
        all_sequences = []

        jobname, jobname_errors = self._sanitize_jobname(row.get('jobname', ''))
        errors.extend(jobname_errors)

        if not jobname:
            return None, errors

        pocket_binder = row.get('pocket_binder', '')
        pocket_contacts = row.get('pocket_contacts', '')
        covalent_bonds_str = row.get('covalent_bonds', '')

        chain_id_counter = 0
        name_to_chain_mapping = {}

        for i in range(1, 11):
            name_col = f'seq{i}_name'
            type_col = f'seq{i}_type'
            copies_col = f'seq{i}_copies'
            seq_col = f'seq{i}'
            mods_col = f'seq{i}_mods'

            if name_col not in row or type_col not in row or seq_col not in row:
                continue

            user_name = row.get(name_col, '')
            seq_type = row.get(type_col, '')
            copies = row.get(copies_col, 1)
            sequence = row.get(seq_col, '')
            mods = row.get(mods_col, '')

            if pd.isna(sequence) or str(sequence).strip() == '':
                continue

            sequence = str(sequence).strip()
            copies = int(copies) if pd.notna(copies) and str(copies).strip() != '' else 1
            user_name = str(user_name).strip() if pd.notna(user_name) else ''

            chain_ids = []
            for copy_num in range(copies):
                chain_id = chr(ord('A') + chain_id_counter)
                chain_ids.append(chain_id)

                if user_name and copy_num == 0:
                    name_to_chain_mapping[user_name] = chain_id

                chain_id_counter += 1

            if seq_type.lower() in ['protein', 'dna', 'rna']:
                char_errors = self._validate_sequence_characters(sequence, seq_type)
                errors.extend(char_errors)

                for idx, chain_id in enumerate(chain_ids):
                    remapped_mods = self._remap_modification_chains(mods, name_to_chain_mapping)
                    mods_list, mod_errors = self._parse_modifications(remapped_mods, sequence, chain_id, seq_type)
                    errors.extend(mod_errors)

                    seq_dict = {
                        'type': seq_type.lower(),
                        'id': chain_id,
                        'sequence': sequence,
                        'modifications': mods_list if mods_list else None
                    }
                    all_sequences.append(seq_dict)

            elif seq_type.lower() == 'ligand':
                ligand_string = sequence.strip()
                is_smiles = self._is_smiles(ligand_string)

                if not is_smiles:
                    if ligand_string not in self.ligand_lookup and ligand_string not in self.ion_lookup:
                        errors.append(f"Unknown ligand/ion CCD code: '{ligand_string}'")

                for chain_id in chain_ids:
                    if is_smiles:
                        seq_dict = {
                            'type': 'ligand',
                            'id': chain_id,
                            'smiles': ligand_string
                        }
                    else:
                        seq_dict = {
                            'type': 'ligand',
                            'id': chain_id,
                            'ccd': ligand_string
                        }
                    all_sequences.append(seq_dict)
            else:
                errors.append(f"Unsupported sequence type: {seq_type}")

        if not all_sequences:
            errors.append("No valid sequences found")
            return None, errors

        remapped_pocket_binder = name_to_chain_mapping.get(pocket_binder, pocket_binder) if pocket_binder else pocket_binder
        remapped_contacts = self._remap_contact_chains(pocket_contacts, name_to_chain_mapping)

        pocket, pocket_errors = self._parse_pocket(remapped_pocket_binder, remapped_contacts)
        errors.extend(pocket_errors)

        remapped_bonds = self._remap_bond_chains(covalent_bonds_str, name_to_chain_mapping)
        covalent_bonds, bond_errors = self._parse_covalent_bonds(remapped_bonds)
        errors.extend(bond_errors)

        has_modifications = any(seq.get('modifications') for seq in all_sequences)

        job = {
            'name': jobname,
            'sequences': all_sequences,
            'pocket': pocket,
            'covalent_bonds': covalent_bonds,
            'has_modifications': has_modifications,
            'has_pocket': pocket is not None,
            'has_covalent': len(covalent_bonds) > 0
        }

        return job, errors

    def process_csv(self, csv_path: str) -> Tuple[List[Dict], pd.DataFrame]:
        """Process CSV file and return jobs list and summary DataFrame"""
        df = pd.read_csv(csv_path)

        jobs = []
        summary_rows = []

        for idx, row in df.iterrows():
            job, errors = self._process_job(row)

            if job:
                jobs.append(job)
                summary_rows.append({
                    'jobname': job['name'],
                    'sequences': len(job['sequences']),
                    'has_modifications': job['has_modifications'],
                    'has_pocket': job['has_pocket'],
                    'has_covalent': job['has_covalent'],
                    'status': 'ERROR' if errors else 'OK',
                    'errors': '; '.join(errors) if errors else ''
                })
            else:
                summary_rows.append({
                    'jobname': f"Row {idx+1}",
                    'sequences': 0,
                    'has_modifications': False,
                    'has_pocket': False,
                    'has_covalent': False,
                    'status': 'FAILED',
                    'errors': '; '.join(errors)
                })

        summary_df = pd.DataFrame(summary_rows)
        return jobs, summary_df

print("🔧 Initializing Boltz CSV Processor...")
boltz_processor = BoltzJobProcessor()

print("✅ Using embedded reference data: 79 entries")
print("✅ Processor ready")
print("📋 Reference data includes:")
print(f"   • 15 PTM types")
print(f"   • 24 ligand types")
print(f"   • 11 ion types")
print(f"   • 10 glycan types")
print(f"   • 8 DNA modification types")
print(f"   • 10 RNA modification types")
print("\n💡 Using embedded reference data (common PTMs, ligands, ions, glycans, DNA/RNA mods)")
print("   To use custom reference: upload file in Cell 3")
print("\n📝 Note: Chain IDs are assigned as A, B, C, D... sequentially")
print("   Sequence identity is preserved in job names")

🔧 Initializing Boltz CSV Processor...
✅ Using embedded reference data: 79 entries
✅ Processor ready
📋 Reference data includes:
   • 15 PTM types
   • 24 ligand types
   • 11 ion types
   • 10 glycan types
   • 8 DNA modification types
   • 10 RNA modification types

💡 Using embedded reference data (common PTMs, ligands, ions, glycans, DNA/RNA mods)
   To use custom reference: upload file in Cell 3

📝 Note: Chain IDs are assigned as A, B, C, D... sequentially
   Sequence identity is preserved in job names


In [3]:
#@title Cell 3: Upload CSV and Configure Batch Jobs
from google.colab import files
import pandas as pd

# Configuration options
upload_custom_reference = False #@param {type:"boolean"}
#@markdown - Upload custom reference file (optional - embedded data includes common PTMs/ligands)

setup_google_drive = True #@param {type:"boolean"}
#@markdown - Setup Google Drive for automatic result upload

gdrive_folder_name = "Boltz2_Predictions" #@param {type:"string"}
#@markdown - Google Drive folder name for batch results

print("=" * 60)
print("📁 CSV UPLOAD FOR BATCH PROCESSING")
print("=" * 60)

# Handle custom reference file upload if requested
custom_ref_file = None
if upload_custom_reference:
    print("\n📤 Upload custom reference CSV file...")
    print("Required columns: Type, CCD Code, Name, Target Residue, Heavy Atom Count")
    uploaded_ref = files.upload()

    if uploaded_ref:
        custom_ref_file = list(uploaded_ref.keys())[0]
        print(f"✅ Custom reference uploaded: {custom_ref_file}")

# Initialize processor (will use embedded data if no custom file)
try:
    boltz_processor = BoltzJobProcessor(custom_ref_file)
except Exception as e:
    print(f"❌ Failed to initialize processor: {e}")
    raise

# Upload input CSV
print("\n📊 Upload your input CSV file with job specifications")
print("Required columns: jobname, seq1_name, seq1_type, seq1")
print("Optional: seq1_copies, seq1_mods, pocket_binder, pocket_contacts, covalent_bonds")
print("Supports up to 10 sequences per job (seq1 through seq10)")

uploaded_csv = files.upload()

if not uploaded_csv:
    raise ValueError("No CSV file uploaded")

csv_filename = list(uploaded_csv.keys())[0]
print(f"\n✅ Uploaded: {csv_filename}")

# Process CSV
print("\n🔄 Processing CSV...")
try:
    jobs, summary_df = boltz_processor.process_csv(csv_filename)
except Exception as e:
    print(f"❌ CSV processing failed: {e}")
    raise

# Display summary
print("\n" + "=" * 60)
print("📋 JOB SUMMARY")
print("=" * 60)
print(summary_df.to_string(index=False))

# Check for errors
error_jobs = summary_df[summary_df['status'] == 'ERROR']
if len(error_jobs) > 0:
    print(f"\n⚠️  {len(error_jobs)} jobs have errors:")
    for _, row in error_jobs.iterrows():
        print(f"  • {row['jobname']}: {row['errors']}")

    proceed = input("\nProceed with valid jobs only? (yes/no): ").strip().lower()
    if proceed not in ['yes', 'y']:
        raise ValueError("Processing cancelled by user")

# Setup Google Drive if requested
drive = None
if setup_google_drive:
    try:
        from pydrive2.drive import GoogleDrive
        from pydrive2.auth import GoogleAuth
        from google.colab import auth
        from oauth2client.client import GoogleCredentials

        print("\n☁️  Setting up Google Drive...")
        auth.authenticate_user()
        gauth = GoogleAuth()
        gauth.credentials = GoogleCredentials.get_application_default()
        drive = GoogleDrive(gauth)
        print("✅ Google Drive connected")
    except Exception as e:
        print(f"⚠️  Google Drive setup failed: {e}")
        drive = None

# Store in global settings
if 'global_settings' not in globals():
    global_settings = {}

global_settings.update({
    'batch_jobs': jobs,
    'csv_filename': csv_filename,
    'drive': drive,
    'gdrive_folder_name': gdrive_folder_name,
    'processor': boltz_processor
})

print("\n" + "=" * 60)
print(f"✅ READY TO PROCESS {len(jobs)} JOBS")
print("=" * 60)
print("\n📌 Next steps:")
print("  1. Configure MSA settings (Cell 4)")
print("  2. Configure prediction parameters (Cell 5)")
print("  3. Run batch predictions (Cell 6)")

📁 CSV UPLOAD FOR BATCH PROCESSING

📊 Upload your input CSV file with job specifications
Required columns: jobname, seq1_name, seq1_type, seq1
Optional: seq1_copies, seq1_mods, pocket_binder, pocket_contacts, covalent_bonds
Supports up to 10 sequences per job (seq1 through seq10)


Saving boltz_csv_template_Nick.csv to boltz_csv_template_Nick.csv

✅ Uploaded: boltz_csv_template_Nick.csv

🔄 Processing CSV...

📋 JOB SUMMARY
                jobname  sequences  has_modifications  has_pocket  has_covalent status errors
     hrip1_xp_016502092          2              False       False         False     OK       
     hrip1_xp_016467230          2              False       False         False     OK       
     hrip1_xp_016454296          2              False       False         False     OK       
     hrip1_xp_016481883          2              False       False         False     OK       
   mohrip1_xp_016448570          2              False       False         False     OK       
   mohrip1_xp_016436137          2              False       False         False     OK       
     mgnlp_xp_016504433          2              False       False         False     OK       
     mgnlp_xp_016475599          2              False       False         False     OK       
  mgnlp_xp_

In [4]:
#@title Cell 4: MSA Configuration
msa_mode = "mmseqs2_uniref_env" #@param ["mmseqs2_uniref_env", "mmseqs2_uniref", "single_sequence"]
#@markdown - **MSA generation method** (mmseqs2 modes use ColabFold server)

msa_pairing_strategy = "greedy" #@param ["greedy", "complete"]
#@markdown - **Pairing strategy**: `greedy` = pair any matching subsets, `complete` = all sequences must match

# Check if global_settings exists
if 'global_settings' not in globals():
    print("⚠️  Please run the CSV Upload cell first")
else:
    # Configure MSA settings
    if "mmseqs2" in msa_mode:
        use_msa_server = True
        msa_server_url = "https://api.colabfold.com"
    else:
        use_msa_server = False
        msa_server_url = None

    # Store MSA settings
    global_settings.update({
        'msa_mode': msa_mode,
        'msa_pairing_strategy': msa_pairing_strategy,
        'use_msa_server': use_msa_server,
        'msa_server_url': msa_server_url
    })

    print("✅ MSA configuration set:")
    print(f"  Mode: {msa_mode}")
    print(f"  Pairing strategy: {msa_pairing_strategy}")
    print(f"  Use MSA server: {use_msa_server}")

✅ MSA configuration set:
  Mode: mmseqs2_uniref_env
  Pairing strategy: greedy
  Use MSA server: True


In [5]:
#@title Cell 5: Advanced Prediction Settings
# Structure Prediction Settings
recycling_steps = 6 #@param {type:"integer"}
#@markdown - **Iterative refinement passes**: Each cycle refines the structure using updated predictions. Higher values improve local geometry and confidence scores. **Time**: ~linear scaling (3 steps = 3x base time). **VRAM**: +20-30% per additional step for intermediate states.

sampling_steps = 200 #@param {type:"integer"}
#@markdown - **Diffusion denoising iterations**: Controls how many steps the diffusion model takes to generate structures from noise. More steps = smoother, higher quality structures. **Time**: Linear scaling (50 steps = 4x faster than 200). **VRAM**: +10-15% for intermediate diffusion states.

diffusion_samples = 5 #@param {type:"integer"}
#@markdown - **Independent structure predictions**: Number of different structures generated per input. More samples increase diversity and reliability of results. **Time**: Linear scaling (5 samples = 5x base time). **VRAM**: Depends on max_parallel_samples setting.

max_parallel_samples = 5 #@param {type:"integer"}
#@markdown - **GPU memory management**: How many diffusion samples are processed simultaneously. Critical for large complexes - each parallel sample requires full model memory allocation. **Time**: Minimal impact on total time. **VRAM**: ~Linear scaling (2 parallel = ~2x memory, 5 parallel = ~5x memory).

step_scale = 1.638 #@param {type:"number"}
#@markdown - **Sampling temperature**: Controls randomness in structure generation. Higher values increase diversity but may reduce quality. 1.638 is optimized default. **Time**: No impact. **VRAM**: No impact.

# Affinity Prediction Settings
predict_affinity = False #@param {type:"boolean"}
#@markdown - **Binding strength prediction**: Runs additional affinity model to predict binding strength (Kd/Ki values). Most reliable for protein-small molecule complexes. **Time**: +50-100% total time. **VRAM**: +40-60% for affinity model loading.

affinity_mw_correction = False #@param {type:"boolean"}
#@markdown - **Molecular weight adjustment**: Applies size-based corrections to affinity predictions. Only affects affinity calculation, not structure. **Time**: Minimal impact. **VRAM**: No impact.

sampling_steps_affinity = 200 #@param {type:"integer"}
#@markdown - **Affinity model diffusion steps**: Controls quality of affinity predictions. Similar to sampling_steps but for the affinity model. **Time**: Linear scaling within affinity prediction. **VRAM**: +5-10% for affinity diffusion states.

diffusion_samples_affinity = 5 #@param {type:"integer"}
#@markdown - **Affinity prediction ensemble size**: Number of independent affinity predictions to average for final binding strength. More samples = more reliable Kd estimates. **Time**: Linear scaling for affinity portion. **VRAM**: Minimal additional impact.

# Output and Optimization Settings
output_format = "mmcif" #@param ["mmcif", "pdb"]
#@markdown - **Structure file format**: mmCIF supports more metadata and modern features, PDB is more widely compatible. Both contain same structural information. **Time**: No impact. **VRAM**: No impact.

write_full_pae = True #@param {type:"boolean"}
#@markdown - **Save Predicted Aligned Error matrix**: Confidence scores between all residue pairs. Essential for assessing interface quality and domain reliability. **Time**: +5-10% for matrix computation and I/O. **VRAM**: +10-20% for large complexes during matrix storage.

write_full_pde = False #@param {type:"boolean"}
#@markdown - **Save Predicted Distance Error matrix**: Distance confidence predictions between residue pairs. Useful for validation and uncertainty quantification. **Time**: +5-10% for matrix computation and I/O. **VRAM**: +10-20% for large complexes during matrix storage.

use_potentials = True #@param {type:"boolean"}
#@markdown - **Inference-time physics optimization**: Applies physics-based energy minimization to improve local geometry and remove clashes. Significantly improves structure quality, especially for interfaces. **Time**: +30-50% total time. **VRAM**: +15-25% for physics calculation buffers.

# Check if global_settings exists
if 'global_settings' not in globals():
    print("⚠️  Please run the 'Choose Input Method' cell first")
else:
    # Store advanced settings
    advanced_settings = {
        'recycling_steps': recycling_steps,
        'sampling_steps': sampling_steps,
        'diffusion_samples': diffusion_samples,
        'max_parallel_samples': max_parallel_samples,
        'step_scale': step_scale,
        'predict_affinity': predict_affinity,
        'affinity_mw_correction': affinity_mw_correction,
        'sampling_steps_affinity': sampling_steps_affinity,
        'diffusion_samples_affinity': diffusion_samples_affinity,
        'output_format': output_format,
        'write_full_pae': write_full_pae,
        'write_full_pde': write_full_pde,
        'use_potentials': use_potentials,
        'max_msa_seqs': 8192,
        'subsample_msa': False,
        'num_subsampled_msa': 1024
    }

    global_settings.update(advanced_settings)

    print("✅ Advanced settings configured:")
    print(f"  Recycling steps: {recycling_steps}")
    print(f"  Sampling steps: {sampling_steps}")
    print(f"  Diffusion samples: {diffusion_samples}")
    print(f"  Predict affinity: {predict_affinity}")
    print(f"  Output format: {output_format}")
    print(f"  Use potentials: {use_potentials}")

✅ Advanced settings configured:
  Recycling steps: 6
  Sampling steps: 200
  Diffusion samples: 5
  Predict affinity: False
  Output format: mmcif
  Use potentials: True


In [None]:
#@title Cell 6: Run Batch Predictions with Smart Kernel Detection
%%time
import subprocess
import os
import zipfile
import shutil
import time
from datetime import datetime

# Verify setup
if 'global_settings' not in globals():
    print("❌ Error: Please run the previous configuration cells first")
elif not global_settings.get('batch_jobs'):
    print("❌ Error: No batch jobs configured. Run CSV upload cell first")
else:
    # GPU verification
    print("🔍 Checking GPU availability...")
    try:
        import torch
        if torch.cuda.is_available():
            print(f"✅ GPU: {torch.cuda.get_device_name(0)} ({torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB)")
        else:
            print("⚠️  WARNING: No GPU detected - predictions will be very slow")
    except ImportError:
        print("❌ PyTorch not available")

    # Check kernel test status
    if not global_settings.get('kernels_tested', False):
        print("\n⚠️  WARNING: Kernel preflight test not run!")
        print("   Running with --no_kernels by default for safety")
        global_settings['use_no_kernels'] = True

    use_no_kernels_flag = global_settings.get('use_no_kernels', True)
    print(f"\n🔧 Kernel mode: {'--no_kernels' if use_no_kernels_flag else 'WITH kernels'}")

    if use_no_kernels_flag:
        print("   (Using CPU fallback - slower but more compatible)")
    else:
        print("   (Using CUDA kernels - faster performance)")

    # Helper functions (same as before)
    def find_or_create_folder(drive, folder_name, parent_id='root'):
        if not drive:
            return None
        try:
            file_list = drive.ListFile({
                'q': f"title='{folder_name}' and '{parent_id}' in parents and mimeType='application/vnd.google-apps.folder' and trashed=false"
            }).GetList()
            if file_list:
                print(f"✅ Found existing folder: {folder_name}")
                return file_list[0]['id']
            else:
                folder = drive.CreateFile({
                    'title': folder_name,
                    'mimeType': 'application/vnd.google-apps.folder',
                    'parents': [{'id': parent_id}]
                })
                folder.Upload()
                print(f"✅ Created new folder: {folder_name}")
                return folder['id']
        except Exception as e:
            print(f"❌ Error with folder '{folder_name}': {e}")
            return None

    def upload_to_gdrive(drive, file_path, folder_id, job_name):
        if not drive or not os.path.exists(file_path):
            return None
        try:
            uploaded_file = drive.CreateFile({
                'title': os.path.basename(file_path),
                'parents': [{'id': folder_id}]
            })
            uploaded_file.SetContentFile(file_path)
            uploaded_file.Upload()
            file_url = f"https://drive.google.com/file/d/{uploaded_file['id']}/view"
            print(f"  ☁️  Uploaded to Google Drive: {file_url}")
            return file_url
        except Exception as e:
            print(f"  ⚠️  Upload failed: {e}")
            return None

    def create_results_zip(job_dir, output_filename):
        with zipfile.ZipFile(output_filename, 'w', zipfile.ZIP_DEFLATED) as zipf:
            results_dirs = [d for d in os.listdir(job_dir) if d.startswith('boltz_results_')]
            if results_dirs:
                predictions_dir = os.path.join(job_dir, results_dirs[0])
                if os.path.exists(predictions_dir):
                    for root, dirs, files in os.walk(predictions_dir):
                        for file in files:
                            file_path = os.path.join(root, file)
                            arc_path = os.path.relpath(file_path, predictions_dir)
                            zipf.write(file_path, arc_path)

    def run_single_prediction(job, settings, job_num, total_jobs):
        """Run a single Boltz-2 prediction with improved diagnostics"""
        job_start_time = time.time()

        print(f"\n{'='*60}")
        print(f"🚀 Job {job_num}/{total_jobs}: {job['name']}")
        print(f"{'='*60}")

        job_name = job['name']
        job_dir = job_name
        os.makedirs(job_dir, exist_ok=True)

        # Generate YAML file
        yaml_content = settings['processor']._generate_yaml(job)
        yaml_file = os.path.join(job_dir, f"{job_name}.yaml")

        with open(yaml_file, 'w') as f:
            f.write(yaml_content)

        print(f"📝 Generated YAML configuration")

        # Build Boltz command - USE PREFLIGHT RESULT
        cmd_parts = [
            "boltz", "predict", yaml_file,
            "--out_dir", job_dir,
            "--recycling_steps", str(settings.get('recycling_steps', 6)),
            "--sampling_steps", str(settings.get('sampling_steps', 200)),
            "--diffusion_samples", str(settings.get('diffusion_samples', 5)),
            "--max_parallel_samples", str(settings.get('max_parallel_samples', 5)),
            "--step_scale", str(settings.get('step_scale', 1.638)),
            "--output_format", settings.get('output_format', 'mmcif'),
            "--max_msa_seqs", str(settings.get('max_msa_seqs', 8192)),
            "--override"
        ]

        # Conditionally add --no_kernels based on preflight test
        if settings.get('use_no_kernels', True):
            cmd_parts.append("--no_kernels")

        # Add MSA server if configured
        if settings.get('use_msa_server', True):
            cmd_parts.extend([
                "--use_msa_server",
                "--msa_server_url", settings.get('msa_server_url', 'https://api.colabfold.com'),
                "--msa_pairing_strategy", settings.get('msa_pairing_strategy', 'greedy')
            ])

        # Add optional flags
        if settings.get('write_full_pae', False):
            cmd_parts.append("--write_full_pae")
        if settings.get('write_full_pde', False):
            cmd_parts.append("--write_full_pde")
        if settings.get('predict_affinity', False):
            cmd_parts.extend([
                "--predict_affinity",
                "--sampling_steps_affinity", str(settings.get('sampling_steps_affinity', 200)),
                "--diffusion_samples_affinity", str(settings.get('diffusion_samples_affinity', 5))
            ])
            if settings.get('affinity_mw_correction', False):
                cmd_parts.append("--affinity_mw_correction")

        cmd = " ".join(cmd_parts)
        print(f"🔧 Command: {cmd}")

        # Run prediction with proper stderr/stdout capture
        try:
            result = subprocess.run(
                cmd,
                shell=True,
                capture_output=True,
                text=True,
                timeout=7200  # 2 hour timeout
            )

            # CRITICAL: Always show stderr if present, even with returncode 0
            if result.stderr and result.stderr.strip():
                print(f"\n📋 Boltz output/warnings:")
                # Show last 50 lines of stderr
                stderr_lines = result.stderr.strip().split('\n')
                for line in stderr_lines[-50:]:
                    if line.strip():
                        print(f"   {line}")

            if result.returncode == 0:
                # Check for output files
                results_dirs = [d for d in os.listdir(job_dir) if d.startswith('boltz_results_')]

                if not results_dirs:
                    print(f"\n❌ No results directory found")
                    print(f"   Expected directory starting with 'boltz_results_' in {job_dir}")
                    print(f"   Actual contents: {os.listdir(job_dir)}")
                    return False

                predictions_dir = os.path.join(job_dir, results_dirs[0])

                # Count structure files
                structure_count = 0
                structure_files = []
                for root, dirs, files in os.walk(predictions_dir):
                    for f in files:
                        if f.endswith(('.cif', '.pdb', '.mmcif')):
                            structure_count += 1
                            structure_files.append(os.path.join(root, f))

                if structure_count == 0:
                    print(f"\n❌ No structure files generated")
                    print(f"   Checked directory: {predictions_dir}")
                    print(f"   Directory contents:")
                    for root, dirs, files in os.walk(predictions_dir):
                        level = root.replace(predictions_dir, '').count(os.sep)
                        indent = ' ' * 2 * level
                        print(f"{indent}{os.path.basename(root)}/")
                        sub_indent = ' ' * 2 * (level + 1)
                        for f in files:
                            print(f"{sub_indent}{f}")

                    # Show full stderr for debugging
                    if result.stderr:
                        print(f"\n🔍 Full Boltz stderr for debugging:")
                        print(result.stderr)

                    return False

                print(f"✅ Generated {structure_count} structure files")
                for sf in structure_files[:5]:  # Show first 5
                    print(f"   📄 {os.path.basename(sf)}")
                if structure_count > 5:
                    print(f"   ... and {structure_count - 5} more files")

                # Create results zip
                zip_filename = f"{job_name}_results.zip"
                create_results_zip(job_dir, zip_filename)
                print(f"📦 Created: {zip_filename}")

                # Upload to Google Drive
                if global_settings.get('drive') and gdrive_folder_id:
                    upload_url = upload_to_gdrive(
                        global_settings['drive'],
                        zip_filename,
                        gdrive_folder_id,
                        job_name
                    )
                    if upload_url:
                        uploaded_files.append({'job': job_name, 'url': upload_url})

                # Cleanup
                try:
                    shutil.rmtree(job_dir)
                    if os.path.exists(zip_filename):
                        os.remove(zip_filename)
                except Exception as e:
                    print(f"⚠️  Cleanup warning: {e}")

                job_duration = time.time() - job_start_time
                print(f"⏱️  Completed in {job_duration:.1f}s")
                return True
            else:
                print(f"\n❌ Prediction failed (return code: {result.returncode})")
                if result.stderr:
                    print(f"\n🔍 Error details:")
                    print(result.stderr[-1000:])  # Last 1000 chars
                return False

        except subprocess.TimeoutExpired:
            print(f"⏰ Prediction timed out (2 hour limit)")
            return False
        except Exception as e:
            print(f"💥 Error: {e}")
            return False

    # Setup Google Drive folder
    gdrive_folder_id = None
    uploaded_files = []

    if global_settings.get('drive'):
        gdrive_folder_id = find_or_create_folder(
            global_settings['drive'],
            global_settings['gdrive_folder_name']
        )

    # Run batch predictions
    batch_jobs = global_settings['batch_jobs']
    start_time = datetime.now()

    print(f"\n{'='*60}")
    print(f"🚀 STARTING BATCH PROCESSING")
    print(f"{'='*60}")
    print(f"Total jobs: {len(batch_jobs)}")
    print(f"Started: {start_time.strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"{'='*60}")

    successful_jobs = 0
    failed_jobs = []

    for i, job in enumerate(batch_jobs, 1):
        success = run_single_prediction(
            job,
            global_settings,
            job_num=i,
            total_jobs=len(batch_jobs)
        )

        if success:
            successful_jobs += 1
        else:
            failed_jobs.append(job['name'])

        # Clear GPU cache between jobs
        try:
            import torch
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        except:
            pass

    # Final summary
    end_time = datetime.now()
    duration = end_time - start_time

    print(f"\n{'='*60}")
    print(f"📊 BATCH PROCESSING COMPLETE")
    print(f"{'='*60}")
    print(f"⏱️  Total duration: {duration}")
    print(f"✅ Successful: {successful_jobs}/{len(batch_jobs)} jobs")

    if failed_jobs:
        print(f"❌ Failed jobs: {len(failed_jobs)}")
        for job_name in failed_jobs:
            print(f"  • {job_name}")

    if uploaded_files:
        print(f"\n☁️  Files uploaded to Google Drive: {len(uploaded_files)}")
        print(f"📁 Folder: {global_settings['gdrive_folder_name']}")
        for file_info in uploaded_files[:5]:
            print(f"  • {file_info['job']}")
        if len(uploaded_files) > 5:
            print(f"  ... and {len(uploaded_files) - 5} more files")

    print(f"{'='*60}")

🔍 Checking GPU availability...
✅ GPU: NVIDIA A100-SXM4-40GB (42.5 GB)

🔧 Kernel mode: --no_kernels
   (Using CPU fallback - slower but more compatible)
✅ Found existing folder: Boltz2_Predictions

🚀 STARTING BATCH PROCESSING
Total jobs: 37
Started: 2025-10-27 16:11:58

🚀 Job 1/37: hrip1_xp_016502092
📝 Generated YAML configuration
🔧 Command: boltz predict hrip1_xp_016502092/hrip1_xp_016502092.yaml --out_dir hrip1_xp_016502092 --recycling_steps 6 --sampling_steps 200 --diffusion_samples 5 --max_parallel_samples 5 --step_scale 1.638 --output_format mmcif --max_msa_seqs 8192 --override --no_kernels --use_msa_server --msa_server_url https://api.colabfold.com --msa_pairing_strategy greedy --write_full_pae

   0%|          | 0/1 [00:00<?, ?it/s]
     0%|          | 0/300 [elapsed: 00:00 remaining: ?]
   SUBMIT:   0%|          | 0/300 [elapsed: 00:00 remaining: ?]
   COMPLETE:   0%|          | 0/300 [elapsed: 00:00 remaining: ?]
   COMPLETE: 100%|██████████| 300/300 [elapsed: 00:00 remaining: 