<a href="https://colab.research.google.com/github/ChenSTeam/Colab_Computation/blob/main/01_Colab_Boltz.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Boltz using Colab notebook
Chen Yulin 2025-08-20

[Github](https://github.com/ChenSTeam/Colab_Computation)

The MSA is performed by [ColabFold](https://github.com/sokrypton/ColabFold).

The structure and affinity prediction is performed by [Boltz](https://github.com/jwohlwend/boltz).

Notebook references:

1. https://github.com/kimjc95/computational-chemistry.
2. https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/Boltz1.ipynb.

&nbsp;


**For structure + affinity prediction of the complex containing 1 protein (452 aa) and 1 small molecule (ATP), it takes ~10 min using T4 GPU.**

**(4 min for MSA, 4 min for structure and 2 min for affinity)**


In [1]:
#@title Install dependencies
#@markdown Run this cell to install applications and packages.
%%time

import subprocess
import yaml
import os

import torch

if torch.cuda.is_available():
    device = 'gpu'
else:
    device = 'cpu'

if not os.path.isfile("COLABFOLD_READY"):
  print("installing colabfold...")
  os.system("pip install -q --no-warn-conflicts 'colabfold[alphafold-minus-jax] @ git+https://github.com/sokrypton/ColabFold'")
  if os.environ.get('TPU_NAME', False) != False:
    os.system("pip uninstall -y jax jaxlib")
    os.system("pip install --no-warn-conflicts --upgrade dm-haiku==0.0.10 'jax[cuda12_pip]'==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html")
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/colabfold colabfold")
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/alphafold alphafold")
  # hack to fix TF crash
  os.system("rm -f /usr/local/lib/python3.*/dist-packages/tensorflow/core/kernels/libtfkernel_sobol_op.so")
  os.system("touch COLABFOLD_READY")

if not os.path.isfile("BOLTZ_READY"):
  print("installing boltz...")
  if device == 'gpu':
    os.system("pip install -q --no-warn-conflicts boltz[cuda] -U")
  else:
    os.system("pip install -q --no-warn-conflicts boltz -U")
  os.system("touch BOLTZ_READY")

installing colabfold...
installing boltz...
CPU times: user 1.73 s, sys: 389 ms, total: 2.12 s
Wall time: 1min 56s


In [2]:
#@title Enter inputs
#@markdown Type the job title name without blanks in the box below.
job_name = "test1" #@param {type:"string"}
#@markdown Run this cell and by using the interactive widgets below, enter the molecule sequence data.

#@markdown For small molecule ligands or modified residues, you can enter the CCD ID (Chemical Compoenent Dictionary code) which can be looked upon the [PDBeChem website](https://www.ebi.ac.uk/pdbe-srv/pdbechem/).

import ipywidgets as widgets
from IPython.display import display, HTML
import requests
from rdkit import Chem, RDLogger
from rdkit.Chem import Draw, AllChem
from datetime import datetime
import re

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
job_id = f"{job_name}_{timestamp}"

def validate_input(text, input_type)->bool:
    """
    Validate the input text about the molecule info based on the specified input type.
    """
    if input_type == 'protein':
        # Use RegEx to check all letters are in 20 canonical amino acid types
        return re.match(r'^[AC-IK-NP-TVWY]+$', text.upper()) is not None
    elif input_type == 'dna':
        # Use RegEx to check all letters are either A, C, G, or T
        return re.match(r'^[ACGT]+$', text.upper()) is not None
    elif input_type == 'rna':
        # Use RegEx to check all letters are either A, C, G, or U
        return re.match(r'^[ACGU]+$', text.upper()) is not None
    elif input_type == 'smiles':
        # Use RDKit to validate the SMILES string
        RDLogger.DisableLog('rdApp.*')
        try:
            mol = Chem.MolFromSmiles(text, sanitize=True)
        except:
            return False
        if mol is None:
            return False
        else:
            return True
    elif input_type == 'ccd':
        # Call header from PDBe website and check HTTP response status
        url = f"https://files.rcsb.org/ligands/download/{text.upper()}_ideal.cif"
        try:
            response = requests.head(url, timeout=5)
            return response.status_code == 200
        except requests.exceptions.RequestException:
            return False

class modify_entries():
    """
    Main module to show interactive widgets
    """
    def __init__(self, container):
        self.container = container
        self.seq_data = []

    def remove_seq_entry(self, b):
        """
        Operates when '-' button is pressed
        """
        for i in range(1, len(self.container.children)-1):
            # when the minus button is hit
            if self.container.children[i].children[0].children[-1] == b:
                newList = []
                # rename chain IDs in alphabetical order
                for j in range(i+1, len(self.container.children)-1):
                    show_chain = widgets.Label(value= 'chain '+str(chr(ord('A')+j-2)))
                    newline = widgets.HBox([show_chain]+list(self.container.children[j].children[0].children[1:]))
                    newEntry = widgets.VBox([newline]+list(self.container.children[j].children[1:]))
                    newList.append(newEntry)

                self.container.children = list(self.container.children[:i]) + newList + [self.container.children[-1]]
                break

    def add_seq_entry(self, b):
        """
        Operates when '+' button is pressed
        """

        # Chain ID
        show_chain = widgets.Label(value= 'chain '+str(chr(ord('A')+len(self.container.children)-2)))

        # Molecule type
        select_type = widgets.Dropdown(
            options=['protein', 'dna', 'rna', 'smiles', 'ccd'],
            description=' is ',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='120px'))

        # Molecule info string
        enter_sequence = widgets.Text(
            description=' described as :',
            placeholder='MAKEY... or CC1=CC=CC=C1 or ATP',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='600px'))

        # check for cyclic polymers
        cyclic = widgets.Checkbox(
            value=False,
            description=' cyclic? ',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='100px'))

        # minus button
        remove_btn = widgets.Button(description='-', layout=widgets.Layout(width='30px'))
        remove_btn.on_click(self.remove_seq_entry)

        # Error message
        message = widgets.HTML(value='', layout=widgets.Layout(width='600px', padding='5px'))

        def validate_string(change):
            """
            Validate the molecule info entered by the user.
            If the input is incorrect, show the red error message.
            """

            if select_type.value in ['smiles', 'ccd'] and cyclic.value:
                message.value = f"<span style='color: red;'>Only polymers (protein, DNA, and RNA) can be specified as cyclic!</span>"
            elif validate_input(enter_sequence.value, select_type.value):
                message.value = ""
            elif select_type.value in ['protein', 'dna', 'rna']:
                message.value = f"<span style='color: red;'>Enter the valid {select_type.value} sequence.</span>"
            else:
                message.value = f"<span style='color: red;'>Enter the valid {select_type.value} string.</span>"

        enter_sequence.observe(validate_string, names='value')
        select_type.observe(validate_string, names='value')
        cyclic.observe(validate_string, names='value')

        line = widgets.HBox([show_chain, select_type, enter_sequence, cyclic, remove_btn])
        entry = widgets.VBox([line, message])

        self.container.children = list(self.container.children[:-1]) + [entry, self.container.children[-1]]

    def update_seq_data(self, b):
        """
        Operates when the confirm button is pressed
        """

        def msa_file(mole_type, chain):
          if mole_type == 'protein':
            return f'{job_id}/{chain}.a3m'

        self.seq_data.clear()
        for i in range(1, len(self.container.children)-1):
            entry = self.container.children[i]
            line = entry.children[0]
            message = entry.children[1]
            # Filter invalid lines
            if message.value != '':
                continue
            # Filter empty lines
            if line.children[2].value == '':
                continue

            seq = {'chain': line.children[0].value[-1],
                   'type': line.children[1].value,
                   'sequence': line.children[2].value,
                   'cyclic': line.children[3].value,
                   'msa':msa_file(line.children[1].value, line.children[0].value[-1])}
            self.seq_data.append(seq)

title = widgets.HTML("<h4>Click the plus button to add molecules, and minus button to remove ones. Click the confirm button after entering all entries.</h4>")
add_button = widgets.Button(description='+', layout=widgets.Layout(width='30px'))
confirm_button = widgets.Button(description='confirm', layout=widgets.Layout(width='100px'))
buttons = widgets.HBox([add_button, confirm_button])
seq_container = widgets.VBox([title, buttons])

add_new_seq = modify_entries(seq_container)

add_button.on_click(add_new_seq.add_seq_entry)
confirm_button.on_click(add_new_seq.update_seq_data)

display(seq_container)

VBox(children=(HTML(value='<h4>Click the plus button to add molecules, and minus button to remove ones. Click …

In [3]:
#@title Ligand selection for affinity (optional)
#@markdown If affinity prediction is required, run this cell and select the ligand.

import io

ligands = [s for s in add_new_seq.seq_data if s['type'] in ['smiles', 'ccd']]

lig_select = widgets.Dropdown(options=['None']+[l['chain'] for l in ligands],
                              value='None',
                              style={'description_width': 'initial'},
                              description='Ligand chain ID:',
                              layout=widgets.Layout(width='300px'))

lig_image = widgets.Image(value=b'', format='png', width=400, height=300)

def update_sdf(change): # callback function to interactively update viewer
    ligand = lig_select.value
    if ligand == 'None':
        return
    for s in add_new_seq.seq_data:
        if s['chain'] == ligand:
            RDLogger.DisableLog('rdApp.*')
            if s['type'] == 'smiles':
                mol = Chem.MolFromSmiles(s['sequence'], sanitize=True)
            elif s['type'] == 'ccd':
                ccd = s['sequence'].upper()
                if not os.path.exists(ccd+'_ideal.sdf'):
                    subprocess.run(f'wget https://files.rcsb.org/ligands/download/{ccd}_ideal.sdf', shell=True)
                sdf_file = f'{ccd}_ideal.sdf'
                mol = Chem.rdmolfiles.SDMolSupplier(sdf_file)[0]
            break

    AllChem.Compute2DCoords(mol)
    img = Draw.MolToImage(mol, size=(400, 300))

    with io.BytesIO() as output:
        img.save(output, format="PNG")
        lig_image.value = output.getvalue()

lig_select.observe(update_sdf, names='value')

print("Select the ligand's chain ID and check the structure. Select None to cancel.")

display(lig_select, lig_image)

Select the ligand's chain ID and check the structure. Select None to cancel.


Dropdown(description='Ligand chain ID:', layout=Layout(width='300px'), options=('None', 'B'), style=Descriptio…

Image(value=b'', height='300', width='400')

In [4]:
#@title Create YAML file from the input data
#@markdown Once you are confident with all the inputs above, run this cell to generate the input file.

import yaml

data = {'version': 1, 'sequences':[]}

for s in add_new_seq.seq_data:
    if s['type'] in ['protein', 'dna', 'rna']:
        seq = {s['type']:{'id':s['chain'], 'sequence':s['sequence']}}
        if s['type'] == 'protein' and s['msa'] != '':
            seq['protein']['msa'] = s['msa']
        seq[s['type']]['cyclic'] = s['cyclic']
    elif s['type'] == 'smiles':
        seq = {'ligand':{'id':s['chain'], 'smiles':s['sequence']}}
    elif s['type'] == 'ccd':
        seq = {'ligand':{'id':s['chain'], 'ccd':s['sequence']}}
    data['sequences'].append(seq)

if lig_select is not None and lig_select.value != 'None':
    data['properties'] = [{'affinity':{'binder':lig_select.value}}]

with open(f'{job_id}.yaml', 'w') as f:
    yaml.dump(data, f, default_flow_style=False, sort_keys=False)
    print('Done!')
    print(f'Save yaml file as {job_id}.yaml!')

Done!
Save yaml file as test1_20250820_074912.yaml!


In [5]:
#@title Run MSA and Boltz prediction

#@markdown Run this cell for prediction.

import os
os.makedirs(job_id, exist_ok=True)

queries_path = f"{job_id}.csv"
csv_entries = []

for s in add_new_seq.seq_data:
    if s['type'] == 'protein':
      csv_entries.append((s['chain'],s['sequence']))

with open(queries_path, "w") as text_file:
    text_file.write("id,sequence\n")
    for seq_id, seq in csv_entries:
        text_file.write(f"{seq_id},{seq}\n")

!colabfold_batch "{job_id}.csv" "{job_id}" --msa-only
!boltz predict --out_dir "{job_id}" --accelerator "{device}" "{job_id}.yaml"

2025-08-20 07:49:50,332 Running colabfold 1.5.5 (c3e8ab010a2d2c5f8f18f653817c6bfbf58118e8)

limited shared resource only capable of processing a few thousand MSAs per day. Please
submit jobs only from a single IP address. We reserve the right to limit access to the
server case-by-case when usage exceeds fair use. If you require more MSAs: You can 
precompute all MSAs with `colabfold_search` or host your own API and pass it to `--host-url`

2025-08-20 07:49:53,479 Running on GPU
2025-08-20 07:49:53,940 Found 4 citations for tools or databases
2025-08-20 07:49:53,940 Query 1/1: A (length 452)
COMPLETE: 100% 150/150 [00:01<00:00, 96.85it/s] 
2025-08-20 07:49:55,594 Saved test1_20250820_074912/A.pickle
2025-08-20 07:49:59,489 Done
Downloading the CCD data to /root/.boltz/mols.tar. This may take a bit of time. You may change the cache directory with the --cache flag.
Extracting the CCD data to /root/.boltz/mols. This may take a bit of time. You may change the cache directory with the --cach

In [13]:
#@title Download the result
#@markdown Run this cell for the result downloading as zip file.

#@markdown The affinity is reports as log(IC50), derived from an IC50 measured in μM.

#@markdown - IC50 of $10^{-9}$ M $\longrightarrow$ our model outputs $-3$ (strong binder)
#@markdown - IC50 of $10^{-6}$ M $\longrightarrow$ our model outputs $0$ (moderate binder)
#@markdown - IC50 of $10^{-4}$ M $\longrightarrow$ our model outputs $2$ (weak binder / decoy)
# Import necessary modules
import os
import zipfile
from google.colab import files
import glob

# Name of the zip file
zip_filename = f"results_{job_id}.zip"

# Create a zip file and add the specified files without preserving directory structure
with zipfile.ZipFile(zip_filename, 'w') as zipf:
    coverage_png_files = glob.glob(os.path.join(job_id, '*_coverage.png'))
    a3m_files = glob.glob(os.path.join(job_id, '*.a3m'))
    yaml_files = glob.glob(f'{job_id}.yaml')
    csv_files = glob.glob(f'{job_id}.csv')
    for file in coverage_png_files + a3m_files + yaml_files + csv_files:
        arcname = os.path.basename(file)  # Use only the file name
        zipf.write(file, arcname=arcname)

    cif_files = glob.glob(os.path.join(job_id, f'boltz_results_{job_id}', 'predictions', job_id, '*.cif'))
    for file in cif_files:
        arcname = os.path.basename(file)  # Use only the file name
        zipf.write(file, arcname=arcname)

    hparams_file = os.path.join(job_id, f'boltz_results_{job_id}', 'lightning_logs', 'version_0', 'hparams.yaml')
    if os.path.exists(hparams_file):
        arcname = os.path.basename(hparams_file)  # Use only the file name
        zipf.write(hparams_file, arcname=arcname)
    else:
        print(f"Warning: {hparams_file} not found.")

    json_files = glob.glob(os.path.join(job_id, f'boltz_results_{job_id}', 'predictions', job_id, '*.json'))
    for file in json_files:
        arcname = os.path.basename(file)  # Use only the file name
        zipf.write(file, arcname=arcname)

# Download the zip file
files.download(zip_filename)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>