From 8e021862b52a3275e218fb71431603ebf72ab387 Mon Sep 17 00:00:00 2001 From: Zeming Lin Date: Mon, 1 Jun 2026 15:09:08 +0000 Subject: [PATCH 1/6] release binder design notebook --- README.md | 24 +- cookbook/tutorials/esmc_finetune.ipynb | 2 +- .../esmfold2_esmc_binder_design.ipynb | 1682 +++++++++++++++++ .../tutorials/esmfold2_esmc_binder_design.py | 1278 +++++++++++++ esm/models/esmfold2/processor.py | 2 +- esm/sdk/api.py | 2 +- 6 files changed, 2976 insertions(+), 14 deletions(-) create mode 100644 cookbook/tutorials/esmfold2_esmc_binder_design.ipynb create mode 100644 cookbook/tutorials/esmfold2_esmc_binder_design.py diff --git a/README.md b/README.md index 7c9d3a6..1809841 100644 --- a/README.md +++ b/README.md @@ -18,14 +18,14 @@ We are releasing a world model for protein biology: a scientific engine for pred -**[ESMFold2](https://huggingface.co/Biohub/ESMFold2)**, built on the ESMC 6B model, is a state-of-the-art structure prediction model that has been validated for the design of protein-protein interactions. ESMFold2 surpasses other models in DockQ pass-rate on Foldbench protein-protein and antibody-antigen complexes, and can be used in single-sequence mode for an order of magnitude speedup in folding. +**[ESMFold2](https://huggingface.co/biohub/ESMFold2)**, built on the ESMC 6B model, is a state-of-the-art structure prediction model that has been validated for the design of protein-protein interactions. ESMFold2 surpasses other models in DockQ pass-rate on Foldbench protein-protein and antibody-antigen complexes, and can be used in single-sequence mode for an order of magnitude speedup in folding.
-ESMFold2 is validated in the lab across five therapeutic targets. Inversion of ESMFold2 enables generation of de novo minibinders and antibody-derived scFvs with high hit rates, nanomolar affinities, target specificity, and functional activity. We're planning to release a notebook that walks through the full design loop from target sequence to ranked binder candidates. The full protocol is also described in the [preprint](https://biohub.ai/papers/esm_protein.pdf). +ESMFold2 is validated in the lab across five therapeutic targets. Inversion of ESMFold2 enables generation of de novo minibinders and antibody-derived scFvs with high hit rates, nanomolar affinities, target specificity, and functional activity. We've released the full protocol from target sequence to ranked binder design in this [tutorial](https://github.com/Biohub/esm/blob/main/cookbook/tutorials/esmfold2_esmc_binder_design.ipynb). For additional details, please refer to the [preprint](https://biohub.ai/papers/esm_protein.pdf).
@@ -77,10 +77,10 @@ login() sequences = ["MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTFSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHNIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDELYK"] model = AutoModelForMaskedLM.from_pretrained( - "Biohub/ESMC-6B", + "biohub/ESMC-6B", device_map="auto", ).eval() -tokenizer = AutoTokenizer.from_pretrained("Biohub/ESMC-6B") +tokenizer = AutoTokenizer.from_pretrained("biohub/ESMC-6B") inputs = tokenizer(sequences, return_tensors="pt", padding=True) inputs = {k: v.to(model.device) for k, v in inputs.items()} @@ -150,10 +150,10 @@ from transformers import AutoModel, AutoTokenizer sequence = "MGSNKSKPKDASQRRRSLEPAENVHGAGGGAFPASQTPSKPASADGHRGPSAAFAPAAAEPKLFGGFNSSDTVTSPQRAGPLAGGVTTFVALYDYESRTETDLSFKKGERLQIVNNTEGDWWLAHSLSTGQTGYIPSNYVAPSDSIQAEEWYFGKITRRESERLLLNAENPRGTFLVRESETTKGAYCLSVSDFDNAKGLNVKHYKIRKLDSGGFYITSRTQFNSLQQLVAYYSKHADGLCHRLTTVCPTSKPQTQGLAKDAWEIPRESLRLEVKLGQGCFGEVWMGTWNGTTRVAIKTLKPGTMSPEAFLQEAQVMKKLRHEKLVQLYAVVSEEPIYIVTEYMSKGSLLDFLKGETGKYLRLPQLVDMAAQIASGMAYVERMNYVHRDLRAANILVGENLVCKVADFGLARLIEDNEYTARQGAKFPIKWTAPEAALYGRFTIKSDVWSFGILLTELTTKGRVPYPGMVNREVLDQVERGYRMPCPPECPESLHDLMCQCWRKEPEERPTFEYLQAFLEDYFTSTEPQYQPGENL" -model = AutoModel.from_pretrained("Biohub/ESMC-6B", device_map="auto").eval() -tokenizer = AutoTokenizer.from_pretrained("Biohub/ESMC-6B") +model = AutoModel.from_pretrained("biohub/ESMC-6B", device_map="auto").eval() +tokenizer = AutoTokenizer.from_pretrained("biohub/ESMC-6B") sae = AutoModel.from_pretrained( - "Biohub/ESMC-6B-sae-k64-codebook16384", + "biohub/ESMC-6B-sae-k64-codebook16384", allow_patterns=["config.json", "layer_30.safetensors", "layer_60.safetensors"], device=model.device, ) @@ -176,11 +176,11 @@ For tutorials on how to use ESMC SAEs, see our [tutorials](https://github.com/Bi ## ESMFold2 -[ESMFold2](https://huggingface.co/Biohub/ESMFold2) is a state-of-the-art protein structure prediction model that combines ESMC (6B parameter) language model embeddings with a diffusion-based structure prediction architecture. +[ESMFold2](https://huggingface.co/biohub/ESMFold2) is a state-of-the-art protein structure prediction model that combines ESMC (6B parameter) language model embeddings with a diffusion-based structure prediction architecture. The model predicts high-resolution, all-atom 3D protein structures directly from amino acid sequences, with optional multiple sequence alignment (MSA) input for enhanced accuracy on challenging targets. ESMFold2 achieves state-of-the-art performance matching or exceeding AlphaFold3 across diverse evaluation datasets, while offering improved computational efficiency through optimized diffusion sampling and architectural innovations. -Codebase, model weights, and model variants for ESMFold2 are available through [Hugging Face](https://huggingface.co/Biohub/ESMFold2) +Codebase, model weights, and model variants for ESMFold2 are available through [Hugging Face](https://huggingface.co/biohub/ESMFold2) ### Running ESMFold2 Locally @@ -232,9 +232,11 @@ with open("1mht_pred.cif", "w") as f: f.write(result.complex.to_mmcif()) ``` +> **AMD ROCm users:** use ROCm 6.4 with PyTorch 2.9 or newer. + ### Running ESMFold2 Through the Biohub Platform -Install the `esm` Python package +Install the `esm` Python package ``` pip install esm@git+https://github.com/Biohub/esm.git@main @@ -283,7 +285,7 @@ Informed by our risk assessments, we are releasing the source code and model wei Evaluations: Prior to release, we conducted evaluations to inform our understanding of capability uplift for specific misuse-relevant functional tasks. The full details of these evaluations are available in our corresponding paper appendix. -The Biohub Platform: We implement guardrails that detect and restrict the use of keywords and sequences corresponding to controlled pathogens and toxins on our freely accessible platform. For further details regarding these guardrails, please refer to our Biohub Platform Resources page. We recognize there are many legitimate reasons to use AI models to understand and model these sequences and proteins. If you are a researcher whose work is impacted by these guardrails, you can request elevated access to our platform via [Biohub.ai](http://Biohub.ai). +The Biohub Platform: We implement guardrails that detect and restrict the use of keywords and sequences corresponding to controlled pathogens and toxins on our freely accessible platform. For further details regarding these guardrails, please refer to our Biohub Platform Resources page. We recognize there are many legitimate reasons to use AI models to understand and model these sequences and proteins. If you are a researcher whose work is impacted by these guardrails, you can request elevated access to our platform via [biohub.ai](https://biohub.ai). Please follow our [Acceptable Use Policy](https://biohub.org/acceptable-use-policy/) when using the model. diff --git a/cookbook/tutorials/esmc_finetune.ipynb b/cookbook/tutorials/esmc_finetune.ipynb index e6fc6f1..06d559a 100644 --- a/cookbook/tutorials/esmc_finetune.ipynb +++ b/cookbook/tutorials/esmc_finetune.ipynb @@ -307,7 +307,7 @@ "metadata": {}, "outputs": [], "source": [ - "MODEL_PATH = \"Biohub/ESMC-300M\"\n", + "MODEL_PATH = \"biohub/ESMC-300M\"\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)\n", "model = ESMCForSequenceClassification.from_pretrained(\n", diff --git a/cookbook/tutorials/esmfold2_esmc_binder_design.ipynb b/cookbook/tutorials/esmfold2_esmc_binder_design.ipynb new file mode 100644 index 0000000..c04234b --- /dev/null +++ b/cookbook/tutorials/esmfold2_esmc_binder_design.ipynb @@ -0,0 +1,1682 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "b5b44288", + "metadata": {}, + "source": [ + "## [Tutorial](https://github.com/biohub/esm/tree/main/cookbook/tutorials): How to run minibinder + scFv design fully end-to-end.\n", + "\n", + "In this tutorial we will use [Modal](https://modal.com/) to parallelize binder design and synthesize a selection,\n", + "using the protocol described in the ESMC and ESMFold2 paper titled [\"Language Modeling Materializes a World Model of Protein Biology\"](https://biohub.ai/papers/esm_protein.pdf).\n", + "\n", + "Biohub used this approach to design minibinders and scFvs against five therapeutically relevant targets — PDGFRB, EGFR, PD-L1, CD45, and CTLA4 — spanning receptor tyrosine kinases, immune checkpoints, and cell-surface phosphatases. Binders exhibit nanomolar affinity, target specificity, and functional activity in laboratory assays." + ] + }, + { + "cell_type": "markdown", + "id": "4421cdaa", + "metadata": {}, + "source": [ + "### One-time setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "99783792", + "metadata": {}, + "outputs": [], + "source": [ + "# Environment\n", + "! pip install esm@git+https://github.com/Biohub/esm.git@main\n", + "! pip install modal py3dmol pyarrow" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d3de7d00", + "metadata": {}, + "outputs": [], + "source": [ + "# Confirm you have a modal token, or make one\n", + "! modal token info # Check\n", + "# ! modal token new # Create" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2ee89767", + "metadata": {}, + "outputs": [], + "source": [ + "# Deploy (or redeploy after changing modal_binder_design.py).\n", + "# This only needs to be run a single time, unless code in esmfold2_esmc_binder_design.py changes.\n", + "! modal deploy esmfold2_esmc_binder_design.py" + ] + }, + { + "cell_type": "markdown", + "id": "dc8456da", + "metadata": {}, + "source": [ + "### Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "37c03b59", + "metadata": {}, + "outputs": [], + "source": [ + "from itertools import product\n", + "from pathlib import Path\n", + "\n", + "import modal\n", + "import pandas as pd\n", + "import py3Dmol\n", + "from Bio.SeqUtils.ProtParam import ProteinAnalysis\n", + "from tqdm.auto import tqdm" + ] + }, + { + "cell_type": "markdown", + "id": "1fe9a141", + "metadata": {}, + "source": [ + "### App setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "291d8bfe", + "metadata": {}, + "outputs": [], + "source": [ + "ESMFold2Design = modal.Cls.from_name(\"esmfold2-design\", \"ESMFold2DesignModal\")\n", + "# Set 'use_scaling_critics' to evaluate with the additional critics.\n", + "# Off by default. But cells below were populated with them enabled.\n", + "app = ESMFold2Design(use_scaling_critics=False)" + ] + }, + { + "cell_type": "markdown", + "id": "159b63df", + "metadata": {}, + "source": [ + "### Run one job - interactive" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "826c88d1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'https://modal.com/id/fc-01KSTCT9W9PYKN3HEKEZ168VJP'" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# ---- Option 1: Use presets. ----\n", + "# Relies on the registry in modal_binder_design.py::{TARGET_SEQUENCES,BINDER_PROMPT_FACTORIES}, which can be modified.\n", + "future = app.design.spawn(target_name=\"ctla4\", binder_name=\"minibinder\")\n", + "future.get_dashboard_url() # A clickable link to Modal dashboard" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "c1bda1a6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'https://modal.com/id/fc-01KSTCT9YCT8HH50718ZBABJRT'" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# ---- Option 2: Provide your own sequences. ----\n", + "# Our pd-l1 sequence crop.\n", + "pdl1_sequence = \"AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKNIIQFVHGEEDLKVQHSSYRQRARLLKDQLSLGNAALQITDVKLQDAGVYRCMISYGGADYKRITVKVNA\"\n", + "# A sample of 'trastuzumab_framework_vhvl' template. From esmfold2_esmc_binder_design.py::BINDER_PROMPT_FACTORIES.\n", + "trastuzumab_framework_vhvl = \"EVQLVESGGGLVQPGGSLRLSCAAS#######YIHWVRQAPGKGLEWVARI#####TRYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCSR###########WGQGTLVTVSSGGGSGGGSGGGSGGGSDIQMTQSPSSLSASVGDRVTITC###########WYQQKPGKAPKLLIY#######GVPSRFSGSRSGTDFTLTISSLQPEDFATYYC#########FGQGTKVEIK\"\n", + "future2 = app.design.spawn(\n", + " target_sequence=pdl1_sequence,\n", + " binder_sequence=trastuzumab_framework_vhvl,\n", + " is_antibody=True,\n", + ")\n", + "future2.get_dashboard_url() # A clickable link to Modal dashboard" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e442e64e", + "metadata": {}, + "outputs": [], + "source": [ + "# ---- Monitor ----\n", + "# Tail a function's output here in jupyter\n", + "! modal app logs esmfold2-design -f --function-call {future2.object_id}" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "79ea37f3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Best sequences: ['AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKNIIQFVHGEEDLKVQHSSYRQRARLLKDQLSLGNAALQITDVKLQDAGVYRCMISYGGADYKRITVKVNA|EVQLVESGGGLVQPGGSLRLSCAASEPADEDDYIHWVRQAPGKGLEWVARITYEEKTRYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCSRWTAMAIGNDVAWGQGTLVTVSSGGGSGGGSGGGSGGGSDIQMTQSPSSLSASVGDRVTITCRFSQDVTIRLSWYQQKPGKAPKLLIYFAFILANGVPSRFSGSRSGTDFTLTISSLQPEDFATYYCNYTRYSSSRFGQGTKVEIK']\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
is_antibodycritic_namebatch_idxdesigned_sequencefinal_lossiptmdistogram_iptm_proxycdr_distogram_iptm_proxy
0TrueESMFold2-Experimental-Fast0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.5445020.9284950.8509760.873329
1TrueESMFold2-Experimental-Fast-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.5445020.9148860.8370670.856937
2TrueESMFold2-Experimental0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.5445020.9145340.8241510.839054
3TrueESMFold2-Experimental-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.5445020.9276020.8350800.858924
4TrueESMFold2-Experimental-Fast-base300M-step250k0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.544502NaN0.7605680.791366
5TrueESMFold2-Experimental-Fast-base300M-step500k0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.544502NaN0.8171970.826138
6TrueESMFold2-Experimental-Fast-base300M-step750k0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.544502NaN0.7536130.773483
7TrueESMFold2-Experimental-Fast-base300M-step1000k0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.544502NaN0.8311060.860911
8TrueESMFold2-Experimental-Fast-base300M-step1500k0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.544502NaN0.7933530.737717
9TrueESMFold2-Experimental-Fast-base600M-step250k0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.544502NaN0.7943460.810242
10TrueESMFold2-Experimental-Fast-base600M-step500k0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.544502NaN0.7863990.816203
11TrueESMFold2-Experimental-Fast-base600M-step750k0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.544502NaN0.8102420.834086
12TrueESMFold2-Experimental-Fast-base600M-step1000k0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.544502NaN0.7814310.810242
13TrueESMFold2-Experimental-Fast-base600M-step1500k0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.544502NaN0.7953400.824151
14TrueESMFold2-Experimental-Fast-base6B-step250k0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.544502NaN0.7913660.819184
15TrueESMFold2-Experimental-Fast-base6B-step500k0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.544502NaN0.8003070.828125
16TrueESMFold2-Experimental-Fast-base6B-step750k0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.544502NaN0.8122290.845015
17TrueESMFold2-Experimental-Fast-base6B-step1000k0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.544502NaN0.8340860.859917
18TrueESMFold2-Experimental-Fast-base6B-step1500k0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.544502NaN0.7774570.798320
\n", + "
" + ], + "text/plain": [ + " is_antibody critic_name batch_idx \\\n", + "0 True ESMFold2-Experimental-Fast 0 \n", + "1 True ESMFold2-Experimental-Fast-Cutoff2025 0 \n", + "2 True ESMFold2-Experimental 0 \n", + "3 True ESMFold2-Experimental-Cutoff2025 0 \n", + "4 True ESMFold2-Experimental-Fast-base300M-step250k 0 \n", + "5 True ESMFold2-Experimental-Fast-base300M-step500k 0 \n", + "6 True ESMFold2-Experimental-Fast-base300M-step750k 0 \n", + "7 True ESMFold2-Experimental-Fast-base300M-step1000k 0 \n", + "8 True ESMFold2-Experimental-Fast-base300M-step1500k 0 \n", + "9 True ESMFold2-Experimental-Fast-base600M-step250k 0 \n", + "10 True ESMFold2-Experimental-Fast-base600M-step500k 0 \n", + "11 True ESMFold2-Experimental-Fast-base600M-step750k 0 \n", + "12 True ESMFold2-Experimental-Fast-base600M-step1000k 0 \n", + "13 True ESMFold2-Experimental-Fast-base600M-step1500k 0 \n", + "14 True ESMFold2-Experimental-Fast-base6B-step250k 0 \n", + "15 True ESMFold2-Experimental-Fast-base6B-step500k 0 \n", + "16 True ESMFold2-Experimental-Fast-base6B-step750k 0 \n", + "17 True ESMFold2-Experimental-Fast-base6B-step1000k 0 \n", + "18 True ESMFold2-Experimental-Fast-base6B-step1500k 0 \n", + "\n", + " designed_sequence final_loss iptm \\\n", + "0 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 0.928495 \n", + "1 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 0.914886 \n", + "2 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 0.914534 \n", + "3 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 0.927602 \n", + "4 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 NaN \n", + "5 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 NaN \n", + "6 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 NaN \n", + "7 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 NaN \n", + "8 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 NaN \n", + "9 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 NaN \n", + "10 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 NaN \n", + "11 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 NaN \n", + "12 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 NaN \n", + "13 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 NaN \n", + "14 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 NaN \n", + "15 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 NaN \n", + "16 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 NaN \n", + "17 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 NaN \n", + "18 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 NaN \n", + "\n", + " distogram_iptm_proxy cdr_distogram_iptm_proxy \n", + "0 0.850976 0.873329 \n", + "1 0.837067 0.856937 \n", + "2 0.824151 0.839054 \n", + "3 0.835080 0.858924 \n", + "4 0.760568 0.791366 \n", + "5 0.817197 0.826138 \n", + "6 0.753613 0.773483 \n", + "7 0.831106 0.860911 \n", + "8 0.793353 0.737717 \n", + "9 0.794346 0.810242 \n", + "10 0.786399 0.816203 \n", + "11 0.810242 0.834086 \n", + "12 0.781431 0.810242 \n", + "13 0.795340 0.824151 \n", + "14 0.791366 0.819184 \n", + "15 0.800307 0.828125 \n", + "16 0.812229 0.845015 \n", + "17 0.834086 0.859917 \n", + "18 0.777457 0.798320 " + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# ---- Load result ----\n", + "best_sequences, trajectory, critic_results = future2.get()\n", + "print(\"Best sequences: \", best_sequences)\n", + "df = pd.DataFrame(critic_results)\n", + "df.drop(columns=[\"logits\", \"complex\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "d80597fa", + "metadata": {}, + "outputs": [ + { + "data": { + "application/3dmoljs_load.v0": "
\n

3Dmol.js failed to load for some reason. Please check your browser console for error messages.

\n
\n", + "text/html": [ + "
\n", + "

3Dmol.js failed to load for some reason. Please check your browser console for error messages.

\n", + "
\n", + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# ---- Visualize ----\n", + "protein_complex = (\n", + " df[df.critic_name.eq(\"ESMFold2-Experimental-Cutoff2025\")].iloc[0].complex\n", + ")\n", + "(\n", + " py3Dmol.view(width=600, height=600)\n", + " .addModel(protein_complex.to_pdb_string(), \"pdb\")\n", + " .setStyle({\"chain\": \"A\"}, {\"cartoon\": {\"color\": \"green\"}}) # pyright: ignore\n", + " .setStyle({\"chain\": \"B\"}, {\"cartoon\": {\"color\": \"cyan\"}}) # pyright: ignore\n", + " .addStyle( # pyright: ignore\n", + " {\"not\": {\"atom\": [\"N\", \"CA\", \"C\", \"O\"]}},\n", + " {\"stick\": {\"colorscheme\": \"default\", \"radius\": 0.2}},\n", + " )\n", + " .center() # pyright: ignore\n", + " .zoomTo() # pyright: ignore\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "fc105292", + "metadata": {}, + "source": [ + "### Run a sweep - async" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "ac02bbaa", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
target_nametarget_sequencebinder_namebinder_sequenceuse_scaling_criticsseedbatch_size
0pd-l1NoneminibinderNoneFalse01
1pd-l1NoneminibinderNoneFalse11
\n", + "
" + ], + "text/plain": [ + " target_name target_sequence binder_name binder_sequence \\\n", + "0 pd-l1 None minibinder None \n", + "1 pd-l1 None minibinder None \n", + "\n", + " use_scaling_critics seed batch_size \n", + "0 False 0 1 \n", + "1 False 1 1 " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(16, 7)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# ---- Config ----\n", + "save_dir = Path(\"sweep\")\n", + "save_dir.mkdir(exist_ok=True)\n", + "\n", + "# Sweep settings - each key-value pair is an axis of a grid sweep.\n", + "line_sweeps = dict(\n", + " target_name=[\"pd-l1\"],\n", + " target_sequence=[None],\n", + " binder_name=[\"minibinder\", \"trastuzumab_framework_vhvl\"], # two modalities\n", + " binder_sequence=[None],\n", + " use_scaling_critics=[False],\n", + " seed=list(range(8)), # 8 seeds each\n", + " batch_size=[1],\n", + ")\n", + "df = pd.DataFrame(product(*line_sweeps.values()), columns=list(line_sweeps.keys()))\n", + "display(df.head(2))\n", + "df.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "9b768813", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Spawned 16 jobs. It is safe to close the notebook.The next cell will resume from call_id's, saved by Modal for up to 7 days.\n" + ] + } + ], + "source": [ + "# ---- Launch ----\n", + "df[\"call_id\"] = [\n", + " app.design.spawn(\n", + " target_name=row.target_name,\n", + " target_sequence=row.target_sequence,\n", + " binder_name=row.binder_name,\n", + " binder_sequence=row.binder_sequence,\n", + " seed=row.seed,\n", + " batch_size=row.batch_size,\n", + " ).object_id\n", + " for row in df.itertuples()\n", + "]\n", + "df.to_parquet(save_dir / \"manifest.parquet\", index=False)\n", + "print(\n", + " f\"Spawned {len(df)} jobs. It is safe to close the notebook.\"\n", + " \"The next cell will resume from call_id's, saved by Modal for up to 7 days.\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "b9a637f0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "First task url: https://modal.com/id/fc-01KSTCTA27QYFGWNB67BPKZ72Z\n" + ] + }, + { + "data": { + "text/plain": [ + "status\n", + "SUCCESS 16\n", + "Name: count, dtype: int64" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# ---- Monitor ----\n", + "df = pd.read_parquet(save_dir / \"manifest.parquet\")\n", + "df[\"future\"] = df.call_id.transform(modal.FunctionCall.from_id)\n", + "df[\"status\"] = df.future.transform(lambda f: f.get_call_graph()[0].status.name)\n", + "print(\"First task url: \", df.at[0, \"future\"].get_dashboard_url()) # pyright: ignore\n", + "df.status.value_counts()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "973f7d6f", + "metadata": {}, + "outputs": [], + "source": [ + "# ---- Collect ----\n", + "df[\"result\"] = modal.FunctionCall.gather(\n", + " *df.future.tolist()\n", + ") # Blocks until all jobs are complete.\n", + "df[\"result_df\"] = [pd.DataFrame(r[2]) for r in df.result] # pyright: ignore" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "39fd979c", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c85b04099dab48209fa0884367f98e21", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/304 [00:00= 6.\n", + "df_result[\"binder_sequence\"] = df_result.designed_sequence.str.split(r\"\\|\").str[1]\n", + "df_result[\"isoelectric_point\"] = [\n", + " ProteinAnalysis(seq).isoelectric_point()\n", + " for seq in tqdm(df_result.binder_sequence.values)\n", + "]\n", + "# Isoelectric point filter\n", + "df_filter = df_result[df_result.is_antibody | df_result.isoelectric_point.lt(6)]\n", + "\n", + "\n", + "# Select the top 84 designs from each (target, binder) combination\n", + "def select(df: pd.DataFrame) -> pd.DataFrame:\n", + " # Where the cdr-specific iptm proxy exists, use it (antibodies).\n", + " # Else use the full distogram iptm proxy.\n", + " # If neither exists (use_scaling_checkpoints=False), then there is no contribution from this term.\n", + " df[\"iptm_proxy\"] = df.cdr_distogram_iptm_proxy.combine_first(\n", + " df.distogram_iptm_proxy\n", + " ).fillna(0)\n", + " df = df.groupby(\"designed_sequence\", as_index=False).agg(\n", + " dict(iptm=\"mean\", iptm_proxy=\"mean\")\n", + " )\n", + " df[\"selection_score\"] = 0.5 * df.iptm + 0.5 * df.iptm_proxy\n", + " return df.nlargest(min(len(df), 84), \"selection_score\")\n", + "\n", + "\n", + "df_select = df_filter.groupby([\"target_name\", \"binder_name\"]).apply(\n", + " select, include_groups=False\n", + ")\n", + "df_select.to_parquet(save_dir / \"selection.parquet\", index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "b54a71de", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
is_antibodycritic_namebatch_idxdesigned_sequencefinal_lossiptmdistogram_iptm_proxycdr_distogram_iptm_proxytarget_nametarget_sequencebinder_namebinder_sequenceuse_scaling_criticsseedbatch_sizecall_idfuturestatusisoelectric_point
3FalseESMFold2-Experimental-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...2.8870060.9494710.902141NaNpd-l1NoneminibinderQSSDDEIDKEVNKVAAEIALAVAELTRAAADGDDKEVDKQLKKALK...False01fc-01KSTCTA27QYFGWNB67BPKZ72ZFunctionCall.from_id('fc-01KSTCTA27QYFGWNB67BP...SUCCESS9.521739
22FalseESMFold2-Experimental-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.9459100.3925570.610550NaNpd-l1NoneminibinderKWEIWRLLWKIGNNLWNNNNNNNNWNAIWTIWWWLIWWLIWWLLIN...False11fc-01KSTCTA52ZQZXC2RZ4F12ZJNCFunctionCall.from_id('fc-01KSTCTA52ZQZXC2RZ4F1...SUCCESS10.605259
41FalseESMFold2-Experimental-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.6284140.9204260.857930NaNpd-l1NoneminibinderSIIRILIIIVIKAIKKVSKIAKILKKALKELAKSGASKEIVEILIE...False21fc-01KSTCTA7ZBCYQZCJ1DHGT94BZFunctionCall.from_id('fc-01KSTCTA7ZBCYQZCJ1DHG...SUCCESS10.170871
60FalseESMFold2-Experimental-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.7705130.9031920.831106NaNpd-l1NoneminibinderMSLEELLKEIVEALKSGDFKKAAKAIKEAAKIIFSENIEVASAKIL...False31fc-01KSTCTAASYNGFY59G85TYT5SZFunctionCall.from_id('fc-01KSTCTAASYNGFY59G85T...SUCCESS7.856585
79FalseESMFold2-Experimental-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...3.7112600.9143340.849982NaNpd-l1NoneminibinderQNSNNNNNNNNEEDEEIDIKILKILIKLLIIIILLKKSPSSSSKKK...False41fc-01KSTCTAEAAQN1QKMYCDD92HKNFunctionCall.from_id('fc-01KSTCTAEAAQN1QKMYCDD...SUCCESS9.874187
98FalseESMFold2-Experimental-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...3.9598030.9340150.855943NaNpd-l1NoneminibinderSLILNILNIRINEINNLITNASKNELILYLKNLNIILKILLILLQN...False51fc-01KSTCTAHDBY4M10ESJP0HVS5ZFunctionCall.from_id('fc-01KSTCTAHDBY4M10ESJP0...SUCCESS5.117010
117FalseESMFold2-Experimental-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.2762570.9363310.892703NaNpd-l1NoneminibinderLLELLKILVKNAKNFSSSELYIVIMLLEILSNEDPREALILVEEII...False61fc-01KSTCTAM1V79QCSF9612BD1RVFunctionCall.from_id('fc-01KSTCTAM1V79QCSF9612...SUCCESS4.560045
136FalseESMFold2-Experimental-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...3.6636910.8467630.783418NaNpd-l1NoneminibinderQQLQLLIIQLILLIIVKILLQIANILLQEAKLSDSDDSEKIIKTLK...False71fc-01KSTCTAP3G7YRJ1RSQ2NFSTB5FunctionCall.from_id('fc-01KSTCTAP3G7YRJ1RSQ2N...SUCCESS9.399378
155TrueESMFold2-Experimental-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...3.8230500.9289520.8390540.851969pd-l1Nonetrastuzumab_framework_vhvlEVQLVESGGGLVQPGGSLRLSCAASSDRSYSVSYIHWVRQAPGKGL...False01fc-01KSTCTAR7DKSNJXBHARD4FFZ2FunctionCall.from_id('fc-01KSTCTAR7DKSNJXBHARD...SUCCESS6.984682
174TrueESMFold2-Experimental-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.1171040.9088980.7814310.797327pd-l1Nonetrastuzumab_framework_vhvlEVQLVESGGGLVQPGGSLRLSCAASEPLSYRIYIHWVRQAPGKGLE...False11fc-01KSTCTATESMP6KMYPFTQ1B2PFFunctionCall.from_id('fc-01KSTCTATESMP6KMYPFTQ...SUCCESS8.632139
193TrueESMFold2-Experimental-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.4055540.9339340.8271320.846008pd-l1Nonetrastuzumab_framework_vhvlEVQLVESGGGLVQPGGSLRLSCAASKGMEADDYIHWVRQAPGKGLE...False21fc-01KSTCTAXFAQJ8BCKSZ67NRH75FunctionCall.from_id('fc-01KSTCTAXFAQJ8BCKSZ67...SUCCESS6.863217
212TrueESMFold2-Experimental-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...3.7723780.9225570.8251450.848989pd-l1Nonetrastuzumab_framework_vhvlEVQLVESGGGLVQPGGSLRLSCAASGFAISDDYIHWVRQAPGKGLE...False31fc-01KSTCTAZTMDPNCHDTGCQJV0NNFunctionCall.from_id('fc-01KSTCTAZTMDPNCHDTGCQ...SUCCESS7.069998
231TrueESMFold2-Experimental-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.4308480.8985120.7595740.781431pd-l1Nonetrastuzumab_framework_vhvlEVQLVESGGGLVQPGGSLRLSCAASGDDDNLGYIHWVRQAPGKGLE...False41fc-01KSTCTB1QKCF0JA1FZQGV4RF1FunctionCall.from_id('fc-01KSTCTB1QKCF0JA1FZQG...SUCCESS8.622018
250TrueESMFold2-Experimental-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.1636610.9264100.8211710.834086pd-l1Nonetrastuzumab_framework_vhvlEVQLVESGGGLVQPGGSLRLSCAASEAFDTRVVLYIHWVRQAPGKG...False51fc-01KSTCTB4SZT3HVA7S6JR57Q60FunctionCall.from_id('fc-01KSTCTB4SZT3HVA7S6JR...SUCCESS6.982352
269TrueESMFold2-Experimental-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.0762990.7140020.6324070.648303pd-l1Nonetrastuzumab_framework_vhvlEVQLVESGGGLVQPGGSLRLSCAASSSLDDDDFAYIHWVRQAPGKG...False61fc-01KSTCTB6N74561E0PVPNHH68HFunctionCall.from_id('fc-01KSTCTB6N74561E0PVPN...SUCCESS6.856681
288TrueESMFold2-Experimental-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.1314370.9176700.8022940.817197pd-l1Nonetrastuzumab_framework_vhvlEVQLVESGGGLVQPGGSLRLSCAASPDLNFVLNYIHWVRQAPGKGL...False71fc-01KSTCTB9YC1GY5QAQVSZK67GAFunctionCall.from_id('fc-01KSTCTB9YC1GY5QAQVSZ...SUCCESS4.925178
\n", + "
" + ], + "text/plain": [ + " is_antibody critic_name batch_idx \\\n", + "3 False ESMFold2-Experimental-Cutoff2025 0 \n", + "22 False ESMFold2-Experimental-Cutoff2025 0 \n", + "41 False ESMFold2-Experimental-Cutoff2025 0 \n", + "60 False ESMFold2-Experimental-Cutoff2025 0 \n", + "79 False ESMFold2-Experimental-Cutoff2025 0 \n", + "98 False ESMFold2-Experimental-Cutoff2025 0 \n", + "117 False ESMFold2-Experimental-Cutoff2025 0 \n", + "136 False ESMFold2-Experimental-Cutoff2025 0 \n", + "155 True ESMFold2-Experimental-Cutoff2025 0 \n", + "174 True ESMFold2-Experimental-Cutoff2025 0 \n", + "193 True ESMFold2-Experimental-Cutoff2025 0 \n", + "212 True ESMFold2-Experimental-Cutoff2025 0 \n", + "231 True ESMFold2-Experimental-Cutoff2025 0 \n", + "250 True ESMFold2-Experimental-Cutoff2025 0 \n", + "269 True ESMFold2-Experimental-Cutoff2025 0 \n", + "288 True ESMFold2-Experimental-Cutoff2025 0 \n", + "\n", + " designed_sequence final_loss iptm \\\n", + "3 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 2.887006 0.949471 \n", + "22 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.945910 0.392557 \n", + "41 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.628414 0.920426 \n", + "60 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.770513 0.903192 \n", + "79 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 3.711260 0.914334 \n", + "98 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 3.959803 0.934015 \n", + "117 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.276257 0.936331 \n", + "136 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 3.663691 0.846763 \n", + "155 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 3.823050 0.928952 \n", + "174 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.117104 0.908898 \n", + "193 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.405554 0.933934 \n", + "212 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 3.772378 0.922557 \n", + "231 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.430848 0.898512 \n", + "250 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.163661 0.926410 \n", + "269 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.076299 0.714002 \n", + "288 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.131437 0.917670 \n", + "\n", + " distogram_iptm_proxy cdr_distogram_iptm_proxy target_name \\\n", + "3 0.902141 NaN pd-l1 \n", + "22 0.610550 NaN pd-l1 \n", + "41 0.857930 NaN pd-l1 \n", + "60 0.831106 NaN pd-l1 \n", + "79 0.849982 NaN pd-l1 \n", + "98 0.855943 NaN pd-l1 \n", + "117 0.892703 NaN pd-l1 \n", + "136 0.783418 NaN pd-l1 \n", + "155 0.839054 0.851969 pd-l1 \n", + "174 0.781431 0.797327 pd-l1 \n", + "193 0.827132 0.846008 pd-l1 \n", + "212 0.825145 0.848989 pd-l1 \n", + "231 0.759574 0.781431 pd-l1 \n", + "250 0.821171 0.834086 pd-l1 \n", + "269 0.632407 0.648303 pd-l1 \n", + "288 0.802294 0.817197 pd-l1 \n", + "\n", + " target_sequence binder_name \\\n", + "3 None minibinder \n", + "22 None minibinder \n", + "41 None minibinder \n", + "60 None minibinder \n", + "79 None minibinder \n", + "98 None minibinder \n", + "117 None minibinder \n", + "136 None minibinder \n", + "155 None trastuzumab_framework_vhvl \n", + "174 None trastuzumab_framework_vhvl \n", + "193 None trastuzumab_framework_vhvl \n", + "212 None trastuzumab_framework_vhvl \n", + "231 None trastuzumab_framework_vhvl \n", + "250 None trastuzumab_framework_vhvl \n", + "269 None trastuzumab_framework_vhvl \n", + "288 None trastuzumab_framework_vhvl \n", + "\n", + " binder_sequence use_scaling_critics \\\n", + "3 QSSDDEIDKEVNKVAAEIALAVAELTRAAADGDDKEVDKQLKKALK... False \n", + "22 KWEIWRLLWKIGNNLWNNNNNNNNWNAIWTIWWWLIWWLIWWLLIN... False \n", + "41 SIIRILIIIVIKAIKKVSKIAKILKKALKELAKSGASKEIVEILIE... False \n", + "60 MSLEELLKEIVEALKSGDFKKAAKAIKEAAKIIFSENIEVASAKIL... False \n", + "79 QNSNNNNNNNNEEDEEIDIKILKILIKLLIIIILLKKSPSSSSKKK... False \n", + "98 SLILNILNIRINEINNLITNASKNELILYLKNLNIILKILLILLQN... False \n", + "117 LLELLKILVKNAKNFSSSELYIVIMLLEILSNEDPREALILVEEII... False \n", + "136 QQLQLLIIQLILLIIVKILLQIANILLQEAKLSDSDDSEKIIKTLK... False \n", + "155 EVQLVESGGGLVQPGGSLRLSCAASSDRSYSVSYIHWVRQAPGKGL... False \n", + "174 EVQLVESGGGLVQPGGSLRLSCAASEPLSYRIYIHWVRQAPGKGLE... False \n", + "193 EVQLVESGGGLVQPGGSLRLSCAASKGMEADDYIHWVRQAPGKGLE... False \n", + "212 EVQLVESGGGLVQPGGSLRLSCAASGFAISDDYIHWVRQAPGKGLE... False \n", + "231 EVQLVESGGGLVQPGGSLRLSCAASGDDDNLGYIHWVRQAPGKGLE... False \n", + "250 EVQLVESGGGLVQPGGSLRLSCAASEAFDTRVVLYIHWVRQAPGKG... False \n", + "269 EVQLVESGGGLVQPGGSLRLSCAASSSLDDDDFAYIHWVRQAPGKG... False \n", + "288 EVQLVESGGGLVQPGGSLRLSCAASPDLNFVLNYIHWVRQAPGKGL... False \n", + "\n", + " seed batch_size call_id \\\n", + "3 0 1 fc-01KSTCTA27QYFGWNB67BPKZ72Z \n", + "22 1 1 fc-01KSTCTA52ZQZXC2RZ4F12ZJNC \n", + "41 2 1 fc-01KSTCTA7ZBCYQZCJ1DHGT94BZ \n", + "60 3 1 fc-01KSTCTAASYNGFY59G85TYT5SZ \n", + "79 4 1 fc-01KSTCTAEAAQN1QKMYCDD92HKN \n", + "98 5 1 fc-01KSTCTAHDBY4M10ESJP0HVS5Z \n", + "117 6 1 fc-01KSTCTAM1V79QCSF9612BD1RV \n", + "136 7 1 fc-01KSTCTAP3G7YRJ1RSQ2NFSTB5 \n", + "155 0 1 fc-01KSTCTAR7DKSNJXBHARD4FFZ2 \n", + "174 1 1 fc-01KSTCTATESMP6KMYPFTQ1B2PF \n", + "193 2 1 fc-01KSTCTAXFAQJ8BCKSZ67NRH75 \n", + "212 3 1 fc-01KSTCTAZTMDPNCHDTGCQJV0NN \n", + "231 4 1 fc-01KSTCTB1QKCF0JA1FZQGV4RF1 \n", + "250 5 1 fc-01KSTCTB4SZT3HVA7S6JR57Q60 \n", + "269 6 1 fc-01KSTCTB6N74561E0PVPNHH68H \n", + "288 7 1 fc-01KSTCTB9YC1GY5QAQVSZK67GA \n", + "\n", + " future status \\\n", + "3 FunctionCall.from_id('fc-01KSTCTA27QYFGWNB67BP... SUCCESS \n", + "22 FunctionCall.from_id('fc-01KSTCTA52ZQZXC2RZ4F1... SUCCESS \n", + "41 FunctionCall.from_id('fc-01KSTCTA7ZBCYQZCJ1DHG... SUCCESS \n", + "60 FunctionCall.from_id('fc-01KSTCTAASYNGFY59G85T... SUCCESS \n", + "79 FunctionCall.from_id('fc-01KSTCTAEAAQN1QKMYCDD... SUCCESS \n", + "98 FunctionCall.from_id('fc-01KSTCTAHDBY4M10ESJP0... SUCCESS \n", + "117 FunctionCall.from_id('fc-01KSTCTAM1V79QCSF9612... SUCCESS \n", + "136 FunctionCall.from_id('fc-01KSTCTAP3G7YRJ1RSQ2N... SUCCESS \n", + "155 FunctionCall.from_id('fc-01KSTCTAR7DKSNJXBHARD... SUCCESS \n", + "174 FunctionCall.from_id('fc-01KSTCTATESMP6KMYPFTQ... SUCCESS \n", + "193 FunctionCall.from_id('fc-01KSTCTAXFAQJ8BCKSZ67... SUCCESS \n", + "212 FunctionCall.from_id('fc-01KSTCTAZTMDPNCHDTGCQ... SUCCESS \n", + "231 FunctionCall.from_id('fc-01KSTCTB1QKCF0JA1FZQG... SUCCESS \n", + "250 FunctionCall.from_id('fc-01KSTCTB4SZT3HVA7S6JR... SUCCESS \n", + "269 FunctionCall.from_id('fc-01KSTCTB6N74561E0PVPN... SUCCESS \n", + "288 FunctionCall.from_id('fc-01KSTCTB9YC1GY5QAQVSZ... SUCCESS \n", + "\n", + " isoelectric_point \n", + "3 9.521739 \n", + "22 10.605259 \n", + "41 10.170871 \n", + "60 7.856585 \n", + "79 9.874187 \n", + "98 5.117010 \n", + "117 4.560045 \n", + "136 9.399378 \n", + "155 6.984682 \n", + "174 8.632139 \n", + "193 6.863217 \n", + "212 7.069998 \n", + "231 8.622018 \n", + "250 6.982352 \n", + "269 6.856681 \n", + "288 4.925178 " + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_result[df_result.critic_name.eq(\"ESMFold2-Experimental-Cutoff2025\")].drop(\n", + " columns=[\"complex\", \"logits\"]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "851365c0", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
designed_sequenceiptmiptm_proxyselection_score
target_namebinder_name
pd-l1minibinder0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...0.9301000.8566490.893375
1AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...0.9373930.8376940.887544
trastuzumab_framework_vhvl4AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...0.9252670.8100590.867663
3AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...0.9199290.7883860.854157
6AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...0.9172050.7809600.849083
5AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...0.9064160.7534040.829910
1AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...0.9044100.7192590.811835
0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...0.8778600.7230240.800442
2AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...0.8684860.7109970.789742
7AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...0.7633470.6446420.703995
\n", + "
" + ], + "text/plain": [ + " designed_sequence \\\n", + "target_name binder_name \n", + "pd-l1 minibinder 0 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", + " 1 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", + " trastuzumab_framework_vhvl 4 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", + " 3 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", + " 6 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", + " 5 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", + " 1 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", + " 0 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", + " 2 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", + " 7 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", + "\n", + " iptm iptm_proxy \\\n", + "target_name binder_name \n", + "pd-l1 minibinder 0 0.930100 0.856649 \n", + " 1 0.937393 0.837694 \n", + " trastuzumab_framework_vhvl 4 0.925267 0.810059 \n", + " 3 0.919929 0.788386 \n", + " 6 0.917205 0.780960 \n", + " 5 0.906416 0.753404 \n", + " 1 0.904410 0.719259 \n", + " 0 0.877860 0.723024 \n", + " 2 0.868486 0.710997 \n", + " 7 0.763347 0.644642 \n", + "\n", + " selection_score \n", + "target_name binder_name \n", + "pd-l1 minibinder 0 0.893375 \n", + " 1 0.887544 \n", + " trastuzumab_framework_vhvl 4 0.867663 \n", + " 3 0.854157 \n", + " 6 0.849083 \n", + " 5 0.829910 \n", + " 1 0.811835 \n", + " 0 0.800442 \n", + " 2 0.789742 \n", + " 7 0.703995 " + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_select" + ] + }, + { + "cell_type": "markdown", + "id": "7e174c1e", + "metadata": {}, + "source": [ + "## Appendix" + ] + }, + { + "cell_type": "markdown", + "id": "3f79528d", + "metadata": {}, + "source": [ + "### Modal Primer" + ] + }, + { + "cell_type": "markdown", + "id": "4297bc4d", + "metadata": {}, + "source": [ + "- **info: ephemeral vs deployment** \n", + " Ephemeral = temporary app from `modal run` or `app.run()`, stopped when the client exits. Deployment = persistent named app from `modal deploy`, reused and observable across runs. ([modal.com](https://modal.com/docs/guide/apps?utm_source=openai))\n", + "\n", + "- **info: dashboard** \n", + " Generic dashboard/apps page: [https://modal.com/apps](https://modal.com/apps). Modal also prints app/deployment links during runs/deploys. ([modal.com](https://modal.com/docs/guide/apps?utm_source=openai))\n", + "\n", + "- **cli: ephemeral run** \n", + " ```bash\n", + " modal run path/to/app.py\n", + " ```\n", + "\n", + "- **cli: deploy/redeploy** \n", + " ```bash\n", + " modal deploy path/to/app.py\n", + " ```\n", + " Running this on an existing app name redeploys a new version. ([modal.com](https://modal.com/docs/reference/cli/deploy?utm_source=openai))\n", + "\n", + "- **local: ephemeral from Python** \n", + " ```python\n", + " with modal.enable_output():\n", + " with modal_app.run():\n", + " result = local_modal_obj.remote(...)\n", + " ```\n", + "\n", + "- **local: call a deployment** \n", + " ```python\n", + " Cls = modal.Cls.from_name(\"app-name\", \"ClassName\")\n", + " obj = Cls(...)\n", + " result = obj.method.remote(...)\n", + " ```\n", + " `Cls.from_name` references a class from a deployed app lazily. ([modal.com](https://modal.com/docs/reference/modal.Cls?utm_source=openai))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "modal-test", + "language": "python", + "name": "modal-test" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/cookbook/tutorials/esmfold2_esmc_binder_design.py b/cookbook/tutorials/esmfold2_esmc_binder_design.py new file mode 100644 index 0000000..b92f144 --- /dev/null +++ b/cookbook/tutorials/esmfold2_esmc_binder_design.py @@ -0,0 +1,1278 @@ +# /// script +# requires-python = "<=3.13" +# dependencies = [ +# "abnumber", +# "esm@git+https://github.com/Biohub/esm.git@main", +# "modal", +# ] +# /// +""" +Code for binder design with ESMFold2 and ESMC. + +As described in [Language Modeling Materializes a World Model of Protein Biology](https://biohub.ai/papers/esm_protein.pdf). +""" + +import logging +import math +import os +import random +import string +from dataclasses import dataclass +from functools import cache, partial +from typing import Any + +import biotite.structure +import modal +import numpy as np +import torch +import torch.nn.functional as F +import torch.optim as optim +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointImpl, + apply_activation_checkpointing, + checkpoint_wrapper, +) +from transformers.models.esmc.modeling_esmc import ESMCForMaskedLM +from transformers.models.esmc.modeling_esmc import ( + UnifiedTransformerBlock as TransformerBlock, +) +from transformers.models.esmc.tokenization_esmc import ESMCTokenizer +from transformers.models.esmfold2.modeling_esmfold2_common import ( + CUE_AVAILABLE, + PairUpdateBlock, +) +from transformers.models.esmfold2.modeling_esmfold2_common import ( + _seed_context as seed_context, +) +from transformers.models.esmfold2.modeling_esmfold2_experimental import ( + ESMFold2ExperimentalModel, +) +from transformers.models.esmfold2.modeling_esmfold2_experimental import ( + MSAEncoder as ESMFold2MSAEncoder, +) + +from esm.models.esmfold2 import ( + ELEMENT_NUMBER_TO_SYMBOL, + ProteinInput, + StructurePredictionInput, + load_ccd, + prepare_esmfold2_input, +) +from esm.models.esmfold2.constants import ( + MOL_TYPE_NONPOLYMER, + PROTEIN_1TO3, + PROTEIN_3TO1, + RES_TYPE_TO_CCD, +) +from esm.utils.structure.protein_chain import ProteinChain +from esm.utils.structure.protein_complex import ProteinComplex + +os.environ["HF_XET_HIGH_PERFORMANCE"] = "1" +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s") +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +# ---- Constants ---- + + +# General +TOKENS = ["", "-"] + [RES_TYPE_TO_CCD[i] for i in range(2, 33)] +ELEMENTS = ["X"] * (max(ELEMENT_NUMBER_TO_SYMBOL) + 1) +ELEMENTS[0] = "" +for _atomic_num, _symbol in ELEMENT_NUMBER_TO_SYMBOL.items(): + ELEMENTS[_atomic_num] = _symbol[:1] + _symbol[1:].lower() +TOKEN_IDS = {token: idx for idx, token in enumerate(TOKENS)} +AA_DIMS = 20 +# Cysteine index in the 20-dim AA space (TOKEN_IDS are offset by 2 for and -) +CYS_IDX = TOKEN_IDS[PROTEIN_1TO3["C"]] - 2 +MUTABLE_TOKEN = "#" +# Contains AA chars at fixed positions and MUTABLE_TOKEN at mutable positions. +BinderPromptStr = str + +# Design +LOSS_WEIGHTS = {"intra_contact": 0.5, "inter_contact": 0.5, "glob": 0.2} +STEPS = 150 +LOG_INTERVAL = 5 +LEARNING_RATE = 0.1 +TEMPERATURE_MIN = 1e-2 +ESMC_MASK_FRACTION = 0.15 +CHECKPOINT_LM = False +COMPILE = False +# NOTE - This significantly reduces VRAM usage. +# On config (target_name=cd45", binder_name="trastuzumab_framework_vhvl, batch_size=1) +# this reduces VRAM from 51GB -> 27GB. And enables increasing batch size up to 6. +# We are testing this setting in silico, and may change the default to True, in the future. +REUSE_ESMC = False + + +# ---- Prompts ---- + + +@dataclass(frozen=True) +class PromptFactory: + """A simple factory for making binder prompt strings.""" + + name: str + template: str # string with format fields + length_ranges: dict[str, tuple[int, int]] # map from field name tp length range + is_antibody: bool # Used to set LM loss weight for antibodies. + + def sample(self, seed: int) -> BinderPromptStr: + random.seed(seed) + return self.template.format( + **{ + key: MUTABLE_TOKEN * random.randint(low, high) + for key, (low, high) in self.length_ranges.items() + } + ) + + +# fmt: off +BINDER_PROMPT_FACTORIES = { + "minibinder": PromptFactory(name="minibinder", template="{seq}", length_ranges={"seq": (60, 200)}, is_antibody=False), + "trastuzumab_framework_vhvl": PromptFactory( + name="trastuzumab_framework_vhvl", + template="EVQLVESGGGLVQPGGSLRLSCAAS{hcdr1}YIHWVRQAPGKGLEWVARI{hcdr2}TRYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCSR{hcdr3}WGQGTLVTVSSGGGSGGGSGGGSGGGSDIQMTQSPSSLSASVGDRVTITC{lcdr1}WYQQKPGKAPKLLIY{lcdr2}GVPSRFSGSRSGTDFTLTISSLQPEDFATYYC{lcdr3}FGQGTKVEIK", + length_ranges = {"hcdr1": (7, 9), "hcdr2": (5, 6), "hcdr3": (9, 15), "lcdr1": (11, 16), "lcdr2": (7, 7), "lcdr3": (9, 9)}, + is_antibody=True, + ), + "atezolizumab_framework_vhvl": PromptFactory( + name="atezolizumab_framework_vhvl", + template="EVQLVESGGGLVQPGGSLRLSCAAS{hcdr1}WIHWVRQAPGKGLEWVAWI{hcdr2}TYYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCAR{hcdr3}WGQGTLVTVSSGGGSGGGSGGGSGGGSDIQMTQSPSSLSASVGDRVTITC{lcdr1}WYQQKPGKAPKLLIY{lcdr2}GVPSRFSGSGSGTDFTLTISSLQPEDFATYYC{lcdr3}FGQGTKVEIK", + length_ranges = {"hcdr1": (7, 9), "hcdr2": (5, 6), "hcdr3": (9, 15), "lcdr1": (11, 16), "lcdr2": (7, 7), "lcdr3": (9, 9)}, + is_antibody=True, + ), + "ocankitug_framework_vhvl": PromptFactory( + name="ocankitug_framework_vhvl", + template="QVQLVQSGAEVKKPGSSVKVSCKAS{hcdr1}WMHWVRQAPGQGLEWMGII{hcdr2}TSLNQKFQGRVTITADTSTSTAYMELSSLRSEDTAVYYCAR{hcdr3}WGQGTLVTVSSGGGSGGGSGGGSGGGSDIQMTQSPSSLSASVGDRVTITC{lcdr1}WYQQKPGKAPKLLIY{lcdr2}GVPSRFSGSGSGTDFTLTISSLQPEDFATYYC{lcdr3}FGQGTKVEIK", + length_ranges = {"hcdr1": (7, 9), "hcdr2": (5, 6), "hcdr3": (8, 14), "lcdr1": (11, 16), "lcdr2": (7, 7), "lcdr3": (9, 9)}, + is_antibody=True, + ) +} + + +TARGET_SEQUENCES = { + # https://www.uniprot.org/uniprotkb/P08575 389-574 + "cd45": "GSPGEPQIIFCRSEAAHQGVITWNPPQRSFHNFTLCYIKETEKDCLNLDKNLIKYDLQNLKPYTKYVLSLHAYIIAKVQRNGSAAMCHFTTKSAPPSQVWNMTVSMTSDNSMHVKCRPPRDRNGPHERYHLEVEAGNTLVRNESHKNCDFRVKDLQYSTDYTFKAYFHNGDYPGEPFILHHSTSY", + # https://www.uniprot.org/uniprotkb/P16410 37-155 + "ctla4": "MHVAQPAVVLASSRGIASFVCEYASPGKATEVRVTVLRQADSQVTEVCAATYMMGNELTFLDDSICTGTSSGNQVNLTIQGLRAMDTGLYICKVELMYPPPYYLGIGNGTQIYVIDPE", + # https://www.uniprot.org/uniprotkb/P00533 333-524 + "egfr": "RKVCNGIGIGEFKDSLSINATNIKHFKNCTSISGDLHILPVAFRGDSFTHTPPLDPQELDILKTVKEITGFLLIQAWPENRTDLHAFENLEIIRGRTKQHGQFSLAVVSLNITSLGLRSLKEISDGDVIISGNKNLCYANTINWKKLFGTSGQKTKIISNRGENSCKATGQVCHALCSPEGCWGPEPRDCV", + # https://www.uniprot.org/uniprotkb/Q9NZQ7 17-132 + "pd-l1": "AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKNIIQFVHGEEDLKVQHSSYRQRARLLKDQLSLGNAALQITDVKLQDAGVYRCMISYGGADYKRITVKVNA", + # https://www.uniprot.org/uniprotkb/P09619 125-312 + "pdgfr": "GFLPNDAEELFIFLTEITEITIPCRVTDPQLVVTLHEKKGDVALPVPYDHQRGFSGIFEDRSYICKTTIGDREVDSDAYYVYRLQVSSINVSVNAVQTVVRQGENITLMCIVIGNEVVNFEWTYPRKESGRLVEPVTDFLLDMPYHIRSILHIPSAELEDSGTYTCNVTESVNDHQDEKAINITVVE", +} +# fmt: on + + +# ---- Helper functions ---- + + +def build_initial_soft_sequence_logits(sequence: str, batch_size: int) -> torch.Tensor: + """ + Initialize logits with: + - High confidence (10.0) for fixed positions + - Random (~0) for mutable positions + - -1e6 for cysteines + """ + if all(aa == MUTABLE_TOKEN for aa in sequence): + logits = 0.01 * torch.randn([batch_size, len(sequence), AA_DIMS]) + logits[:, :, CYS_IDX] = -1e6 # remove cysteines + else: + logits = torch.zeros([batch_size, len(sequence), AA_DIMS]) + for i, aa in enumerate(sequence): + if aa == MUTABLE_TOKEN: # mutable position - random + logits[:, i, :] = 0.01 * torch.randn(batch_size, AA_DIMS) + logits[:, i, CYS_IDX] = -1e6 + else: # fixed position + assert aa in PROTEIN_1TO3, aa + token_id = TOKEN_IDS[PROTEIN_1TO3[aa]] + logits[:, i, token_id - 2] = 10.0 + + return logits.requires_grad_(True) + + +def build_gradient_mask(sequence: str, batch_size: int) -> torch.Tensor: + """ + Build gradient mask [B, L, V]: + - 0 for fixed (all amino acids) + - 0 for cysteine at all positions + - 1 for non-cysteine amino acids at mutable positions + """ + mask = torch.ones([batch_size, len(sequence), AA_DIMS]) + fixed_positions = [i for i, aa in enumerate(sequence) if aa != MUTABLE_TOKEN] + mask[:, fixed_positions, :] = 0.0 + mask[:, :, CYS_IDX] = 0.0 + return mask + + +def sequence_to_one_hot(sequence: str, device="cuda") -> torch.Tensor: + """Convert target string to one-hot tensor [1, L_target, num_tokens].""" + + const_dict = {token: i for i, token in enumerate(TOKENS)} + target_index = [const_dict[PROTEIN_1TO3[letter]] for letter in sequence] + one_hot = F.one_hot(torch.tensor(target_index), num_classes=len(TOKENS)) + return one_hot.to(device).unsqueeze(0).float() + + +def get_mid_points() -> torch.Tensor: + """128 distance bin midpoints (2p-52 Angstrom range).""" + + boundaries = torch.linspace(2, 52.0, 127) + lower = torch.tensor([1.0]) + upper = torch.tensor([52.0 + 5.0]) + exp_boundaries = torch.cat((lower, boundaries, upper)) + return (exp_boundaries[:-1] + exp_boundaries[1:]) / 2 + + +def binned_entropy( + dgram: torch.Tensor, bin_distance: torch.Tensor, cutoff: float +) -> torch.Tensor: + """Entropy of distance distribution within cutoff (design losses only).""" + + bin_mask = ~(bin_distance < cutoff) + masked_dgram = dgram - (1e7 * bin_mask) + px = torch.softmax(masked_dgram, dim=-1) + log_px = torch.log_softmax(dgram, dim=-1) + return -(px * log_px).sum(-1) + + +def masked_min_k(x: torch.Tensor, mask: torch.Tensor, k: int) -> torch.Tensor: + """Mean of the smallest k values in x under mask along the last dimension.""" + + mask = mask.bool() + y = torch.sort(torch.where(mask, x, float("nan")))[0] + k_mask = (torch.arange(y.shape[-1]).to(y.device) < k) & (~torch.isnan(y)) + return torch.where(k_mask, y, 0).sum(-1) / (k_mask.sum(-1) + 1e-8) + + +def masked_average(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Masked mean along last axis.""" + + mask = mask.bool() + return torch.where(mask, x, 0).sum(-1) / (torch.where(mask, 1, 0).sum(-1) + 1e-8) + + +# ---- Loss functions ---- + + +def compute_contact_loss( + distogram_logits: torch.Tensor, + bin_distance: torch.Tensor, + num_contacts: int, + min_sep: int, + cutoff: float, + chain_mask: torch.Tensor, + binder_mask: torch.Tensor, +) -> torch.Tensor: + """Algorithm 12 Contact Losses. + + Entropy-based contact loss with sequence separation constraint.""" + + con_loss = binned_entropy(distogram_logits, bin_distance, cutoff) + position = torch.arange(distogram_logits.shape[1]) + p_dist = position[:, None] - position[None, :] + if min_sep > 0: + separation_mask = (torch.abs(p_dist) >= min_sep).to(distogram_logits.device) + binder_mask = torch.logical_and(separation_mask, binder_mask) + per_residue = masked_min_k(con_loss, mask=binder_mask, k=num_contacts).to( + distogram_logits.device + ) + return masked_average(per_residue, mask=chain_mask).to(distogram_logits.device) + + +def compute_intra_contact_loss( + distogram_logits: torch.Tensor, binder_length: int, bin_distance: torch.Tensor +) -> torch.Tensor: + """Binder internal contacts (k=2, min_sep=9, cutoff=14A).""" + + full_len = distogram_logits.shape[1] + is_binder = torch.ones(full_len, device=distogram_logits.device) + is_binder[:-binder_length] *= 0.0 + return compute_contact_loss( + distogram_logits, + bin_distance, + num_contacts=2, + min_sep=9, + cutoff=14.0, + chain_mask=is_binder, + binder_mask=is_binder, + ) + + +def compute_inter_contact_loss( + distogram_logits: torch.Tensor, binder_length: int, bin_distance: torch.Tensor +) -> torch.Tensor: + """Binder-target interface (k=1, min_sep=0, cutoff=22A).""" + + full_len = distogram_logits.shape[1] + is_binder = torch.ones(full_len, device=distogram_logits.device) + is_binder[:-binder_length] *= 0.0 + return compute_contact_loss( + distogram_logits, + bin_distance, + num_contacts=1, + min_sep=0, + cutoff=22.0, + chain_mask=1 - is_binder, + binder_mask=is_binder, + ) + + +def compute_globularity_loss( + distogram_logits: torch.Tensor, binder_length: int, bin_distance: torch.Tensor +) -> torch.Tensor: + """Algorithm 13 Globularity Loss. + + Radius of gyration vs theoretical packed protein.""" + + binder_disto = distogram_logits[:, -binder_length:, -binder_length:, :] + n = binder_disto.shape[1] + disto_probs = torch.softmax(binder_disto, dim=-1) + bin_distance = bin_distance.clamp(max=27) + e_sq_dist = torch.sum(disto_probs * torch.square(bin_distance), dim=-1) + sum_sq_dist = torch.sum(torch.tril(e_sq_dist, diagonal=-1), dim=(1, 2)) + rg_term = torch.sqrt(sum_sq_dist / (n * n)) + rg_th = 2.38 * (n**0.365) + return F.elu(rg_term - rg_th) + + +def compute_structure_losses( + distogram_logits: torch.Tensor, binder_length: int +) -> dict[str, torch.Tensor]: + """Compute structural losses and a weighted total.""" + + bin_distance = get_mid_points().to(distogram_logits.device) + losses: dict[str, torch.Tensor] = {} + losses["intra_contact_loss"] = compute_intra_contact_loss( + distogram_logits, binder_length, bin_distance + ) + losses["inter_contact_loss"] = compute_inter_contact_loss( + distogram_logits, binder_length, bin_distance + ) + losses["glob_loss"] = compute_globularity_loss( + distogram_logits, binder_length, bin_distance + ) + B = distogram_logits.size(0) + total = torch.tensor([0.0] * B, device=distogram_logits.device, requires_grad=True) + total = total + LOSS_WEIGHTS["intra_contact"] * losses["intra_contact_loss"] + total = total + LOSS_WEIGHTS["inter_contact"] * losses["inter_contact_loss"] + total = total + LOSS_WEIGHTS["glob"] * losses["glob_loss"] + losses["total_loss"] = total + return losses + + +# ---- Distogram iptm proxy ---- + + +def _binding_confidence_entropy( + dgram: torch.Tensor, bin_distance: torch.Tensor, cutoff: float +) -> torch.Tensor: + """Pair entropy within cutoff; matches rd3 contact_score scoring.""" + + probs = torch.softmax(dgram, dim=-1) + cutoff_mask = bin_distance < cutoff + p_cut = probs[..., cutoff_mask] + p_cut = p_cut / (p_cut.sum(-1, keepdim=True) + 1e-8) + return -(p_cut * torch.log(p_cut + 1e-10)).sum(-1) + + +def _entropy_to_confidence(mean_entropy: float) -> float: + """Map mean pair entropy to [0, 1]; lower entropy → higher score.""" + return float(max(0.0, min(1.0, 1.0 - mean_entropy / math.log(51)))) + + +def _cdr_indices(binder_sequence: str) -> list[int]: + """0-based binder indices for all Chothia CDRs.""" + from abnumber import Chain + from abnumber.common import _anarci_align + + result = _anarci_align( + sequences=[binder_sequence], scheme="chothia", allowed_species=None + )[0] + chains = [ + Chain("".join(result[i][0].values()), scheme="chothia") + for i in range(len(result)) + ] + if len(chains) == 2 and not chains[0].is_heavy_chain(): + chains.reverse() + indices: list[int] = [] + for chain in chains: + for cdr in (chain.cdr1_seq, chain.cdr2_seq, chain.cdr3_seq): + start = binder_sequence.find(cdr) + assert start >= 0 + indices.extend(range(start, start + len(cdr))) + return indices + + +def compute_distogram_iptm_proxy( + distogram_logits: torch.Tensor, + target_length: int, + binder_sequence: str, + is_antibody: bool, +) -> dict[str, float]: + """Algorithm 15 Distogram ipTM Proxy. + + Distogram iptm proxy for a target|binder complex (binder at suffix). + + Returns distogram_iptm_proxy for all designs and + cdr_distogram_iptm_proxy when the binder can be annotated as an + antibody; otherwise the CDR score is NaN. + """ + if distogram_logits.ndim == 4: + distogram_logits = distogram_logits[0] + + binder_length = len(binder_sequence) + assert distogram_logits.shape[0] == target_length + binder_length + + bin_distance = get_mid_points().to(distogram_logits.device) + binder_start = target_length + + def _mean_lowest_k(entropies: torch.Tensor, k: int) -> float: + sorted_entropies, _ = torch.sort(entropies.reshape(-1)) + k = min(k, sorted_entropies.numel()) + return float(sorted_entropies[:k].mean()) + + binder_to_target_entropy = _binding_confidence_entropy( + distogram_logits[binder_start:, :target_length, :], bin_distance, cutoff=22.0 + ) + distogram_iptm_proxy = _entropy_to_confidence( + _mean_lowest_k(binder_to_target_entropy, k=binder_length) + ) + + if not is_antibody: + cdr_distogram_iptm_proxy = float("nan") + else: + cdr_indices = _cdr_indices(binder_sequence) + cdr_rows = [binder_start + i for i in cdr_indices] + cdr_to_target_entropy = _binding_confidence_entropy( + distogram_logits[cdr_rows, :target_length, :], bin_distance, cutoff=22.0 + ) + cdr_distogram_iptm_proxy = _entropy_to_confidence( + _mean_lowest_k(cdr_to_target_entropy, k=len(cdr_indices)) + ) + + return { + "distogram_iptm_proxy": distogram_iptm_proxy, + "cdr_distogram_iptm_proxy": cdr_distogram_iptm_proxy, + } + + +# ---- Folding ---- + + +def _resize_tensor(tensor: torch.Tensor, *, dim: int, size: int) -> torch.Tensor: + current = tensor.shape[dim] + if current >= size: + return tensor.narrow(dim, 0, size) + + pad_shape = list(tensor.shape) + pad_shape[dim] = size - current + pad = torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device) + return torch.cat((tensor, pad), dim=dim) + + +_ATOM_FEATURE_DIMS = { + "ref_pos": 0, + "ref_element": 0, + "ref_charge": 0, + "ref_atom_name_chars": 0, + "ref_space_uid": 0, + "atom_attention_mask": 0, + "atom_to_token": 0, + "is_resolved": 0, + "gt_coords": 1, +} + + +@cache +def _ensure_ccd_loaded() -> None: + load_ccd() + + +def prepare_esmfold2_tensors( + input: StructurePredictionInput, + max_tokens: int | None = None, + max_atoms: int | None = None, + max_seqs: int = 16384, + pad_to_max_seqs: bool = False, + seed: int | None = None, + use_vectorized_msa_assembly: bool = True, +) -> dict[str, torch.Tensor]: + del max_tokens, max_seqs, pad_to_max_seqs, use_vectorized_msa_assembly + _ensure_ccd_loaded() + features, _ = prepare_esmfold2_input(input, seed=seed) + if max_atoms is not None: + for key, dim in _ATOM_FEATURE_DIMS.items(): + if key in features: + features[key] = _resize_tensor(features[key], dim=dim, size=max_atoms) + return features + + +def fold_and_get_distogram( + model: ESMFold2ExperimentalModel, + target_seq: str, + target_one_hot: torch.Tensor, + design: torch.Tensor, + num_loops: int = 0, + num_sampling_steps: int = 1, + calculate_confidence: bool = False, + seed: int | None = None, +) -> dict: + """Prepare inputs, run model forward, return distogram_logits + raw output.""" + padding = (2, 11) + padded_design = F.pad(design, padding, mode="constant", value=0) + + # Argmax to get the designed sequence string. + token_lists = torch.argmax(padded_design, dim=-1) + designed_seq = [ + [PROTEIN_3TO1[TOKENS[int(tkn.item())]] for tkn in token_list] + for token_list in token_lists + ] + seq_list = [target_seq + "|" + "".join(seq) for seq in designed_seq] + max_atoms = None if len(seq_list) == 1 else ((len(seq_list[0]) - 1) * 14) // 32 * 32 + + inputs_list = [] + for seq in seq_list: + sequences = { + sequence: [str(idx)] for idx, sequence in enumerate(seq.split("|")) + } + inputs_raw = StructurePredictionInput( + sequences=[ + ProteinInput(id=chain_id, sequence=sequence, msa=None) + for sequence, chain_id in sequences.items() + ] + ) + inputs_list.append(prepare_esmfold2_tensors(inputs_raw, max_atoms=max_atoms)) + + inputs = { + key: torch.stack([inp[key] for inp in inputs_list], dim=0).cuda() + for key in inputs_list[0] + } + inputs["res_type_soft"] = torch.cat( + (target_one_hot.repeat(design.size(0), 1, 1), padded_design), dim=1 + ) + + with seed_context(seed): + output = model( + **inputs, + num_diffusion_samples=1, + num_sampling_steps=num_sampling_steps, + num_loops=num_loops, + calculate_confidence=calculate_confidence, + seed=seed, + ) + + result: dict = { + "distogram_logits": output["distogram_logits"], + "inputs": inputs, + "inputs_list": inputs_list, + "output": output, + "seq_list": seq_list, + } + if calculate_confidence: + result.update( + { + "ptm": output.get("ptm"), + "iptm": output.get("iptm"), + "plddt": output.get("plddt"), + } + ) + return result + + +_CHAIN_ID_ALPHABET = string.ascii_uppercase + string.ascii_lowercase + string.digits + + +def _asym_id_to_chain_label(asym_id: int) -> str: + if asym_id < 0: + raise ValueError(f"asym_id must be >= 0, got {asym_id}") + label = "" + n = len(_CHAIN_ID_ALPHABET) + while True: + label = _CHAIN_ID_ALPHABET[asym_id % n] + label + asym_id = asym_id // n - 1 + if asym_id < 0: + return label + + +def to_atom_array( + coords: np.ndarray, + atom_to_token: np.ndarray, + res_type: np.ndarray, + residue_index: np.ndarray, + asym_id: np.ndarray, + mol_type: np.ndarray, + ref_atom_name_chars: np.ndarray, + ref_element: np.ndarray, + atom_attention_mask: np.ndarray, + plddt_per_atom: np.ndarray | None = None, +) -> biotite.structure.AtomArray: + atoms = [] + for atom_i, ( + atom_coord, + token_idx, + atom_name_chars, + element_idx, + is_not_pad, + ) in enumerate( + zip( + coords, atom_to_token, ref_atom_name_chars, ref_element, atom_attention_mask + ) + ): + if not is_not_pad: + continue + atoms.append( + biotite.structure.Atom( + coord=atom_coord, + chain_id=_asym_id_to_chain_label(int(asym_id[token_idx])), + res_id=residue_index[token_idx] + 1, + res_name=TOKENS[res_type[token_idx]], + atom_name="".join(chr(c + 32) for c in atom_name_chars if c != 0), + element=ELEMENTS[element_idx], + ins_code=" ", + hetero=mol_type[token_idx] == MOL_TYPE_NONPOLYMER, + b_factor=float(plddt_per_atom[atom_i]) + if plddt_per_atom is not None + else 0.0, + ) + ) + return biotite.structure.array(atoms) + + +def build_complex( + inputs: dict[str, torch.Tensor], output: dict[str, Any] +) -> ProteinComplex: + """Build ProteinComplex from model output.""" + atom_arr = to_atom_array( + coords=output["sample_atom_coords"][0].cpu().numpy(), + atom_to_token=inputs["atom_to_token"][0].cpu().numpy(), + res_type=inputs["res_type"][0].cpu().numpy(), + residue_index=inputs["token_index"][0].cpu().numpy(), + asym_id=inputs["asym_id"][0].cpu().numpy(), + mol_type=inputs["mol_type"][0].cpu().numpy(), + ref_atom_name_chars=inputs["ref_atom_name_chars"][0].cpu().numpy(), + ref_element=inputs["ref_element"][0].cpu().numpy(), + atom_attention_mask=inputs["atom_attention_mask"][0].cpu().numpy(), + ) + return ProteinComplex.from_chains( + [ProteinChain.from_atomarray(a) for a in biotite.structure.chain_iter(atom_arr)] + ) + + +# ---- LM loss ---- + + +@cache +def _folding_trunk_to_lm_aa_vocab_matrix(device: torch.device) -> torch.Tensor: + """Build a matrix of shape [ft_aas=20, lm_aas=20].""" + three_to_one_map = {v: k for k, v in PROTEIN_1TO3.items()} + ft_aas = [three_to_one_map[tok_3letter] for tok_3letter in TOKENS[2:22]] + + lm_vocab = sorted(ESMCTokenizer().vocab.items(), key=lambda x: x[1]) + lm_aas = [lm_vocab[i][0] for i in range(4, 24)] + + ft_to_lm_aa_matrix = torch.zeros(20, 20) + for ft_idx, ft_aa in enumerate(ft_aas): + lm_idx = lm_aas.index(ft_aa) + ft_to_lm_aa_matrix[ft_idx, lm_idx] = 1 + + return ft_to_lm_aa_matrix.to(device=device) + + +def _one_hot_from_probs(probs: torch.Tensor) -> torch.Tensor: + return F.one_hot(torch.argmax(probs, dim=-1), num_classes=probs.size(-1)).to( + probs.dtype + ) + + +def _straight_through(discrete: torch.Tensor, continuous: torch.Tensor) -> torch.Tensor: + return continuous + (discrete - continuous).detach() + + +def compute_esmc_pseudoperplexity_nll( + esmc_model: ESMCForMaskedLM, + binder_design: torch.Tensor, + score_mask: torch.Tensor, + batch_size: int = 4, + n_passes: int = 4, +) -> torch.Tensor: + """Algorithm 14 ESMC Pseudo-perplexity Sequence Regularization. + + Approximate pseudoperplexity NLL via multiple sampled masks.""" + device = binder_design.device + lm_vocab_size = esmc_model.config.vocab_size + model_dtype = esmc_model.esmc.embed.weight.dtype + + target_esm = binder_design @ _folding_trunk_to_lm_aa_vocab_matrix(device) + input_esm = _straight_through(_one_hot_from_probs(target_esm), target_esm) + input_ids = torch.zeros( + (binder_design.size(0), binder_design.size(1) + 2, lm_vocab_size), + dtype=model_dtype, + device=device, + ) + tokenizer = ESMCTokenizer() + input_ids[:, 0, tokenizer.cls_token_id] = 1 + input_ids[:, -1, tokenizer.eos_token_id] = 1 + input_ids[:, 1:-1, 4:24] = input_esm.to(model_dtype) + + if score_mask.ndim == 1: + score_mask = score_mask.unsqueeze(0).expand(binder_design.size(0), -1) + elif score_mask.shape != binder_design.shape[:2]: + raise ValueError( + f"Expected score_mask with shape {(binder_design.size(0), binder_design.size(1))}, " + f"got {tuple(score_mask.shape)}" + ) + score_mask = score_mask.to(device=device, dtype=torch.bool) + + mask_token = torch.zeros(lm_vocab_size, dtype=model_dtype, device=device) + mask_token[esmc_model.config.mask_token_id] = 1 + esmc = esmc_model.esmc + + losses = [] + for batch_idx in range(binder_design.size(0)): + position_indices = score_mask[batch_idx].nonzero(as_tuple=False).flatten() + num_positions = int(position_indices.numel()) + if num_positions == 0: + raise ValueError( + "ESMC pseudoperplexity score mask selected zero positions." + ) + + num_masked = max(1, math.ceil(ESMC_MASK_FRACTION * num_positions)) + random_scores = torch.rand((n_passes, num_positions), device=device) + masked_offsets = random_scores.topk(num_masked, dim=-1, largest=False).indices + pass_masks = torch.zeros( + (n_passes, binder_design.size(1)), dtype=torch.bool, device=device + ) + pass_masks[ + torch.arange(n_passes, device=device)[:, None], + position_indices[masked_offsets], + ] = True + + masked_sequences = input_ids[batch_idx : batch_idx + 1].repeat(n_passes, 1, 1) + mask_rows, mask_cols = pass_masks.nonzero(as_tuple=True) + masked_sequences[mask_rows, mask_cols + 1] = mask_token + + target_weights = target_esm[batch_idx] + masked_nlls = [] + for start in range(0, n_passes, batch_size): + stop = min(start + batch_size, n_passes) + chunk = masked_sequences[start:stop] + with torch.autocast( + device_type="cuda", dtype=torch.bfloat16, enabled=device.type == "cuda" + ): + hidden, *_ = esmc.transformer( + chunk @ esmc.embed.weight.to(chunk.dtype), + sequence_id=None, + layers_to_collect=[], + output_attentions=False, + ) + logits = esmc_model.lm_head(hidden) + log_probs = logits.log_softmax(dim=-1)[:, 1:-1, 4:24] + nlls = -(log_probs * target_weights.to(log_probs.dtype).unsqueeze(0)).sum( + dim=-1 + ) + masked_nlls.append(nlls[pass_masks[start:stop]]) + + losses.append(torch.cat(masked_nlls, dim=0).mean()) + + return torch.stack(losses, dim=0) + + +# ---- Design ---- + + +def normalized_gradient_tensor( + grad: torch.Tensor, gradient_mask: torch.Tensor +) -> torch.Tensor: + masked_grad = grad * gradient_mask + index_has_nonzero_grad = torch.square(masked_grad).sum(-1) > 0 # (B, L) + eff_L = index_has_nonzero_grad.sum(-1) # (B,) + grad_norm = torch.linalg.norm(masked_grad, axis=(-1, -2)) # (B,) + normalized_grad = (masked_grad / (grad_norm[:, None, None] + 1e-7)) * torch.sqrt( + eff_L[:, None, None] + ) + return normalized_grad * gradient_mask + + +def design_binder( + inversion_models: dict[str, ESMFold2ExperimentalModel], + hf_critic_models: dict[str, ESMFold2ExperimentalModel], + esmc_model: ESMCForMaskedLM, + target_name: str | None, + target_sequence: str | None, + binder_name: str | None, + binder_sequence: str | None, + is_antibody: bool | None, + seed: int, + batch_size: int = 1, +) -> tuple[list[str], dict[int, dict[str, torch.Tensor]], list[dict]]: + """ + Algorithm 11 Gradient-Guided Binder Sequence Optimization. + + Run the full optimization loop. + Returns dict with designed_sequence, complex, and trajectory. + + Every critic is folded once on the best designed sequence via HF ESMFold2. + Hero critics expose iPTM; scaling critics contribute distogram scores only. + ``distogram_binding_confidence`` / ``cdr_distogram_binding_confidence`` come + from the distogram in all cases. + """ + # Vet inputs + assert (target_name is None) ^ ( + target_sequence is None + ), "Provide either target name or sequence." + assert (binder_name is None) ^ ( + binder_sequence is None + ), "Provide either binder name or sequence." + + # Setup + device = "cuda" + if target_name is not None: + target_sequence = TARGET_SEQUENCES[target_name] + else: + assert target_sequence is not None + target_one_hot = sequence_to_one_hot(target_sequence, device=device) + + if binder_name is None: + assert binder_sequence is not None + # If no binder_name and is_antibody is not specified, assume False. + if is_antibody is None: + is_antibody = False + else: + binder_prompt_factor = BINDER_PROMPT_FACTORIES[binder_name] + if is_antibody is not None: + assert ( + binder_prompt_factor.is_antibody == is_antibody + ), "Conflict in is_antibody settings." + is_antibody = binder_prompt_factor.is_antibody + binder_sequence = binder_prompt_factor.sample(seed=seed) + + binder_length = len(binder_sequence) + + # By default, we only support single binder and target chains. + # To support this case, remove the asserts below and check that losses + # and selection metrics are appropriate for your multi-chain case. + assert "|" not in target_sequence + assert "|" not in binder_sequence + + with seed_context(seed), torch.device(device): + logits = build_initial_soft_sequence_logits( + binder_sequence, batch_size=batch_size + ) + gradient_mask = build_gradient_mask(binder_sequence, batch_size=batch_size) + + # step -> {loss_name: [B] tensor on CPU} + trajectory: dict[int, dict[str, torch.Tensor]] = {} + global_step = 0 + + def run_step( + logits: torch.Tensor, + optimizer: optim.Optimizer, + temperature: float, + calculate_confidence: bool, + ) -> tuple[torch.Tensor, list[str], list[float] | None]: + nonlocal global_step + optimizer.zero_grad() + + random.seed(seed + global_step) + replicate_choice = random.randint(0, len(inversion_models) - 1) + inversion_model = list(inversion_models.values())[replicate_choice] + design = F.softmax(logits / temperature, dim=-1) + + fold_result = fold_and_get_distogram( + inversion_model, + target_sequence, + target_one_hot, + design, + num_loops=1, + num_sampling_steps=50 if calculate_confidence else 1, + calculate_confidence=calculate_confidence, + seed=seed + global_step, + ) + sequences: list[str] = fold_result["seq_list"] + losses = compute_structure_losses( + fold_result["distogram_logits"], binder_length + ) + structure_loss = losses["total_loss"] + structure_grad = torch.autograd.grad(structure_loss.mean(), logits)[0] + + # Recompute the logits -> design transform for a fresh graph. + design = F.softmax(logits / temperature, dim=-1) + score_mask = gradient_mask.sum(dim=-1) > 0 + with seed_context(seed + global_step): + plm_loss = compute_esmc_pseudoperplexity_nll( + esmc_model=esmc_model, + binder_design=design, + score_mask=score_mask, + batch_size=4, + n_passes=4, + ) + plm_grad = torch.autograd.grad(plm_loss.mean(), logits)[0] + + logits.grad = normalized_gradient_tensor(structure_grad, gradient_mask) + ( + 0.05 if is_antibody else 0.15 + ) * normalized_gradient_tensor(plm_grad, gradient_mask) + + for g in optimizer.param_groups: + g["lr"] = LEARNING_RATE * temperature + + optimizer.step() + + step = global_step + step_losses = {k: v.detach().cpu() for k, v in losses.items()} + step_losses["plm_loss"] = plm_loss.detach().cpu() + step_losses["total_loss"] = (structure_loss + plm_loss).detach().cpu() + trajectory[step] = step_losses + loss_str = " ".join( + f"{k}={v.mean().item():.4f}" for k, v in step_losses.items() + ) + if step % LOG_INTERVAL == 0: + logger.info(f" step {step:3d} | {loss_str} T={temperature:.4f}") + global_step += 1 + return logits, sequences, fold_result.get("iptm", None) + + # Optimize + optimizer = optim.SGD([logits], lr=LEARNING_RATE) + best_iptm: list[float] = [-1.0] * batch_size + best_sequences: list[str] = [""] * batch_size + for step in range(STEPS): + # Cosine schedule + t = (step + 1) / STEPS + remaining = 0.5 * (1 + math.cos(math.pi * t)) + temperature = TEMPERATURE_MIN + (1 - TEMPERATURE_MIN) * remaining + logits, sequences, iptm = run_step( + logits, + optimizer, + temperature=temperature, + calculate_confidence=temperature < 0.05, + ) + if iptm is not None: + for b in range(batch_size): + if iptm[b] is not None and iptm[b] > best_iptm[b]: + best_iptm[b] = iptm[b] + best_sequences[b] = sequences[b] + assert all(seq != "" for seq in best_sequences) + + # Score + critic_results: list[dict] = [] + target_length = len(target_sequence.replace("|", "")) + for batch_idx in range(batch_size): + best_seq = best_sequences[batch_idx] + binder_seq = best_seq.split("|")[-1] + binder_design = sequence_to_one_hot(binder_seq)[..., 2:22] + for critic_name, critic_model in hf_critic_models.items(): + is_scaling_critic = "ESMFold2-Experimental-Fast-base" in critic_name + if is_scaling_critic: + critic_model.cuda() + final_fold = fold_and_get_distogram( + critic_model, + target_sequence, + target_one_hot, + binder_design, + num_loops=3, + num_sampling_steps=200, + calculate_confidence=True, + seed=seed, + ) + if is_scaling_critic: + critic_model.cpu() + pred_complex = build_complex(final_fold["inputs"], final_fold["output"]) + iptm_proxy_scores = compute_distogram_iptm_proxy( + final_fold["distogram_logits"], target_length, binder_seq, is_antibody + ) + iptm = final_fold["iptm"].item() if final_fold["iptm"] is not None else None + critic_results.append( + { + "is_antibody": is_antibody, + "critic_name": critic_name, + "batch_idx": batch_idx, + "designed_sequence": best_seq, + "complex": pred_complex, + "final_loss": trajectory[global_step - 1]["total_loss"][ + batch_idx + ].item(), + "iptm": iptm, + "logits": logits[batch_idx].detach().cpu(), + **iptm_proxy_scores, + } + ) + + if not critic_results: + for batch_idx in range(batch_size): + critic_results.append( + { + "is_antibody": is_antibody, + "batch_idx": batch_idx, + "designed_sequence": best_sequences[batch_idx], + "final_loss": trajectory[global_step - 1]["total_loss"][ + batch_idx + ].item(), + "logits": logits[batch_idx].detach().cpu(), + } + ) + + return best_sequences, trajectory, critic_results + + +# ---- Model Loading ---- + +_ESMC = None + + +def _load_hf_model( + critic_name: str, lm_dropout: float, cache_esmc: bool, device: str +) -> Any: + """Loads ESMFold2 from huggingface. Will cache ESMC-6B among + all non-scaling checkpoints, to save on VRAM and load time.""" + global _ESMC + repo_id = f"biohub/{critic_name}" + model = ESMFold2ExperimentalModel.from_pretrained(repo_id, load_esmc=not cache_esmc) + if cache_esmc: + if _ESMC is None: + model.load_esmc(model.config.esmc_id) + _ESMC = model._esmc + else: + model._esmc = _ESMC + model.configure_lm_dropout(lm_dropout, force_lm_dropout_during_inference=True) + model.set_kernel_backend("cuequivariance" if CUE_AVAILABLE else None) + return model.to(device=device).eval().requires_grad_(False) + + +def _apply_torch_compile(model: torch.nn.Module) -> None: + """A helper for torch compiling the model.""" + torch._dynamo.config.cache_size_limit = 512 + torch._dynamo.config.accumulated_cache_size_limit = 512 + + compile_targets = (ESMFold2MSAEncoder, PairUpdateBlock, TransformerBlock) + + def _maybe_compile_module(module: torch.nn.Module) -> None: + if not isinstance(module, compile_targets): + return + module.forward = torch.compile(module.forward) # pyright: ignore + + model.apply(_maybe_compile_module) + + +class ESMFold2Design: + lm_name = "biohub/ESMC-6B" + inversion_model_names: list[str] = [ + "ESMFold2-Experimental-Fast", + "ESMFold2-Experimental-Fast-Cutoff2025", + ] + hero_critic_hf_paths: list[str] = [ + "ESMFold2-Experimental-Fast", + "ESMFold2-Experimental-Fast-Cutoff2025", + "ESMFold2-Experimental", + "ESMFold2-Experimental-Cutoff2025", + ] + scaling_critic_hf_paths: list[str] = [] + + def load(self, use_scaling_critics: bool): + if use_scaling_critics: + self.scaling_critic_hf_paths = [ + f"ESMFold2-Experimental-Fast-base{size}-step{step}k" + for size in ("300M", "600M", "6B") + for step in ("250", "500", "750", "1000", "1500") + ] + + self.inversion_models = { + model_name: _load_hf_model( + model_name, lm_dropout=0.5, cache_esmc=True, device="cuda" + ) + for model_name in self.inversion_model_names + } + if COMPILE: + for model in self.inversion_models.values(): + _apply_torch_compile(model) + + self.hf_critic_models: dict[str, Any] = {} + for name in self.hero_critic_hf_paths: + self.hf_critic_models[name] = _load_hf_model( + name, lm_dropout=0.25, cache_esmc=True, device="cuda" + ) + for name in self.scaling_critic_hf_paths: + self.hf_critic_models[name] = _load_hf_model( + name, lm_dropout=0.25, cache_esmc=False, device="cpu" + ) + + self.esmc_model = ESMCForMaskedLM.from_pretrained( + self.lm_name, torch_dtype=torch.float32 + ) + if REUSE_ESMC: + del self.esmc_model.esmc + torch.cuda.empty_cache() + self.esmc_model.esmc = self.inversion_models[ + "ESMFold2-Experimental-Fast" + ]._esmc + self.esmc_model = self.esmc_model.cuda().eval().requires_grad_(False) + + if CHECKPOINT_LM: + apply_activation_checkpointing( + self.esmc_model, + checkpoint_wrapper_fn=partial( + checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT + ), + check_fn=lambda module: isinstance(module, TransformerBlock), + ) + + def design( + self, + target_name: str | None = None, + target_sequence: str | None = None, + binder_name: str | None = None, + binder_sequence: str | None = None, + is_antibody: bool | None = None, + seed: int = 0, + batch_size: int = 1, + ) -> tuple[list[str], dict[int, dict[str, torch.Tensor]], list[dict]]: + return design_binder( + self.inversion_models, + self.hf_critic_models, + self.esmc_model, + target_name=target_name, + target_sequence=target_sequence, + binder_name=binder_name, + binder_sequence=binder_sequence, + is_antibody=is_antibody, + seed=seed, + batch_size=batch_size, + ) + + +# ---- Modal ---- + + +FLASH_ATTN_WHEEL = ( + "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/" + "v2.8.3/flash_attn-2.8.3+cu12torch2.8cxx11abiFALSE-cp312-cp312-linux_x86_64.whl" +) +TRANSFORMER_ENGINE_TORCH_WHEEL = ( + "transformer-engine-torch @ https://github.com/evolutionaryscale/wheels/" + "releases/download/transformer-engine-torch/v2.13.0-pt2.8-cu128-cp312/" + "transformer_engine_torch-2.13.0-cp312-cp312-linux_x86_64.whl" +) + + +def get_base_image(): + return ( + modal.Image.micromamba(python_version="3.12") + .run_commands("apt update && apt install -y git build-essential") + .micromamba_install( + "anarci>=2020.04.03", + "hmmer=3.4", + "cuda-version=12.8", + "cuda-libraries-dev=12.8", + channels=["conda-forge", "bioconda"], + ) + .pip_install( + "torch==2.8.0", + "triton==3.4.0", + index_url="https://download.pytorch.org/whl/cu128", + ) + .pip_install( + FLASH_ATTN_WHEEL, + "transformer-engine[core-cu12,pytorch]==2.13.0", + TRANSFORMER_ENGINE_TORCH_WHEEL, + "xformers==0.0.32.post1", + ) + .pip_install( + "abnumber", "esm@git+https://github.com/Biohub/esm.git@main", "modal" + ) + .env( + { + "HF_HOME": "/models", + "HF_XET_HIGH_PERFORMANCE": "1", + "XFORMERS_IGNORE_FLASH_VERSION_CHECK": "1", + } + ) + ) + + +app = modal.App( + name="esmfold2-design", + image=get_base_image(), + volumes={ + "/models": modal.Volume.from_name("esmfold2-models", create_if_missing=True) + }, +) + + +# If use_scaling_checkpoints is True, `memory` should be increased to 60 * 1024. +@app.cls(gpu="H100", timeout=60 * 60, cpu=16, memory=10 * 1024) +class ESMFold2DesignModal(ESMFold2Design): + """Modal entrypoint. Hero critics are HF experimental exports with + confidence heads. Set ``use_scaling_critics=True`` to also load the + 15-checkpoint rd3 scaling-experiment ensemble (distogram binding confidence only). + """ + + use_scaling_critics: bool = modal.parameter(default=False) + + @modal.enter() + def load(self): + return super().load(self.use_scaling_critics) + + @modal.method() + def design(self, *args, **kws): + return super().design(*args, **kws) + + +@app.local_entrypoint() +def main( + target_name: str | None = None, + target_sequence: str | None = None, + binder_name: str | None = None, + binder_sequence: str | None = None, + use_scaling_critics: bool = False, + is_antibody: bool | None = None, + local: bool = False, + seed: int = 0, + batch_size: int = 1, +): + if local: + assert not use_scaling_critics, ( + "'abnumber' will fail if running this script with uv run. " + "It requires conda packages. To be addressed soon." + ) + app = ESMFold2Design() + app.load(use_scaling_critics) + run_fn = app.design + else: + app = ESMFold2DesignModal(use_scaling_critics=use_scaling_critics) + run_fn = app.design.remote + + seq, trajectory, results = run_fn( + target_name=target_name, + target_sequence=target_sequence, + binder_name=binder_name, + binder_sequence=binder_sequence, + is_antibody=is_antibody, + seed=seed, + batch_size=batch_size, + ) + + avg_final_loss = sum(r["final_loss"] for r in results) / len(results) + logger.info(f"\nDesigned sequence: {seq}") + logger.info(f"Trajectory length: {len(trajectory)} steps") + logger.info(f"Average final loss: {avg_final_loss:.4f}") + + +if __name__ == "__main__": + # Run a single local design. + main( + # Example case 1 + target_name="pd-l1", + binder_name="minibinder", + is_antibody=False, + # Example case 2 + # target_name="cd45", + # binder_name="trastuzumab_framework_vhvl", + # is_antibody=True, + # Common settings + seed=0, + batch_size=1, + local=True, + use_scaling_critics=False, + ) diff --git a/esm/models/esmfold2/processor.py b/esm/models/esmfold2/processor.py index 3616188..0a67b5d 100644 --- a/esm/models/esmfold2/processor.py +++ b/esm/models/esmfold2/processor.py @@ -288,7 +288,7 @@ def fold( model: Any, input: StructurePredictionInput, *, - num_loops: int = 3, + num_loops: int = 20, num_sampling_steps: int = 200, num_diffusion_samples: int = 1, seed: int | None = None, diff --git a/esm/sdk/api.py b/esm/sdk/api.py index d9230a9..54221ff 100644 --- a/esm/sdk/api.py +++ b/esm/sdk/api.py @@ -366,7 +366,7 @@ class FoldingConfig: include_pae: bool = False include_pair_chains_iptm: bool = False num_sampling_steps: int = 100 - num_loops: int = 10 + num_loops: int = 20 include_embeddings: bool = False From b20a1730afd0f33c3a6f7a855738b0f7b3a4337b Mon Sep 17 00:00:00 2001 From: Zeming Lin Date: Mon, 1 Jun 2026 15:45:14 +0000 Subject: [PATCH 2/6] Change tutorial to notebook --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1809841..26d1ca5 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ We are releasing a world model for protein biology: a scientific engine for pred
-ESMFold2 is validated in the lab across five therapeutic targets. Inversion of ESMFold2 enables generation of de novo minibinders and antibody-derived scFvs with high hit rates, nanomolar affinities, target specificity, and functional activity. We've released the full protocol from target sequence to ranked binder design in this [tutorial](https://github.com/Biohub/esm/blob/main/cookbook/tutorials/esmfold2_esmc_binder_design.ipynb). For additional details, please refer to the [preprint](https://biohub.ai/papers/esm_protein.pdf). +ESMFold2 is validated in the lab across five therapeutic targets. Inversion of ESMFold2 enables generation of de novo minibinders and antibody-derived scFvs with high hit rates, nanomolar affinities, target specificity, and functional activity. We've released the full protocol from target sequence to ranked binder design in this [notebook](https://github.com/Biohub/esm/blob/main/cookbook/tutorials/esmfold2_esmc_binder_design.ipynb). For additional details, please refer to the [preprint](https://biohub.ai/papers/esm_protein.pdf).
From 6c5282359dc9766f6ab03904d59a35cf7606eb51 Mon Sep 17 00:00:00 2001 From: Zeming Lin Date: Mon, 1 Jun 2026 15:59:10 +0000 Subject: [PATCH 3/6] fix tests --- .pre-commit-config.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 848375d..46c4c15 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,6 +31,8 @@ repos: language: system types: [python] pass_filenames: true # For speed, we only check the files that are changed + # Modal-app tutorial: deps (modal, abnumber) and dynamic decorators aren't resolvable in the lint env. + exclude: ^cookbook/tutorials/esmfold2_esmc_binder_design\.py$ - repo: https://github.com/gitleaks/gitleaks rev: v8.24.2 hooks: From 50b4ef09873173b8f52630c82e0f04ac9b4f8d79 Mon Sep 17 00:00:00 2001 From: Zeming Lin Date: Mon, 1 Jun 2026 17:07:39 +0000 Subject: [PATCH 4/6] notebook bugfix --- .../esmfold2_esmc_binder_design.ipynb | 1348 ++++------------- .../tutorials/esmfold2_esmc_binder_design.py | 4 +- 2 files changed, 277 insertions(+), 1075 deletions(-) diff --git a/cookbook/tutorials/esmfold2_esmc_binder_design.ipynb b/cookbook/tutorials/esmfold2_esmc_binder_design.ipynb index c04234b..2372298 100644 --- a/cookbook/tutorials/esmfold2_esmc_binder_design.ipynb +++ b/cookbook/tutorials/esmfold2_esmc_binder_design.ipynb @@ -97,7 +97,8 @@ "metadata": {}, "outputs": [], "source": [ - "ESMFold2Design = modal.Cls.from_name(\"esmfold2-design\", \"ESMFold2DesignModal\")\n", + "# ESMFold2Design = modal.Cls.from_name(\"esmfold2-design-jun1-11am\", \"ESMFold2DesignModal\")\n", + "ESMFold2Design = modal.Cls.from_name(\"esmfold2-design-jun1-12pm\", \"ESMFold2DesignModal\")\n", "# Set 'use_scaling_critics' to evaluate with the additional critics.\n", "# Off by default. But cells below were populated with them enabled.\n", "app = ESMFold2Design(use_scaling_critics=False)" @@ -120,7 +121,7 @@ { "data": { "text/plain": [ - "'https://modal.com/id/fc-01KSTCT9W9PYKN3HEKEZ168VJP'" + "'https://modal.com/id/fc-01KT1ZA3NQ0JTF2B4HNCR159NM'" ] }, "execution_count": 4, @@ -137,21 +138,10 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "id": "c1bda1a6", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'https://modal.com/id/fc-01KSTCT9YCT8HH50718ZBABJRT'" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "# ---- Option 2: Provide your own sequences. ----\n", "# Our pd-l1 sequence crop.\n", @@ -175,340 +165,18 @@ "source": [ "# ---- Monitor ----\n", "# Tail a function's output here in jupyter\n", - "! modal app logs esmfold2-design -f --function-call {future2.object_id}" + "! modal app logs esmfold2-design -f --function-call {future.object_id}" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "79ea37f3", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Best sequences: ['AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKNIIQFVHGEEDLKVQHSSYRQRARLLKDQLSLGNAALQITDVKLQDAGVYRCMISYGGADYKRITVKVNA|EVQLVESGGGLVQPGGSLRLSCAASEPADEDDYIHWVRQAPGKGLEWVARITYEEKTRYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCSRWTAMAIGNDVAWGQGTLVTVSSGGGSGGGSGGGSGGGSDIQMTQSPSSLSASVGDRVTITCRFSQDVTIRLSWYQQKPGKAPKLLIYFAFILANGVPSRFSGSRSGTDFTLTISSLQPEDFATYYCNYTRYSSSRFGQGTKVEIK']\n" - ] - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
is_antibodycritic_namebatch_idxdesigned_sequencefinal_lossiptmdistogram_iptm_proxycdr_distogram_iptm_proxy
0TrueESMFold2-Experimental-Fast0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.5445020.9284950.8509760.873329
1TrueESMFold2-Experimental-Fast-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.5445020.9148860.8370670.856937
2TrueESMFold2-Experimental0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.5445020.9145340.8241510.839054
3TrueESMFold2-Experimental-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.5445020.9276020.8350800.858924
4TrueESMFold2-Experimental-Fast-base300M-step250k0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.544502NaN0.7605680.791366
5TrueESMFold2-Experimental-Fast-base300M-step500k0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.544502NaN0.8171970.826138
6TrueESMFold2-Experimental-Fast-base300M-step750k0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.544502NaN0.7536130.773483
7TrueESMFold2-Experimental-Fast-base300M-step1000k0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.544502NaN0.8311060.860911
8TrueESMFold2-Experimental-Fast-base300M-step1500k0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.544502NaN0.7933530.737717
9TrueESMFold2-Experimental-Fast-base600M-step250k0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.544502NaN0.7943460.810242
10TrueESMFold2-Experimental-Fast-base600M-step500k0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.544502NaN0.7863990.816203
11TrueESMFold2-Experimental-Fast-base600M-step750k0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.544502NaN0.8102420.834086
12TrueESMFold2-Experimental-Fast-base600M-step1000k0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.544502NaN0.7814310.810242
13TrueESMFold2-Experimental-Fast-base600M-step1500k0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.544502NaN0.7953400.824151
14TrueESMFold2-Experimental-Fast-base6B-step250k0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.544502NaN0.7913660.819184
15TrueESMFold2-Experimental-Fast-base6B-step500k0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.544502NaN0.8003070.828125
16TrueESMFold2-Experimental-Fast-base6B-step750k0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.544502NaN0.8122290.845015
17TrueESMFold2-Experimental-Fast-base6B-step1000k0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.544502NaN0.8340860.859917
18TrueESMFold2-Experimental-Fast-base6B-step1500k0AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.544502NaN0.7774570.798320
\n", - "
" - ], - "text/plain": [ - " is_antibody critic_name batch_idx \\\n", - "0 True ESMFold2-Experimental-Fast 0 \n", - "1 True ESMFold2-Experimental-Fast-Cutoff2025 0 \n", - "2 True ESMFold2-Experimental 0 \n", - "3 True ESMFold2-Experimental-Cutoff2025 0 \n", - "4 True ESMFold2-Experimental-Fast-base300M-step250k 0 \n", - "5 True ESMFold2-Experimental-Fast-base300M-step500k 0 \n", - "6 True ESMFold2-Experimental-Fast-base300M-step750k 0 \n", - "7 True ESMFold2-Experimental-Fast-base300M-step1000k 0 \n", - "8 True ESMFold2-Experimental-Fast-base300M-step1500k 0 \n", - "9 True ESMFold2-Experimental-Fast-base600M-step250k 0 \n", - "10 True ESMFold2-Experimental-Fast-base600M-step500k 0 \n", - "11 True ESMFold2-Experimental-Fast-base600M-step750k 0 \n", - "12 True ESMFold2-Experimental-Fast-base600M-step1000k 0 \n", - "13 True ESMFold2-Experimental-Fast-base600M-step1500k 0 \n", - "14 True ESMFold2-Experimental-Fast-base6B-step250k 0 \n", - "15 True ESMFold2-Experimental-Fast-base6B-step500k 0 \n", - "16 True ESMFold2-Experimental-Fast-base6B-step750k 0 \n", - "17 True ESMFold2-Experimental-Fast-base6B-step1000k 0 \n", - "18 True ESMFold2-Experimental-Fast-base6B-step1500k 0 \n", - "\n", - " designed_sequence final_loss iptm \\\n", - "0 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 0.928495 \n", - "1 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 0.914886 \n", - "2 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 0.914534 \n", - "3 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 0.927602 \n", - "4 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 NaN \n", - "5 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 NaN \n", - "6 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 NaN \n", - "7 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 NaN \n", - "8 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 NaN \n", - "9 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 NaN \n", - "10 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 NaN \n", - "11 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 NaN \n", - "12 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 NaN \n", - "13 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 NaN \n", - "14 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 NaN \n", - "15 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 NaN \n", - "16 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 NaN \n", - "17 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 NaN \n", - "18 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.544502 NaN \n", - "\n", - " distogram_iptm_proxy cdr_distogram_iptm_proxy \n", - "0 0.850976 0.873329 \n", - "1 0.837067 0.856937 \n", - "2 0.824151 0.839054 \n", - "3 0.835080 0.858924 \n", - "4 0.760568 0.791366 \n", - "5 0.817197 0.826138 \n", - "6 0.753613 0.773483 \n", - "7 0.831106 0.860911 \n", - "8 0.793353 0.737717 \n", - "9 0.794346 0.810242 \n", - "10 0.786399 0.816203 \n", - "11 0.810242 0.834086 \n", - "12 0.781431 0.810242 \n", - "13 0.795340 0.824151 \n", - "14 0.791366 0.819184 \n", - "15 0.800307 0.828125 \n", - "16 0.812229 0.845015 \n", - "17 0.834086 0.859917 \n", - "18 0.777457 0.798320 " - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "# ---- Load result ----\n", - "best_sequences, trajectory, critic_results = future2.get()\n", + "best_sequences, trajectory, critic_results = future.get()\n", "print(\"Best sequences: \", best_sequences)\n", "df = pd.DataFrame(critic_results)\n", "df.drop(columns=[\"logits\", \"complex\"])" @@ -516,16 +184,16 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 54, "id": "d80597fa", "metadata": {}, "outputs": [ { "data": { - "application/3dmoljs_load.v0": "
\n

3Dmol.js failed to load for some reason. Please check your browser console for error messages.

\n
\n", + "application/3dmoljs_load.v0": "
\n

3Dmol.js failed to load for some reason. Please check your browser console for error messages.

\n
\n", "text/html": [ - "
\n", - "

3Dmol.js failed to load for some reason. Please check your browser console for error messages.

\n", + "
\n", + "

3Dmol.js failed to load for some reason. Please check your browser console for error messages.

\n", "
\n", "" ] @@ -581,10 +249,10 @@ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 9, + "execution_count": 54, "metadata": {}, "output_type": "execute_result" } @@ -600,7 +268,7 @@ " .setStyle({\"chain\": \"A\"}, {\"cartoon\": {\"color\": \"green\"}}) # pyright: ignore\n", " .setStyle({\"chain\": \"B\"}, {\"cartoon\": {\"color\": \"cyan\"}}) # pyright: ignore\n", " .addStyle( # pyright: ignore\n", - " {\"not\": {\"atom\": [\"N\", \"CA\", \"C\", \"O\"]}},\n", + " {\"not\": {\"atom\": [\"N\", \"C\", \"O\"]}},\n", " {\"stick\": {\"colorscheme\": \"default\", \"radius\": 0.2}},\n", " )\n", " .center() # pyright: ignore\n", @@ -618,7 +286,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "ac02bbaa", "metadata": {}, "outputs": [ @@ -693,10 +361,10 @@ { "data": { "text/plain": [ - "(16, 7)" + "(256, 7)" ] }, - "execution_count": 6, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -713,7 +381,7 @@ " binder_name=[\"minibinder\", \"trastuzumab_framework_vhvl\"], # two modalities\n", " binder_sequence=[None],\n", " use_scaling_critics=[False],\n", - " seed=list(range(8)), # 8 seeds each\n", + " seed=list(range(128)),\n", " batch_size=[1],\n", ")\n", "df = pd.DataFrame(product(*line_sweeps.values()), columns=list(line_sweeps.keys()))\n", @@ -723,7 +391,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "id": "9b768813", "metadata": {}, "outputs": [ @@ -731,7 +399,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Spawned 16 jobs. It is safe to close the notebook.The next cell will resume from call_id's, saved by Modal for up to 7 days.\n" + "Spawned 256 jobs. It is safe to close the notebook.The next cell will resume from call_id's, saved by Modal for up to 7 days.\n" ] } ], @@ -757,686 +425,133 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 34, "id": "b9a637f0", "metadata": {}, "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "700b75ea416b483890f292b204d4d0fb", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/256 [00:00= 6.\n", - "df_result[\"binder_sequence\"] = df_result.designed_sequence.str.split(r\"\\|\").str[1]\n", - "df_result[\"isoelectric_point\"] = [\n", - " ProteinAnalysis(seq).isoelectric_point()\n", - " for seq in tqdm(df_result.binder_sequence.values)\n", - "]\n", - "# Isoelectric point filter\n", - "df_filter = df_result[df_result.is_antibody | df_result.isoelectric_point.lt(6)]\n", - "\n", - "\n", - "# Select the top 84 designs from each (target, binder) combination\n", - "def select(df: pd.DataFrame) -> pd.DataFrame:\n", - " # Where the cdr-specific iptm proxy exists, use it (antibodies).\n", - " # Else use the full distogram iptm proxy.\n", - " # If neither exists (use_scaling_checkpoints=False), then there is no contribution from this term.\n", - " df[\"iptm_proxy\"] = df.cdr_distogram_iptm_proxy.combine_first(\n", - " df.distogram_iptm_proxy\n", - " ).fillna(0)\n", - " df = df.groupby(\"designed_sequence\", as_index=False).agg(\n", - " dict(iptm=\"mean\", iptm_proxy=\"mean\")\n", - " )\n", - " df[\"selection_score\"] = 0.5 * df.iptm + 0.5 * df.iptm_proxy\n", - " return df.nlargest(min(len(df), 84), \"selection_score\")\n", - "\n", - "\n", - "df_select = df_filter.groupby([\"target_name\", \"binder_name\"]).apply(\n", - " select, include_groups=False\n", - ")\n", - "df_select.to_parquet(save_dir / \"selection.parquet\", index=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "b54a71de", - "metadata": {}, - "outputs": [ + }, { "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
is_antibodycritic_namebatch_idxdesigned_sequencefinal_lossiptmdistogram_iptm_proxycdr_distogram_iptm_proxytarget_nametarget_sequencebinder_namebinder_sequenceuse_scaling_criticsseedbatch_sizecall_idfuturestatusisoelectric_point
3FalseESMFold2-Experimental-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...2.8870060.9494710.902141NaNpd-l1NoneminibinderQSSDDEIDKEVNKVAAEIALAVAELTRAAADGDDKEVDKQLKKALK...False01fc-01KSTCTA27QYFGWNB67BPKZ72ZFunctionCall.from_id('fc-01KSTCTA27QYFGWNB67BP...SUCCESS9.521739
22FalseESMFold2-Experimental-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.9459100.3925570.610550NaNpd-l1NoneminibinderKWEIWRLLWKIGNNLWNNNNNNNNWNAIWTIWWWLIWWLIWWLLIN...False11fc-01KSTCTA52ZQZXC2RZ4F12ZJNCFunctionCall.from_id('fc-01KSTCTA52ZQZXC2RZ4F1...SUCCESS10.605259
41FalseESMFold2-Experimental-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.6284140.9204260.857930NaNpd-l1NoneminibinderSIIRILIIIVIKAIKKVSKIAKILKKALKELAKSGASKEIVEILIE...False21fc-01KSTCTA7ZBCYQZCJ1DHGT94BZFunctionCall.from_id('fc-01KSTCTA7ZBCYQZCJ1DHG...SUCCESS10.170871
60FalseESMFold2-Experimental-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.7705130.9031920.831106NaNpd-l1NoneminibinderMSLEELLKEIVEALKSGDFKKAAKAIKEAAKIIFSENIEVASAKIL...False31fc-01KSTCTAASYNGFY59G85TYT5SZFunctionCall.from_id('fc-01KSTCTAASYNGFY59G85T...SUCCESS7.856585
79FalseESMFold2-Experimental-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...3.7112600.9143340.849982NaNpd-l1NoneminibinderQNSNNNNNNNNEEDEEIDIKILKILIKLLIIIILLKKSPSSSSKKK...False41fc-01KSTCTAEAAQN1QKMYCDD92HKNFunctionCall.from_id('fc-01KSTCTAEAAQN1QKMYCDD...SUCCESS9.874187
98FalseESMFold2-Experimental-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...3.9598030.9340150.855943NaNpd-l1NoneminibinderSLILNILNIRINEINNLITNASKNELILYLKNLNIILKILLILLQN...False51fc-01KSTCTAHDBY4M10ESJP0HVS5ZFunctionCall.from_id('fc-01KSTCTAHDBY4M10ESJP0...SUCCESS5.117010
117FalseESMFold2-Experimental-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.2762570.9363310.892703NaNpd-l1NoneminibinderLLELLKILVKNAKNFSSSELYIVIMLLEILSNEDPREALILVEEII...False61fc-01KSTCTAM1V79QCSF9612BD1RVFunctionCall.from_id('fc-01KSTCTAM1V79QCSF9612...SUCCESS4.560045
136FalseESMFold2-Experimental-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...3.6636910.8467630.783418NaNpd-l1NoneminibinderQQLQLLIIQLILLIIVKILLQIANILLQEAKLSDSDDSEKIIKTLK...False71fc-01KSTCTAP3G7YRJ1RSQ2NFSTB5FunctionCall.from_id('fc-01KSTCTAP3G7YRJ1RSQ2N...SUCCESS9.399378
155TrueESMFold2-Experimental-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...3.8230500.9289520.8390540.851969pd-l1Nonetrastuzumab_framework_vhvlEVQLVESGGGLVQPGGSLRLSCAASSDRSYSVSYIHWVRQAPGKGL...False01fc-01KSTCTAR7DKSNJXBHARD4FFZ2FunctionCall.from_id('fc-01KSTCTAR7DKSNJXBHARD...SUCCESS6.984682
174TrueESMFold2-Experimental-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.1171040.9088980.7814310.797327pd-l1Nonetrastuzumab_framework_vhvlEVQLVESGGGLVQPGGSLRLSCAASEPLSYRIYIHWVRQAPGKGLE...False11fc-01KSTCTATESMP6KMYPFTQ1B2PFFunctionCall.from_id('fc-01KSTCTATESMP6KMYPFTQ...SUCCESS8.632139
193TrueESMFold2-Experimental-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.4055540.9339340.8271320.846008pd-l1Nonetrastuzumab_framework_vhvlEVQLVESGGGLVQPGGSLRLSCAASKGMEADDYIHWVRQAPGKGLE...False21fc-01KSTCTAXFAQJ8BCKSZ67NRH75FunctionCall.from_id('fc-01KSTCTAXFAQJ8BCKSZ67...SUCCESS6.863217
212TrueESMFold2-Experimental-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...3.7723780.9225570.8251450.848989pd-l1Nonetrastuzumab_framework_vhvlEVQLVESGGGLVQPGGSLRLSCAASGFAISDDYIHWVRQAPGKGLE...False31fc-01KSTCTAZTMDPNCHDTGCQJV0NNFunctionCall.from_id('fc-01KSTCTAZTMDPNCHDTGCQ...SUCCESS7.069998
231TrueESMFold2-Experimental-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.4308480.8985120.7595740.781431pd-l1Nonetrastuzumab_framework_vhvlEVQLVESGGGLVQPGGSLRLSCAASGDDDNLGYIHWVRQAPGKGLE...False41fc-01KSTCTB1QKCF0JA1FZQGV4RF1FunctionCall.from_id('fc-01KSTCTB1QKCF0JA1FZQG...SUCCESS8.622018
250TrueESMFold2-Experimental-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.1636610.9264100.8211710.834086pd-l1Nonetrastuzumab_framework_vhvlEVQLVESGGGLVQPGGSLRLSCAASEAFDTRVVLYIHWVRQAPGKG...False51fc-01KSTCTB4SZT3HVA7S6JR57Q60FunctionCall.from_id('fc-01KSTCTB4SZT3HVA7S6JR...SUCCESS6.982352
269TrueESMFold2-Experimental-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.0762990.7140020.6324070.648303pd-l1Nonetrastuzumab_framework_vhvlEVQLVESGGGLVQPGGSLRLSCAASSSLDDDDFAYIHWVRQAPGKG...False61fc-01KSTCTB6N74561E0PVPNHH68HFunctionCall.from_id('fc-01KSTCTB6N74561E0PVPN...SUCCESS6.856681
288TrueESMFold2-Experimental-Cutoff20250AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...4.1314370.9176700.8022940.817197pd-l1Nonetrastuzumab_framework_vhvlEVQLVESGGGLVQPGGSLRLSCAASPDLNFVLNYIHWVRQAPGKGL...False71fc-01KSTCTB9YC1GY5QAQVSZK67GAFunctionCall.from_id('fc-01KSTCTB9YC1GY5QAQVSZ...SUCCESS4.925178
\n", - "
" - ], + "application/vnd.jupyter.widget-view+json": { + "model_id": "3d25e164d3234b5b84686f99a91bf393", + "version_major": 2, + "version_minor": 0 + }, "text/plain": [ - " is_antibody critic_name batch_idx \\\n", - "3 False ESMFold2-Experimental-Cutoff2025 0 \n", - "22 False ESMFold2-Experimental-Cutoff2025 0 \n", - "41 False ESMFold2-Experimental-Cutoff2025 0 \n", - "60 False ESMFold2-Experimental-Cutoff2025 0 \n", - "79 False ESMFold2-Experimental-Cutoff2025 0 \n", - "98 False ESMFold2-Experimental-Cutoff2025 0 \n", - "117 False ESMFold2-Experimental-Cutoff2025 0 \n", - "136 False ESMFold2-Experimental-Cutoff2025 0 \n", - "155 True ESMFold2-Experimental-Cutoff2025 0 \n", - "174 True ESMFold2-Experimental-Cutoff2025 0 \n", - "193 True ESMFold2-Experimental-Cutoff2025 0 \n", - "212 True ESMFold2-Experimental-Cutoff2025 0 \n", - "231 True ESMFold2-Experimental-Cutoff2025 0 \n", - "250 True ESMFold2-Experimental-Cutoff2025 0 \n", - "269 True ESMFold2-Experimental-Cutoff2025 0 \n", - "288 True ESMFold2-Experimental-Cutoff2025 0 \n", - "\n", - " designed_sequence final_loss iptm \\\n", - "3 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 2.887006 0.949471 \n", - "22 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.945910 0.392557 \n", - "41 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.628414 0.920426 \n", - "60 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.770513 0.903192 \n", - "79 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 3.711260 0.914334 \n", - "98 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 3.959803 0.934015 \n", - "117 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.276257 0.936331 \n", - "136 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 3.663691 0.846763 \n", - "155 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 3.823050 0.928952 \n", - "174 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.117104 0.908898 \n", - "193 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.405554 0.933934 \n", - "212 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 3.772378 0.922557 \n", - "231 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.430848 0.898512 \n", - "250 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.163661 0.926410 \n", - "269 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.076299 0.714002 \n", - "288 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... 4.131437 0.917670 \n", - "\n", - " distogram_iptm_proxy cdr_distogram_iptm_proxy target_name \\\n", - "3 0.902141 NaN pd-l1 \n", - "22 0.610550 NaN pd-l1 \n", - "41 0.857930 NaN pd-l1 \n", - "60 0.831106 NaN pd-l1 \n", - "79 0.849982 NaN pd-l1 \n", - "98 0.855943 NaN pd-l1 \n", - "117 0.892703 NaN pd-l1 \n", - "136 0.783418 NaN pd-l1 \n", - "155 0.839054 0.851969 pd-l1 \n", - "174 0.781431 0.797327 pd-l1 \n", - "193 0.827132 0.846008 pd-l1 \n", - "212 0.825145 0.848989 pd-l1 \n", - "231 0.759574 0.781431 pd-l1 \n", - "250 0.821171 0.834086 pd-l1 \n", - "269 0.632407 0.648303 pd-l1 \n", - "288 0.802294 0.817197 pd-l1 \n", - "\n", - " target_sequence binder_name \\\n", - "3 None minibinder \n", - "22 None minibinder \n", - "41 None minibinder \n", - "60 None minibinder \n", - "79 None minibinder \n", - "98 None minibinder \n", - "117 None minibinder \n", - "136 None minibinder \n", - "155 None trastuzumab_framework_vhvl \n", - "174 None trastuzumab_framework_vhvl \n", - "193 None trastuzumab_framework_vhvl \n", - "212 None trastuzumab_framework_vhvl \n", - "231 None trastuzumab_framework_vhvl \n", - "250 None trastuzumab_framework_vhvl \n", - "269 None trastuzumab_framework_vhvl \n", - "288 None trastuzumab_framework_vhvl \n", - "\n", - " binder_sequence use_scaling_critics \\\n", - "3 QSSDDEIDKEVNKVAAEIALAVAELTRAAADGDDKEVDKQLKKALK... False \n", - "22 KWEIWRLLWKIGNNLWNNNNNNNNWNAIWTIWWWLIWWLIWWLLIN... False \n", - "41 SIIRILIIIVIKAIKKVSKIAKILKKALKELAKSGASKEIVEILIE... False \n", - "60 MSLEELLKEIVEALKSGDFKKAAKAIKEAAKIIFSENIEVASAKIL... False \n", - "79 QNSNNNNNNNNEEDEEIDIKILKILIKLLIIIILLKKSPSSSSKKK... False \n", - "98 SLILNILNIRINEINNLITNASKNELILYLKNLNIILKILLILLQN... False \n", - "117 LLELLKILVKNAKNFSSSELYIVIMLLEILSNEDPREALILVEEII... False \n", - "136 QQLQLLIIQLILLIIVKILLQIANILLQEAKLSDSDDSEKIIKTLK... False \n", - "155 EVQLVESGGGLVQPGGSLRLSCAASSDRSYSVSYIHWVRQAPGKGL... False \n", - "174 EVQLVESGGGLVQPGGSLRLSCAASEPLSYRIYIHWVRQAPGKGLE... False \n", - "193 EVQLVESGGGLVQPGGSLRLSCAASKGMEADDYIHWVRQAPGKGLE... False \n", - "212 EVQLVESGGGLVQPGGSLRLSCAASGFAISDDYIHWVRQAPGKGLE... False \n", - "231 EVQLVESGGGLVQPGGSLRLSCAASGDDDNLGYIHWVRQAPGKGLE... False \n", - "250 EVQLVESGGGLVQPGGSLRLSCAASEAFDTRVVLYIHWVRQAPGKG... False \n", - "269 EVQLVESGGGLVQPGGSLRLSCAASSSLDDDDFAYIHWVRQAPGKG... False \n", - "288 EVQLVESGGGLVQPGGSLRLSCAASPDLNFVLNYIHWVRQAPGKGL... False \n", - "\n", - " seed batch_size call_id \\\n", - "3 0 1 fc-01KSTCTA27QYFGWNB67BPKZ72Z \n", - "22 1 1 fc-01KSTCTA52ZQZXC2RZ4F12ZJNC \n", - "41 2 1 fc-01KSTCTA7ZBCYQZCJ1DHGT94BZ \n", - "60 3 1 fc-01KSTCTAASYNGFY59G85TYT5SZ \n", - "79 4 1 fc-01KSTCTAEAAQN1QKMYCDD92HKN \n", - "98 5 1 fc-01KSTCTAHDBY4M10ESJP0HVS5Z \n", - "117 6 1 fc-01KSTCTAM1V79QCSF9612BD1RV \n", - "136 7 1 fc-01KSTCTAP3G7YRJ1RSQ2NFSTB5 \n", - "155 0 1 fc-01KSTCTAR7DKSNJXBHARD4FFZ2 \n", - "174 1 1 fc-01KSTCTATESMP6KMYPFTQ1B2PF \n", - "193 2 1 fc-01KSTCTAXFAQJ8BCKSZ67NRH75 \n", - "212 3 1 fc-01KSTCTAZTMDPNCHDTGCQJV0NN \n", - "231 4 1 fc-01KSTCTB1QKCF0JA1FZQGV4RF1 \n", - "250 5 1 fc-01KSTCTB4SZT3HVA7S6JR57Q60 \n", - "269 6 1 fc-01KSTCTB6N74561E0PVPNHH68H \n", - "288 7 1 fc-01KSTCTB9YC1GY5QAQVSZK67GA \n", - "\n", - " future status \\\n", - "3 FunctionCall.from_id('fc-01KSTCTA27QYFGWNB67BP... SUCCESS \n", - "22 FunctionCall.from_id('fc-01KSTCTA52ZQZXC2RZ4F1... SUCCESS \n", - "41 FunctionCall.from_id('fc-01KSTCTA7ZBCYQZCJ1DHG... SUCCESS \n", - "60 FunctionCall.from_id('fc-01KSTCTAASYNGFY59G85T... SUCCESS \n", - "79 FunctionCall.from_id('fc-01KSTCTAEAAQN1QKMYCDD... SUCCESS \n", - "98 FunctionCall.from_id('fc-01KSTCTAHDBY4M10ESJP0... SUCCESS \n", - "117 FunctionCall.from_id('fc-01KSTCTAM1V79QCSF9612... SUCCESS \n", - "136 FunctionCall.from_id('fc-01KSTCTAP3G7YRJ1RSQ2N... SUCCESS \n", - "155 FunctionCall.from_id('fc-01KSTCTAR7DKSNJXBHARD... SUCCESS \n", - "174 FunctionCall.from_id('fc-01KSTCTATESMP6KMYPFTQ... SUCCESS \n", - "193 FunctionCall.from_id('fc-01KSTCTAXFAQJ8BCKSZ67... SUCCESS \n", - "212 FunctionCall.from_id('fc-01KSTCTAZTMDPNCHDTGCQ... SUCCESS \n", - "231 FunctionCall.from_id('fc-01KSTCTB1QKCF0JA1FZQG... SUCCESS \n", - "250 FunctionCall.from_id('fc-01KSTCTB4SZT3HVA7S6JR... SUCCESS \n", - "269 FunctionCall.from_id('fc-01KSTCTB6N74561E0PVPN... SUCCESS \n", - "288 FunctionCall.from_id('fc-01KSTCTB9YC1GY5QAQVSZ... SUCCESS \n", - "\n", - " isoelectric_point \n", - "3 9.521739 \n", - "22 10.605259 \n", - "41 10.170871 \n", - "60 7.856585 \n", - "79 9.874187 \n", - "98 5.117010 \n", - "117 4.560045 \n", - "136 9.399378 \n", - "155 6.984682 \n", - "174 8.632139 \n", - "193 6.863217 \n", - "212 7.069998 \n", - "231 8.622018 \n", - "250 6.982352 \n", - "269 6.856681 \n", - "288 4.925178 " + " 0%| | 0/256 [00:00\n", " \n", " designed_sequence\n", - " iptm\n", - " iptm_proxy\n", + " iptm_score\n", + " iptm_proxy_score\n", " selection_score\n", " \n", " \n", @@ -1477,128 +592,215 @@ " \n", " \n", " \n", - " pd-l1\n", - " minibinder\n", - " 0\n", + " pd-l1\n", + " minibinder\n", + " 70\n", " AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...\n", - " 0.930100\n", - " 0.856649\n", - " 0.893375\n", + " 0.967982\n", + " 0.944397\n", + " 0.956190\n", " \n", " \n", - " 1\n", + " 13\n", " AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...\n", - " 0.937393\n", - " 0.837694\n", - " 0.887544\n", + " 0.964470\n", + " 0.926216\n", + " 0.945343\n", " \n", " \n", - " trastuzumab_framework_vhvl\n", - " 4\n", + " 78\n", " AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...\n", - " 0.925267\n", - " 0.810059\n", - " 0.867663\n", + " 0.961920\n", + " 0.918186\n", + " 0.940053\n", " \n", " \n", - " 3\n", + " 75\n", " AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...\n", - " 0.919929\n", - " 0.788386\n", - " 0.854157\n", + " 0.957516\n", + " 0.916083\n", + " 0.936799\n", " \n", " \n", - " 6\n", + " 45\n", " AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...\n", - " 0.917205\n", - " 0.780960\n", - " 0.849083\n", + " 0.964665\n", + " 0.908466\n", + " 0.936566\n", + " \n", + " \n", + " ...\n", + " ...\n", + " ...\n", + " ...\n", + " ...\n", + " ...\n", " \n", " \n", - " 5\n", + " trastuzumab_framework_vhvl\n", + " 2\n", " AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...\n", - " 0.906416\n", - " 0.753404\n", - " 0.829910\n", + " 0.921884\n", + " 0.725663\n", + " 0.823773\n", " \n", " \n", - " 1\n", + " 107\n", " AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...\n", - " 0.904410\n", - " 0.719259\n", - " 0.811835\n", + " 0.920620\n", + " 0.726391\n", + " 0.823506\n", " \n", " \n", - " 0\n", + " 46\n", " AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...\n", - " 0.877860\n", - " 0.723024\n", - " 0.800442\n", + " 0.879064\n", + " 0.767323\n", + " 0.823194\n", " \n", " \n", - " 2\n", + " 68\n", " AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...\n", - " 0.868486\n", - " 0.710997\n", - " 0.789742\n", + " 0.921196\n", + " 0.724139\n", + " 0.822668\n", " \n", " \n", - " 7\n", + " 50\n", " AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN...\n", - " 0.763347\n", - " 0.644642\n", - " 0.703995\n", + " 0.894414\n", + " 0.747255\n", + " 0.820834\n", " \n", " \n", "\n", + "

168 rows × 4 columns

\n", "
" ], "text/plain": [ - " designed_sequence \\\n", - "target_name binder_name \n", - "pd-l1 minibinder 0 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", - " 1 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", - " trastuzumab_framework_vhvl 4 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", - " 3 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", - " 6 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", - " 5 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", - " 1 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", - " 0 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", - " 2 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", - " 7 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", + " designed_sequence \\\n", + "target_name binder_name \n", + "pd-l1 minibinder 70 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", + " 13 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", + " 78 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", + " 75 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", + " 45 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", + "... ... \n", + " trastuzumab_framework_vhvl 2 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", + " 107 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", + " 46 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", + " 68 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", + " 50 AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKN... \n", "\n", - " iptm iptm_proxy \\\n", - "target_name binder_name \n", - "pd-l1 minibinder 0 0.930100 0.856649 \n", - " 1 0.937393 0.837694 \n", - " trastuzumab_framework_vhvl 4 0.925267 0.810059 \n", - " 3 0.919929 0.788386 \n", - " 6 0.917205 0.780960 \n", - " 5 0.906416 0.753404 \n", - " 1 0.904410 0.719259 \n", - " 0 0.877860 0.723024 \n", - " 2 0.868486 0.710997 \n", - " 7 0.763347 0.644642 \n", + " iptm_score iptm_proxy_score \\\n", + "target_name binder_name \n", + "pd-l1 minibinder 70 0.967982 0.944397 \n", + " 13 0.964470 0.926216 \n", + " 78 0.961920 0.918186 \n", + " 75 0.957516 0.916083 \n", + " 45 0.964665 0.908466 \n", + "... ... ... \n", + " trastuzumab_framework_vhvl 2 0.921884 0.725663 \n", + " 107 0.920620 0.726391 \n", + " 46 0.879064 0.767323 \n", + " 68 0.921196 0.724139 \n", + " 50 0.894414 0.747255 \n", "\n", - " selection_score \n", - "target_name binder_name \n", - "pd-l1 minibinder 0 0.893375 \n", - " 1 0.887544 \n", - " trastuzumab_framework_vhvl 4 0.867663 \n", - " 3 0.854157 \n", - " 6 0.849083 \n", - " 5 0.829910 \n", - " 1 0.811835 \n", - " 0 0.800442 \n", - " 2 0.789742 \n", - " 7 0.703995 " + " selection_score \n", + "target_name binder_name \n", + "pd-l1 minibinder 70 0.956190 \n", + " 13 0.945343 \n", + " 78 0.940053 \n", + " 75 0.936799 \n", + " 45 0.936566 \n", + "... ... \n", + " trastuzumab_framework_vhvl 2 0.823773 \n", + " 107 0.823506 \n", + " 46 0.823194 \n", + " 68 0.822668 \n", + " 50 0.820834 \n", + "\n", + "[168 rows x 4 columns]" ] }, - "execution_count": 15, + "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], + "source": [ + "# ---- Select ----\n", + "\n", + "# Join all result_df's, with other fields in df broadcasted as metadata.\n", + "df_result = pd.concat(\n", + " [\n", + " row.result_df.assign(**row.drop([\"result\", \"result_df\"]).to_dict()) # pyright: ignore\n", + " for _, row in df_success.iterrows()\n", + " ],\n", + " ignore_index=True,\n", + " axis=0,\n", + ")\n", + "\n", + "# Filter minibinder designs with isoelectric point >= 6.\n", + "df_result[\"binder_sequence\"] = df_result.designed_sequence.str.split(r\"\\|\").str[1]\n", + "df_result[\"isoelectric_point\"] = [\n", + " ProteinAnalysis(seq).isoelectric_point()\n", + " for seq in tqdm(df_result.binder_sequence.values)\n", + "]\n", + "# Isoelectric point filter\n", + "df_filter = df_result[df_result.is_antibody | df_result.isoelectric_point.lt(6)]\n", + "\n", + "\n", + "# Select the top 84 designs from each (target, binder) combination\n", + "SCALING_CHECKPOINT_SUBSTRING = \"ESMFold2-Experimental-Fast-base\"\n", + "\n", + "\n", + "def select(df: pd.DataFrame) -> pd.DataFrame:\n", + " df = df.copy()\n", + " is_scaling = df.critic_name.str.contains(\n", + " SCALING_CHECKPOINT_SUBSTRING, regex=False, na=False\n", + " )\n", + " iptm_proxy = df.distogram_iptm_proxy.where(\n", + " ~df.is_antibody, df.cdr_distogram_iptm_proxy\n", + " )\n", + "\n", + " df[\"iptm_score\"] = df.iptm.where(~is_scaling)\n", + " df[\"iptm_proxy_score\"] = iptm_proxy.where(is_scaling)\n", + " scores = df.groupby(\"designed_sequence\", as_index=False).agg(\n", + " iptm_score=(\"iptm_score\", \"mean\"), iptm_proxy_score=(\"iptm_proxy_score\", \"mean\")\n", + " )\n", + " scores[\"selection_score\"] = 0.5 * scores.iptm_score.fillna(\n", + " 0\n", + " ) + 0.5 * scores.iptm_proxy_score.fillna(0)\n", + "\n", + " return scores.nlargest(min(len(scores), 84), \"selection_score\")\n", + "\n", + "\n", + "df_select = df_filter.groupby([\"target_name\", \"binder_name\"]).apply(\n", + " select, include_groups=False\n", + ")\n", + "df_select.to_parquet(save_dir / \"selection.parquet\", index=False)\n", + "df_select" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b54a71de", + "metadata": {}, + "outputs": [], + "source": [ + "df_result[df_result.critic_name.eq(\"ESMFold2-Experimental-Cutoff2025\")].drop(\n", + " columns=[\"complex\", \"logits\"]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "851365c0", + "metadata": {}, + "outputs": [], "source": [ "df_select" ] @@ -1660,9 +862,9 @@ ], "metadata": { "kernelspec": { - "display_name": "modal-test", + "display_name": "default", "language": "python", - "name": "modal-test" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -1674,7 +876,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.0" + "version": "3.12.13" } }, "nbformat": 4, diff --git a/cookbook/tutorials/esmfold2_esmc_binder_design.py b/cookbook/tutorials/esmfold2_esmc_binder_design.py index b92f144..bbde18e 100644 --- a/cookbook/tutorials/esmfold2_esmc_binder_design.py +++ b/cookbook/tutorials/esmfold2_esmc_binder_design.py @@ -370,7 +370,7 @@ def compute_structure_losses( def _binding_confidence_entropy( dgram: torch.Tensor, bin_distance: torch.Tensor, cutoff: float ) -> torch.Tensor: - """Pair entropy within cutoff; matches rd3 contact_score scoring.""" + """Pair entropy within cutoff.""" probs = torch.softmax(dgram, dim=-1) cutoff_mask = bin_distance < cutoff @@ -1205,7 +1205,7 @@ def get_base_image(): class ESMFold2DesignModal(ESMFold2Design): """Modal entrypoint. Hero critics are HF experimental exports with confidence heads. Set ``use_scaling_critics=True`` to also load the - 15-checkpoint rd3 scaling-experiment ensemble (distogram binding confidence only). + 15-checkpoint scaling-experiment ensemble (distogram binding confidence only). """ use_scaling_critics: bool = modal.parameter(default=False) From 918c2307f47bc0a7e68b9c10adf05d08ae8ed805 Mon Sep 17 00:00:00 2001 From: Zeming Lin Date: Mon, 1 Jun 2026 21:57:41 +0000 Subject: [PATCH 5/6] update the protocol --- cookbook/tutorials/README.md | 5 ++- .../tutorials/esmfold2_esmc_binder_design.py | 40 ++----------------- 2 files changed, 6 insertions(+), 39 deletions(-) diff --git a/cookbook/tutorials/README.md b/cookbook/tutorials/README.md index abcec8a..6bf5b3f 100644 --- a/cookbook/tutorials/README.md +++ b/cookbook/tutorials/README.md @@ -23,9 +23,10 @@ ESMC is a protein language model that embeds sequences into rich numerical repre ESMFold2 predicts 3D protein structure from sequence, including DNA/RNA and small molecules. -| Notebook | Colab Notebook | Description | +| Notebook | Colab Notebook | Description | | :---- | :---- | :---- | -| Folding with ESMFold2 | `esmfold2.ipynb`
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/biohub/esm/blob/main/cookbook/tutorials/esmfold2.ipynb) | Fold proteins in combination with DNA, RNA and small-molecule ligands. +| Folding with ESMFold2 | `esmfold2.ipynb`
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/biohub/esm/blob/main/cookbook/tutorials/esmfold2.ipynb) | Fold proteins in combination with DNA, RNA and small-molecule ligands. | +| Binder design | `esmfold2_esmc_binder_design.ipynb`
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/biohub/esm/blob/main/cookbook/tutorials/esmfold2_esmc_binder_design.ipynb) | Design antibodies and minibinders with high hit rates. Implements the protocol featured in our paper, which produced binders exhibiting nanomolar affinity, target specificity, and functional activity in laboratory assays. | ## **ESM3** diff --git a/cookbook/tutorials/esmfold2_esmc_binder_design.py b/cookbook/tutorials/esmfold2_esmc_binder_design.py index bbde18e..8040f16 100644 --- a/cookbook/tutorials/esmfold2_esmc_binder_design.py +++ b/cookbook/tutorials/esmfold2_esmc_binder_design.py @@ -1145,49 +1145,15 @@ def design( # ---- Modal ---- -FLASH_ATTN_WHEEL = ( - "flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/" - "v2.8.3/flash_attn-2.8.3+cu12torch2.8cxx11abiFALSE-cp312-cp312-linux_x86_64.whl" -) -TRANSFORMER_ENGINE_TORCH_WHEEL = ( - "transformer-engine-torch @ https://github.com/evolutionaryscale/wheels/" - "releases/download/transformer-engine-torch/v2.13.0-pt2.8-cu128-cp312/" - "transformer_engine_torch-2.13.0-cp312-cp312-linux_x86_64.whl" -) - - def get_base_image(): return ( modal.Image.micromamba(python_version="3.12") .run_commands("apt update && apt install -y git build-essential") .micromamba_install( - "anarci>=2020.04.03", - "hmmer=3.4", - "cuda-version=12.8", - "cuda-libraries-dev=12.8", - channels=["conda-forge", "bioconda"], - ) - .pip_install( - "torch==2.8.0", - "triton==3.4.0", - index_url="https://download.pytorch.org/whl/cu128", - ) - .pip_install( - FLASH_ATTN_WHEEL, - "transformer-engine[core-cu12,pytorch]==2.13.0", - TRANSFORMER_ENGINE_TORCH_WHEEL, - "xformers==0.0.32.post1", - ) - .pip_install( - "abnumber", "esm@git+https://github.com/Biohub/esm.git@main", "modal" - ) - .env( - { - "HF_HOME": "/models", - "HF_XET_HIGH_PERFORMANCE": "1", - "XFORMERS_IGNORE_FLASH_VERSION_CHECK": "1", - } + "anarci>=2020.04.03", "hmmer=3.4", channels=["conda-forge", "bioconda"] ) + .pip_install("abnumber", "esm@git+https://github.com/Biohub/esm.git@main") + .env({"HF_HOME": "/models", "HF_XET_HIGH_PERFORMANCE": "1"}) ) From fa0720a7d45598ff602acd285388dd545aee57ef Mon Sep 17 00:00:00 2001 From: Zeming Lin Date: Mon, 1 Jun 2026 22:26:52 +0000 Subject: [PATCH 6/6] Add more fixes --- .pre-commit-config.yaml | 2 +- README.md | 2 +- cookbook/tutorials/README.md | 2 +- ...fold2_esmc_binder_design.ipynb => binder_design.ipynb} | 8 ++++---- .../{esmfold2_esmc_binder_design.py => binder_design.py} | 0 5 files changed, 7 insertions(+), 7 deletions(-) rename cookbook/tutorials/{esmfold2_esmc_binder_design.ipynb => binder_design.ipynb} (99%) rename cookbook/tutorials/{esmfold2_esmc_binder_design.py => binder_design.py} (100%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 46c4c15..a127f34 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -32,7 +32,7 @@ repos: types: [python] pass_filenames: true # For speed, we only check the files that are changed # Modal-app tutorial: deps (modal, abnumber) and dynamic decorators aren't resolvable in the lint env. - exclude: ^cookbook/tutorials/esmfold2_esmc_binder_design\.py$ + exclude: ^cookbook/tutorials/binder_design\.py$ - repo: https://github.com/gitleaks/gitleaks rev: v8.24.2 hooks: diff --git a/README.md b/README.md index 26d1ca5..dfbd0b8 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ We are releasing a world model for protein biology: a scientific engine for pred
-ESMFold2 is validated in the lab across five therapeutic targets. Inversion of ESMFold2 enables generation of de novo minibinders and antibody-derived scFvs with high hit rates, nanomolar affinities, target specificity, and functional activity. We've released the full protocol from target sequence to ranked binder design in this [notebook](https://github.com/Biohub/esm/blob/main/cookbook/tutorials/esmfold2_esmc_binder_design.ipynb). For additional details, please refer to the [preprint](https://biohub.ai/papers/esm_protein.pdf). +ESMFold2 is validated in the lab across five therapeutic targets. Inversion of ESMFold2 enables generation of de novo minibinders and antibody-derived scFvs with high hit rates, nanomolar affinities, target specificity, and functional activity. We've released the full protocol from target sequence to ranked binder design in this [notebook](https://github.com/Biohub/esm/blob/main/cookbook/tutorials/binder_design.ipynb). For additional details, please refer to the [preprint](https://biohub.ai/papers/esm_protein.pdf).
diff --git a/cookbook/tutorials/README.md b/cookbook/tutorials/README.md index 6bf5b3f..f2d51d7 100644 --- a/cookbook/tutorials/README.md +++ b/cookbook/tutorials/README.md @@ -26,7 +26,7 @@ ESMFold2 predicts 3D protein structure from sequence, including DNA/RNA and smal | Notebook | Colab Notebook | Description | | :---- | :---- | :---- | | Folding with ESMFold2 | `esmfold2.ipynb`
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/biohub/esm/blob/main/cookbook/tutorials/esmfold2.ipynb) | Fold proteins in combination with DNA, RNA and small-molecule ligands. | -| Binder design | `esmfold2_esmc_binder_design.ipynb`
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/biohub/esm/blob/main/cookbook/tutorials/esmfold2_esmc_binder_design.ipynb) | Design antibodies and minibinders with high hit rates. Implements the protocol featured in our paper, which produced binders exhibiting nanomolar affinity, target specificity, and functional activity in laboratory assays. | +| Binder design | `binder_design.ipynb`
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/biohub/esm/blob/main/cookbook/tutorials/binder_design.ipynb) | Design antibodies and minibinders with high hit rates. Implements the protocol featured in our paper, which produced binders exhibiting nanomolar affinity, target specificity, and functional activity in laboratory assays. | ## **ESM3** diff --git a/cookbook/tutorials/esmfold2_esmc_binder_design.ipynb b/cookbook/tutorials/binder_design.ipynb similarity index 99% rename from cookbook/tutorials/esmfold2_esmc_binder_design.ipynb rename to cookbook/tutorials/binder_design.ipynb index 2372298..a921958 100644 --- a/cookbook/tutorials/esmfold2_esmc_binder_design.ipynb +++ b/cookbook/tutorials/binder_design.ipynb @@ -52,9 +52,9 @@ "metadata": {}, "outputs": [], "source": [ - "# Deploy (or redeploy after changing modal_binder_design.py).\n", - "# This only needs to be run a single time, unless code in esmfold2_esmc_binder_design.py changes.\n", - "! modal deploy esmfold2_esmc_binder_design.py" + "# Deploy (or redeploy after changing binder_design.py).\n", + "# This only needs to be run a single time, unless code in binder_design.py changes.\n", + "! modal deploy binder_design.py" ] }, { @@ -146,7 +146,7 @@ "# ---- Option 2: Provide your own sequences. ----\n", "# Our pd-l1 sequence crop.\n", "pdl1_sequence = \"AFTVTVPKDLYVVEYGSNMTIECKFPVEKQLDLAALIVYWEMEDKNIIQFVHGEEDLKVQHSSYRQRARLLKDQLSLGNAALQITDVKLQDAGVYRCMISYGGADYKRITVKVNA\"\n", - "# A sample of 'trastuzumab_framework_vhvl' template. From esmfold2_esmc_binder_design.py::BINDER_PROMPT_FACTORIES.\n", + "# A sample of 'trastuzumab_framework_vhvl' template. From binder_design.py::BINDER_PROMPT_FACTORIES.\n", "trastuzumab_framework_vhvl = \"EVQLVESGGGLVQPGGSLRLSCAAS#######YIHWVRQAPGKGLEWVARI#####TRYADSVKGRFTISADTSKNTAYLQMNSLRAEDTAVYYCSR###########WGQGTLVTVSSGGGSGGGSGGGSGGGSDIQMTQSPSSLSASVGDRVTITC###########WYQQKPGKAPKLLIY#######GVPSRFSGSRSGTDFTLTISSLQPEDFATYYC#########FGQGTKVEIK\"\n", "future2 = app.design.spawn(\n", " target_sequence=pdl1_sequence,\n", diff --git a/cookbook/tutorials/esmfold2_esmc_binder_design.py b/cookbook/tutorials/binder_design.py similarity index 100% rename from cookbook/tutorials/esmfold2_esmc_binder_design.py rename to cookbook/tutorials/binder_design.py