<a href="https://colab.research.google.com/github/RishikeshMagar/ColabBoltz/blob/main/ColabBoltz.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Colab Notebook to run Boltz-2.

In [None]:
#@title Install dependencies
%%capture
!pip install -q --no-deps boltz -U
!pip install -q py3Dmol rdkit biopython matplotlib pandas


In [None]:
#@title Input protein and ligand(s)
from google.colab import files
import os
import re
import hashlib
import random
import requests

# Helper to generate unique job name
def add_hash(x, y):
    return x + "_" + hashlib.sha1(y.encode()).hexdigest()[:5]

#@markdown ### Protein(s)
query_sequence = 'PIAQIHILEGRSDEQKETLIREVSEAISRSLDAPLTSVRVIITEMAKGHFGIGGELASKK'  #@param {type:"string"}
#@markdown Use `:` to separate chains (e.g., "SEQ1:SEQ2")

#@markdown ### Ligands
ligand_input = 'N[C@@H](Cc1ccc(O)cc1)C(=O)O'  #@param {type:"string"}
#@markdown - Colon-separated SMILES strings

ligand_input_ccd = 'SAH'  #@param {type:"string"}
#@markdown - Colon-separated CCD codes

ligand_input_common_name = ''  #@param {type:"string"}
#@markdown - Colon-separated common names (e.g., "aspirin")

#@markdown ### DNA
dna_input = ''  #@param {type:"string"}
#@markdown - Colon-separated DNA sequences

#@markdown ### Jobname
jobname = 'test'  #@param {type:"string"}

# Clean up
query_sequence = "".join(query_sequence.split())
ligand_input = "".join(ligand_input.split())
ligand_input_ccd = "".join(ligand_input_ccd.split())
ligand_input_common_name = "".join(ligand_input_common_name.split())
dna_input = "".join(dna_input.split())
basejobname = "".join(jobname.split())
basejobname = re.sub(r'\W+', '', basejobname)
jobname = add_hash(basejobname, query_sequence)

# Ensure uniqueness
if os.path.exists(jobname):
    n = 0
    while os.path.exists(f"{jobname}_{n}"):
        n += 1
    jobname = f"{jobname}_{n}"

# Create job folder
os.makedirs(jobname, exist_ok=True)
print(f"✅ Job directory created: {jobname}")


✅ Job directory created: test_18b28_0


In [None]:
#@title Input protein sequence(s), then hit `Runtime` -> `Run all`
from google.colab import files
import os
import re
import hashlib
import random
import requests
from string import ascii_uppercase

# Function to add a hash to the jobname
def add_hash(x, y):
    return x + "_" + hashlib.sha1(y.encode()).hexdigest()[:5]

# User inputs
query_sequence = 'PIAQIHILEGRSDEQKETLIREVSEAISRSLDAPLTSVRVIITEMAKGHFGIGGELASKK'  #@param {type:"string"}
#@markdown  - Use `:` to specify inter-protein chainbreaks for **modeling complexes** (supports homo- and hetro-oligomers). For example **PI...SK:PI...SK** for a homodimer
ligand_input = 'N[C@@H](Cc1ccc(O)cc1)C(=O)O'  #@param {type:"string"}
#@markdown  - Use `:` to specify multiple ligands as smile strings
ligand_input_ccd = 'SAH'  #@param {type:"string"}
#@markdown - Use `:` to specify multiple ligands as CCD codes (three-letter codes)
ligand_input_common_name = ''  #@param {type:"string"}
#@markdown - Use `:` to specify multiple ligands with their common name (e.g. Aspirin; SMILES fetched from [PubChem](https://pubchem.ncbi.nlm.nih.gov) API)
dna_input = ''  #@param {type:"string"}
#@markdown - Use `:` to specify multiple DNA sequences
jobname = 'test'  #@param {type:"string"}

# Clean up the query sequence and jobname
query_sequence = "".join(query_sequence.split())
ligand_input = "".join(ligand_input.split())
ligand_input_ccd = "".join(ligand_input_ccd.split())
ligand_input_common_name = "".join(ligand_input_common_name.split())
dna_input = "".join(dna_input.split())
basejobname = "".join(jobname.split())
basejobname = re.sub(r'\W+', '', basejobname)
jobname = add_hash(basejobname, query_sequence)

# Check if a directory with jobname exists
def check(folder):
    return not os.path.exists(folder)

if not check(jobname):
    n = 0
    while not check(f"{jobname}_{n}"):
        n += 1
    jobname = f"{jobname}_{n}"

# Make directory to save results
os.makedirs(jobname, exist_ok=True)

from string import ascii_uppercase

# Split sequences on chain breaks
protein_sequences = query_sequence.strip().split(':') if query_sequence.strip() else []
ligand_sequences = ligand_input.strip().split(':') if ligand_input.strip() else []
ligand_sequences_ccd = ligand_input_ccd.strip().split(':') if ligand_input_ccd.strip() else []
ligand_sequences_common_name = ligand_input_common_name.strip().split(':') if ligand_input_common_name.strip() else []
dna_sequences = dna_input.strip().split(':') if dna_input.strip() else []

def get_smiles(compound_name):
    autocomplete_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/autocomplete/compound/{compound_name}/json?limit=1"
    autocomplete_response = requests.get(autocomplete_url)
    if autocomplete_response.status_code != 200:
        return None

    autocomplete_data = autocomplete_response.json()
    if autocomplete_data.get("status", {}).get("code") != 0 or autocomplete_data.get("total", 0) == 0:
        return None

    suggested_compound = autocomplete_data.get("dictionary_terms", {}).get("compound", [])
    if not suggested_compound:
        return None
    suggested_compound_name = suggested_compound[0]

    smiles_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{suggested_compound_name}/property/CanonicalSMILES/JSON"
    smiles_response = requests.get(smiles_url)
    if smiles_response.status_code != 200:
        return None

    smiles_data = smiles_response.json()
    properties = smiles_data.get("PropertyTable", {}).get("Properties", [])
    if len(properties) == 0:
        return None

    return properties[0].get("CanonicalSMILES")

smiles_cache = {}
for name in ligand_sequences_common_name:
    if name not in smiles_cache:
        smiles_cache[name] = get_smiles(name)
        if smiles_cache[name] is not None:
          print(f"Mapped compound {name} to {smiles_cache[name]}")

    if smiles_cache[name] is not None:
        ligand_sequences.append(smiles_cache[name])

# Initialize chain labels starting from 'A'
chain_labels = iter(ascii_uppercase)

fasta_entries = []
csv_entries = []
chain_label_to_seq_id = {}
seq_to_seq_id = {}
seq_id_counter = 0  # Counter for unique sequences

# Process protein sequences
for seq in protein_sequences:
    seq = seq.strip()
    if not seq:
        continue  # Skip empty sequences
    chain_label = next(chain_labels)
    # Check if sequence has been seen before
    if seq in seq_to_seq_id:
        seq_id = seq_to_seq_id[seq]
    else:
        seq_id = f"{jobname}_{seq_id_counter}"
        seq_to_seq_id[seq] = seq_id
        seq_id_counter += 1
        # For CSV file (for ColabFold), add only unique sequences
        csv_entries.append((seq_id, seq))
    chain_label_to_seq_id[chain_label] = seq_id
    # For FASTA file
    msa_path = os.path.join(jobname, f"{seq_id}.a3m")
    header = f">{chain_label}|protein|{msa_path}"
    sequence = seq
    fasta_entries.append((header, sequence))

# Process ligand sequences (assumed to be SMILES strings)
for lig in ligand_sequences:
    lig = lig.strip()
    if not lig:
        continue  # Skip empty ligands
    chain_label = next(chain_labels)
    lig_type = 'smiles'
    header = f">{chain_label}|{lig_type}"
    sequence = lig
    fasta_entries.append((header, sequence))

# Process DNA sequences (NO MSA is generated)
for seq in dna_sequences:
    seq = seq.strip()
    if not seq:
        continue  # Skip empty sequences
    chain_label = next(chain_labels)
    lig_type = 'DNA'
    header = f">{chain_label}|{lig_type}"
    sequence = seq
    fasta_entries.append((header, sequence))

# Process ligand sequences (CCD codes)
for lig in ligand_sequences_ccd:
    lig = lig.strip()
    if not lig:
        continue  # Skip empty ligands
    chain_label = next(chain_labels)
    lig_type = 'ccd'
    header = f">{chain_label}|{lig_type}"
    sequence = lig.upper()  # Ensure CCD codes are uppercase
    fasta_entries.append((header, sequence))

# Write the CSV file for ColabFold
queries_path = os.path.join(jobname, f"{jobname}.csv")
with open(queries_path, "w") as text_file:
    text_file.write("id,sequence\n")
    for seq_id, seq in csv_entries:
        text_file.write(f"{seq_id},{seq}\n")





In [None]:
#@title Process inputs and write YAML
import os
import yaml
from string import ascii_uppercase

# Step 1: Split sequences on chain breaks
protein_sequences = query_sequence.strip().split(':') if query_sequence.strip() else []
ligand_sequences = ligand_input.strip().split(':') if ligand_input.strip() else []
ligand_sequences_ccd = ligand_input_ccd.strip().split(':') if ligand_input_ccd.strip() else []
ligand_sequences_common_name = ligand_input_common_name.strip().split(':') if ligand_input_common_name.strip() else []
dna_sequences = dna_input.strip().split(':') if dna_input.strip() else []

# Step 2: Resolve common names to SMILES
def get_smiles(compound_name):
    autocomplete_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/autocomplete/compound/{compound_name}/json?limit=1"
    autocomplete_response = requests.get(autocomplete_url)
    if autocomplete_response.status_code != 200:
        return None
    autocomplete_data = autocomplete_response.json()
    if autocomplete_data.get("status", {}).get("code") != 0 or autocomplete_data.get("total", 0) == 0:
        return None
    suggested = autocomplete_data.get("dictionary_terms", {}).get("compound", [])
    if not suggested:
        return None
    smiles_url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{suggested[0]}/property/CanonicalSMILES/JSON"
    smiles_response = requests.get(smiles_url)
    if smiles_response.status_code != 200:
        return None
    props = smiles_response.json().get("PropertyTable", {}).get("Properties", [])
    return props[0].get("CanonicalSMILES") if props else None

smiles_cache = {}
for name in ligand_sequences_common_name:
    if name not in smiles_cache:
        smiles_cache[name] = get_smiles(name)
        if smiles_cache[name]:
            print(f"Mapped compound {name} to {smiles_cache[name]}")
    if smiles_cache[name]:
        ligand_sequences.append(smiles_cache[name])

# Step 3: Assign chain labels and collect entries
chain_labels = iter(ascii_uppercase)
sequences_yaml = []
csv_entries = []
seq_id_counter = 0
seq_to_seq_id = {}

for seq in protein_sequences:
    seq = seq.strip()
    if not seq:
        continue
    chain_id = next(chain_labels)
    msa_path = os.path.join(jobname, f"{jobname}_{seq_id_counter}.a3m")
    if seq not in seq_to_seq_id:
        seq_id = f"{jobname}_{seq_id_counter}"
        seq_id_counter += 1
        seq_to_seq_id[seq] = seq_id
        csv_entries.append((seq_id, seq))
    sequences_yaml.append({
        "protein": {
            "id": chain_id,
            "sequence": seq,
            "msa": msa_path
        }
    })

for smiles in ligand_sequences:
    smiles = smiles.strip()
    if not smiles:
        continue
    chain_id = next(chain_labels)
    sequences_yaml.append({
        "ligand": {
            "id": chain_id,
            "smiles": smiles
        }
    })

for ccd in ligand_sequences_ccd:
    ccd = ccd.strip().upper()
    if not ccd:
        continue
    chain_id = next(chain_labels)
    sequences_yaml.append({
        "ligand": {
            "id": chain_id,
            "ccd": ccd
        }
    })

for dna in dna_sequences:
    dna = dna.strip()
    if not dna:
        continue
    chain_id = next(chain_labels)
    sequences_yaml.append({
        "dna": {
            "id": chain_id,
            "sequence": dna
        }
    })

# Step 4: Write CSV for MSA (ColabFold)
csv_path = os.path.join(jobname, f"{jobname}.csv")
with open(csv_path, 'w') as f:
    f.write("id,sequence\n")
    for seq_id, seq in csv_entries:
        f.write(f"{seq_id},{seq}\n")

# Step 5: Write YAML for Boltz
yaml_path = os.path.join(jobname, "input.yaml")
with open(yaml_path, 'w') as f:
    yaml.dump({"sequences": sequences_yaml}, f, sort_keys=False)

print(f"✅ YAML written to: {yaml_path}")
print(f"✅ CSV written to: {csv_path}")


✅ YAML written to: test_18b28_0/input.yaml
✅ CSV written to: test_18b28_0/test_18b28_0.csv


In [None]:
#@title Boltz Runtime Settings
#@markdown These will be passed as CLI flags to `boltz predict`

diffusion_steps = 200  #@param {type:"integer"}
num_samples = 1  #@param {type:"integer"}
batch_size = 1  #@param {type:"integer"}
random_seed = 42  #@param {type:"integer"}
use_msa_server = True  #@param {type:"boolean"}
device_override = ""  #@param {type:"string"}
#@markdown Leave empty to auto-detect GPU

print("✅ Runtime flags captured.")



✅ Runtime flags captured.


In [None]:
#@title Run Boltz Prediction
import shlex

# Required arguments
yaml_path = f"{jobname}/input.yaml"
output_dir = jobname

# Build the command
cmd = ["boltz", "predict", "--out_dir", output_dir]

# Optional flags based on user input
if diffusion_steps:
    cmd += ["--sampling_steps", str(diffusion_steps)]
if num_samples:
    cmd += ["--diffusion_samples", str(num_samples)]

if random_seed:
    cmd += ["--seed", str(random_seed)]
if device_override.strip():
    cmd += ["--device", device_override.strip()]
if use_msa_server:
    cmd += ["--use_msa_server"]

# Append the YAML input path (positional argument)
cmd += [yaml_path]

# Print the full command for verification
print("✅ Running command:\n", " ".join(shlex.quote(arg) for arg in cmd))

# Execute
!{" ".join(shlex.quote(arg) for arg in cmd)}


✅ Running command:
 boltz predict --out_dir test_18b28_0 --sampling_steps 200 --diffusion_samples 1 --seed 42 --use_msa_server test_18b28_0/input.yaml
Seed set to 42
Checking input data.
Found 0 existing processed inputs, skipping them.
Processing 1 inputs with 1 threads.
  0% 0/1 [00:00<?, ?it/s]Generating MSA for test_18b28_0/input.yaml with 1 protein entities.

  0% 0/150 [00:00<?, ?it/s][A
SUBMIT:   0% 0/150 [00:00<?, ?it/s][A
COMPLETE:   0% 0/150 [00:00<?, ?it/s][A
COMPLETE: 100% 150/150 [00:02<00:00, 70.38it/s] 
100% 1/1 [00:02<00:00,  2.40s/it]
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Running structure prediction for 1 input.
/usr/local/lib/python3.11/dist-packages/pytorch_lightning/utilities/migration/utils.py:56: The loaded checkpoint was produced with Lightning v2.5.0.post0, which is newer than your current Lightning version: v2.5.0
You are using a CUDA