![logo](https://raw.githubusercontent.com/x-lab-3D/swiftmhc/main/logo.png)


##**SwiftMHC**
SwiftMHC predicts peptide-MHC structure and binding affinity at the same time. It currently works for HLA-A*02:01 9-mers only. For more details, check out the [code repo](https://github.com/X-lab-3D/swiftmhc) and [paper](https://doi.org/10.1101/2025.01.20.633893).



---


### To run
Either run each cell sequentially, or click on `Runtime -> Run All` after choosing the desired settings.

In [None]:
#@title Install dependencies
import os


# Note: conda is required by openfold, but it's ok to omit it for swiftmhc inference.

if not os.path.isfile("OPENFOLD_READY"):
    print("Installing OpenFold...")
    os.system("git clone https://github.com/aqlaboratory/openfold")
    os.system("cd openfold && scripts/install_third_party_dependencies.sh")
    os.system("touch OPENFOLD_READY")

if not os.path.isfile("SWIFTMHC_READY"):
    print("Installing swiftMHC...")
    os.system("pip install git+https://github.com/X-lab-3D/swiftmhc.git")
    os.system("touch SWIFTMHC_READY")


if not os.path.isfile("PY3DMOL_READY"):
    print("Installing py3Dmol...")
    os.system("pip install py3Dmol")
    os.system("touch PY3DMOL_READY")

print("Installation completed!")

Installing OpenFold...


In [None]:
#@title Download SwiftMHC model and reference MHC structure

import os
import sys

ref_structure_path = "HLA-A0201-from-3MRD.hdf5"
model_path = "8k-trained-model.pth"

if not os.path.isfile(ref_structure_path):
    os.system("wget https://github.com/X-lab-3D/swiftmhc/raw/refs/heads/main/data/HLA-A0201-from-3MRD.hdf5")

if not os.path.isfile(model_path):
    os.system("wget https://github.com/X-lab-3D/swiftmhc/raw/refs/heads/main/trained-models/8k-trained-model.pth")

print("Downloading completed!")

In [None]:
#@title Settings

import os
import re
import hashlib

#@markdown - Specify the 9-mer peptide sequence
peptide_sequence = "RLGPGKISV"  # @param {type:"string"}

#@markdown - Specify the job name
jobname = "swiftmhc"  # @param {type:"string"}

number_builders = 1
batch_size = 1

# ----------------------------
# Peptide sequence validation
# ----------------------------

# Normalize: remove whitespace and uppercase
peptide_sequence = re.sub(r"\s+", "", peptide_sequence).upper()

def validate_peptide(seq: str, expected_len: int = 9) -> None:
    peplen = len(seq)
    if peplen != expected_len:
        raise ValueError(
            f"Invalid peptide length: {peplen}. Expected a {expected_len}-mer peptide. "
            f"Got: '{seq}'"
        )

    # restrict to standard 20 AAs
    if not re.fullmatch(r"[ACDEFGHIKLMNPQRSTVWY]+", seq):
        raise ValueError(
            f"Peptide contains invalid characters: '{seq}'. "
            "Expected one-letter amino-acid codes (ACDEFGHIKLMNPQRSTVWY)."
        )

validate_peptide(peptide_sequence)

# ----------------------------
# Job name handling
# ----------------------------

def add_hash(base: str, key: str, n_chars: int = 5) -> str:
    """Append a short hash based on `key` to `base`."""
    digest = hashlib.sha1(key.encode("utf-8")).hexdigest()[:n_chars]
    return f"{base}_{digest}"

# Normalize job name: remove whitespace and non-word chars
basejobname = "".join(jobname.split())
basejobname = re.sub(r"\W+", "", basejobname)

# Fallback if jobname is empty after cleaning
if not basejobname:
    basejobname = "job"

# Make peptide-dependent jobname for uniqueness
jobname = add_hash(basejobname, peptide_sequence)

# Ensure directory name is unique
original_jobname = jobname
n = 1
while os.path.exists(jobname):
    jobname = f"{original_jobname}_{n}"
    n += 1

# ----------------------------
# Make directory and save input
# ----------------------------

os.makedirs(jobname, exist_ok=False)

input_path = os.path.join(jobname, "input.csv")
with open(input_path, "w", encoding="utf-8") as text_file:
    text_file.write("allele,peptide\n")
    text_file.write(f"HLA-A*02:01,{peptide_sequence}\n")

In [None]:
#@title Run prediction

import os
import csv
import subprocess
import zipfile


# ----------------------------
# Clean results.csv
# ----------------------------
affinity_file = os.path.join(jobname, "results.csv")

if os.path.exists(affinity_file):
    os.remove(affinity_file)


# ----------------------------
# Build and run SwiftMHC command
# ----------------------------

command = [
    "swiftmhc_predict",
    "--num-builders", str(number_builders),
    "--batch-size", str(batch_size),
    "8k-trained-model.pth",
    input_path,
    "HLA-A0201-from-3MRD.hdf5",
    jobname,
]

print("Running SwiftMHC...")

try:
    result = subprocess.run(
        command,
        check=True,
        capture_output=True,
        text=True,
    )

    print("SwiftMHC completed successfully!\n")

except subprocess.CalledProcessError as e:
    print("Error: SwiftMHC failed to run.\n")
    if e.stdout:
        print("=== STDOUT ===")
        print(e.stdout)
        print()
    if e.stderr:
        print("=== STDERR ===")
        print(e.stderr)
        print()
    raise RuntimeError(
        f"swiftmhc_predict exited with return code {e.returncode}"
    ) from e


# ----------------------------
# Read and print predicted affinity
# ----------------------------

if not os.path.isfile(affinity_file):
    print(f"\nNo results file found at: {affinity_file}")
else:
    try:
        with open(affinity_file, newline="") as f:
            reader = csv.DictReader(f)
            first_row = next(reader, None)
            if first_row is None:
                print("results.csv is empty.")
            elif "affinity" not in first_row:
                print("Column 'affinity' not found in results.csv.")
                print("Columns available:", list(first_row.keys()))
            else:
                affinity_value = float(first_row["affinity"])
                print(f"Predicted affinity: {affinity_value}")
    except Exception as e:
        print(f"Failed to read affinity value: {e}")

In [None]:
#@title Display 3D structure {run: "auto"}
import os
import py3Dmol

#@markdown - Coloring mode for the peptide/MHC complex
color = "chain"  #@param ["chain", "rainbow"]

#@markdown - Show peptide Cα atoms as spheres
show_carbon_alpha = True  #@param {type:"boolean"}

#@markdown - Show peptide side chains as sticks
show_sidechains = True  #@param {type:"boolean"}


pdb_file = os.path.join(jobname, f"HLA-Ax02_01-{peptide_sequence}.pdb")
if not os.path.isfile(pdb_file):
    raise FileNotFoundError(f"PDB file not found: {pdb_file}")


def show_pdb(
    pdb_path: str,
    show_sidechains: bool = False,
    show_carbon_alpha: bool = False,
    color: str = "chain",
):
    """Visualize a PDB structure using py3Dmol."""
    view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js')

    with open(pdb_path, "r") as f:
        view.addModel(f.read(), "pdb")

    # Cartoon coloring
    if color == "rainbow":
        view.setStyle({"cartoon": {"color": "spectrum"}})
    elif color == "chain":
        view.setStyle({"chain": "M"}, {"cartoon": {"color": "gold"}})
        view.setStyle({"chain": "P"}, {"cartoon": {"color": "blue"}})
    else:
        # Fallback: default cartoon
        view.setStyle({"cartoon": {}})


    # Define Cα atom selection once (used for zoom and optional spheres)
    BB = ["CA"]
    # Cα spheres (backbone)
    if show_carbon_alpha:
        view.addStyle(
            {"chain": "P", "atom": BB},
            {"sphere": {"colorscheme": "WhiteCarbon", "radius": 0.8}},
        )

    # Peptide side chains
    if show_sidechains:
        view.addStyle(
            {"chain": "P"},
            {"stick": {"colorscheme": "WhiteCarbon", "radius": 0.3}},
        )

    # Zoom to the peptide chain
    view.zoomTo()

    return view


view = show_pdb(
    pdb_path=pdb_file,
    show_sidechains=show_sidechains,
    show_carbon_alpha=show_carbon_alpha,
    color=color,
)

view.rotate(-120, {"x": 1, "y": 0, "z": 0}) # rotate -120 degree around x-axis
view.rotate(-200, {"x": 0, "y": 1, "z": 0})
view.rotate(20, {"x": 0, "y": 0, "z": 1})

view.show()

In [None]:
#@title Download the results

zip_filename = f"{jobname}.zip"
with zipfile.ZipFile(zip_filename, "w", zipfile.ZIP_DEFLATED) as zipf:
    for root, _, files in os.walk(jobname):
        for file in files:
            full_path = os.path.join(root, file)
            rel_path = os.path.relpath(full_path, jobname)
            zipf.write(full_path, rel_path)


from google.colab import files
files.download(zip_filename)
print(f"Downloading {zip_filename}")
