<a href="https://colab.research.google.com/github/Kuhlman-Lab/ThermoMPNN-D/blob/main/ThermoMPNN-D.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# <center>**This is the Colab implementation of ThermoMPNN-D**</center>


<center><img src='https://drive.google.com/uc?export=view&id=1qXMpih7MLeZfRDZF9-iYSlL6SXEY3FdS'></center>

---

ThermoMPNN-D is an updated version of ThermoMPNN for predicting double point mutations. It was trained on an augmented version of the Megascale double mutant dataset. It is state-of-the-art at predicting stabilizing double mutations.

For convenience, we also provide a single-mutant ThermoMPNN model and an "additive" model that finds mutation pairs in a naive fashion by ignoring epistatic interactions. For details, see the [ThermoMPNN-D paper](https://doi.org/10.1101/2024.08.20.608844).

### **COLAB TIPS:**
- The cells of this notebook are meant to be executed *in order*, so users should start from the top and work their way down.
- Executable cells can be run by clicking the PLAY button (>) that appears when you hover over each cell, or by using **Shift+Enter**.
- Make sure GPU is enabled by checking `Runtime` -> `Change Runtime Type`
  - Make sure that `Runtime type` is set to `Python 3`
  - Make sure that `Hardware accelerator` is set to `GPU`
  - Click `Save` to confirm

- If the notebook freezes up or otherwise crashes, go to `Runtime` -> `Restart Runtime` and try again.


In [1]:
%%capture

#@title # 1. Set up **ThermoMPNN environment**
#@markdown Import ThermoMPNN and its dependencies to this session. This may take a minute or two.

#@markdown You only need to do this once *per session*. To re-run ThermoMPNN on a new protein, you may start on Step 3.

#@markdown ---

# cleaning out any remaining data
!cd /content
!rm -rf /content/ThermoMPNN-D
!rm -rf /content/sample_data
!rm /content/*.pdb
!rm /content/*.csv

# import ThermoMPNN-D github repo
import os
if not os.path.exists("/content/ThermoMPNN-D"):
  !git clone https://github.com/Kuhlman-Lab/ThermoMPNN-D.git
  %cd /content/ThermoMPNN-D

# downloading various dependencies - add more if needed later
! pip install omegaconf wandb pytorch-lightning biopython nglview


In [9]:
%%capture
#@title # **2. Set up ThermoMPNN imports and functions**

from google.colab import files
import os
import sys
from urllib import request
from urllib.error import HTTPError
from google.colab._message import MessageError

import re
import torch
import sys
import numpy as np
from dataclasses import dataclass
from Bio.PDB import PDBParser
from omegaconf import OmegaConf
import pandas as pd
from copy import deepcopy
from tqdm import tqdm
import time
from scipy.spatial.distance import cdist
from torch.utils.data import Dataset, DataLoader

tMPNN_path = '/content/ThermoMPNN-D'
if tMPNN_path not in sys.path:
  sys.path.append(tMPNN_path)

from thermompnn.datasets.dataset_utils import Mutation
from thermompnn.datasets.v2_datasets import tied_featurize_mut
from thermompnn.model.v2_model import batched_index_select, _dist

from thermompnn.train_thermompnn import parse_cfg
from thermompnn.trainer.v2_trainer import TransferModelPLv2, TransferModelPLv2Siamese

def download_pdb(pdbcode, datadir, downloadurl="https://files.rcsb.org/download/"):
    """
    Downloads a PDB file from the Internet and saves it in a data directory.
    :param pdbcode: The standard PDB ID e.g. '3ICB' or '3icb'
    :param datadir: The directory where the downloaded file will be saved
    :param downloadurl: The base PDB download URL, cf.
        `https://www.rcsb.org/pages/download/http#structures` for details
    :return: the full path to the downloaded PDB file or None if something went wrong
    """

    pdbfn = pdbcode + ".pdb"
    url = downloadurl + pdbfn
    outfnm = os.path.join(datadir, pdbfn)
    try:
        request.urlretrieve(url, outfnm)
        return outfnm
    except Exception as err:
        print(str(err), file=sys.stderr)
        return None

def alt_parse_PDB_biounits(x, atoms=['N', 'CA', 'C'], chain=None):
    '''
  input:  x = PDB filename
          atoms = atoms to extract (optional)
  output: (length, atoms, coords=(x,y,z)), sequence
  '''

    alpha_1 = list("ARNDCQEGHILKMFPSTWYV-")
    states = len(alpha_1)
    alpha_3 = ['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 'HIS', 'ILE',
               'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL', 'GAP']

    aa_1_N = {a: n for n, a in enumerate(alpha_1)}
    aa_3_N = {a: n for n, a in enumerate(alpha_3)}
    aa_N_1 = {n: a for n, a in enumerate(alpha_1)}
    aa_1_3 = {a: b for a, b in zip(alpha_1, alpha_3)}
    aa_3_1 = {b: a for a, b in zip(alpha_1, alpha_3)}

    def AA_to_N(x):
        # ["ARND"] -> [[0,1,2,3]]
        x = np.array(x);
        if x.ndim == 0: x = x[None]
        return [[aa_1_N.get(a, states - 1) for a in y] for y in x]

    def N_to_AA(x):
        # [[0,1,2,3]] -> ["ARND"]
        x = np.array(x);
        if x.ndim == 1: x = x[None]
        return ["".join([aa_N_1.get(a, "-") for a in y]) for y in x]

    xyz, seq, min_resn, max_resn = {}, {}, 1e6, -1e6
    resn_list = []
    for line in open(x, "rb"):
        line = line.decode("utf-8", "ignore").rstrip()

        # handling MSE and SEC residues
        if line[:6] == "HETATM" and line[17:17 + 3] == "MSE":
            line = line.replace("HETATM", "ATOM  ")
            line = line.replace("MSE", "MET")
        elif line[17:17 + 3] == "MSE":
            line = line.replace("MSE", "MET")
        elif line[17:17 + 3] == "SEC":
            line = line.replace("SEC", "CYS")

        if line[:4] == "ATOM":
            ch = line[21:22]
            if ch == chain or chain is None:
                atom = line[12:12 + 4].strip()
                resi = line[17:17 + 3]
                resn = line[22:22 + 5].strip()

                # check for gaps and add them if needed
                if (resn not in resn_list) and len(resn_list) > 0:
                  _, num, ins_code = re.split(r'(\d+)', resn)
                  _, num_prior, ins_code_prior = re.split(r'(\d+)', resn_list[-1])
                  gap = int(num) - int(num_prior) - 1
                  for g in range(gap + 1):
                    resn_list.append(str(int(num_prior) + g))

                # RAW resn is defined HERE
                resn_list.append(resn) # NEED to keep ins code here

                x, y, z = [float(line[i:(i + 8)]) for i in [30, 38, 46]]
                if resn[-1].isalpha():
                    resa, resn = resn[-1], int(resn[:-1]) - 1
                else:
                    resa, resn = "", int(resn) - 1
                if resn < min_resn:
                    min_resn = resn
                if resn > max_resn:
                    max_resn = resn
                if resn not in xyz:
                    xyz[resn] = {}
                if resa not in xyz[resn]:
                    xyz[resn][resa] = {}
                if resn not in seq:
                    seq[resn] = {}
                if resa not in seq[resn]:
                    seq[resn][resa] = resi

                if atom not in xyz[resn][resa]:
                    xyz[resn][resa][atom] = np.array([x, y, z])

    # convert to numpy arrays, fill in missing values
    seq_, xyz_ = [], []
    try:
        for resn in range(min_resn, max_resn + 1):
            if resn in seq:
                for k in sorted(seq[resn]): seq_.append(aa_3_N.get(seq[resn][k], 20))
            else:
                seq_.append(20)

            if resn in xyz:
                for k in sorted(xyz[resn]):
                    for atom in atoms:
                        if atom in xyz[resn][k]:
                            xyz_.append(xyz[resn][k][atom])
                        else:
                            xyz_.append(np.full(3, np.nan))
            else:
                for atom in atoms: xyz_.append(np.full(3, np.nan))
        return np.array(xyz_).reshape(-1, len(atoms), 3), N_to_AA(np.array(seq_)), list(dict.fromkeys(resn_list))
    except TypeError:
        return 'no_chain', 'no_chain', 'no_chain'

def alt_parse_PDB(path_to_pdb, input_chain_list=None, ca_only=False, side_chains=False, mut_chain=None):
    c = 0
    pdb_dict_list = []
    init_alphabet = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T',
                     'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n',
                     'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
    extra_alphabet = [str(item) for item in list(np.arange(300))]
    chain_alphabet = init_alphabet + extra_alphabet

    if input_chain_list:
        chain_alphabet = input_chain_list

    biounit_names = [path_to_pdb]
    for biounit in biounit_names:
        my_dict = {}
        s = 0
        concat_seq = ''
        concat_N = []
        concat_CA = []
        concat_C = []
        concat_O = []
        concat_mask = []
        coords_dict = {}
        for letter in chain_alphabet:
            if ca_only:
                sidechain_atoms = ['CA']
            elif side_chains:
                sidechain_atoms = ["N", "CA", "C", "O", "CB",
                                   "CG", "CG1", "OG1", "OG2", "CG2", "OG", "SG",
                                   "CD", "SD", "CD1", "ND1", "CD2", "OD1", "OD2", "ND2",
                                   "CE", "CE1", "NE1", "OE1", "NE2", "OE2", "NE", "CE2", "CE3",
                                   "NZ", "CZ", "CZ2", "CZ3", "CH2", "OH", "NH1", "NH2"]
            else:
                sidechain_atoms = ['N', 'CA', 'C', 'O']
            xyz, seq, resn_list = alt_parse_PDB_biounits(biounit, atoms=sidechain_atoms, chain=letter)
            if resn_list != 'no_chain':
              my_dict['resn_list_' + letter] = resn_list
            if type(xyz) != str:
                concat_seq += seq[0]
                my_dict['seq_chain_' + letter] = seq[0]
                coords_dict_chain = {}
                if ca_only:
                    coords_dict_chain['CA_chain_' + letter] = xyz.tolist()
                elif side_chains:
                    coords_dict_chain['SG_chain_' + letter] = xyz[:, 11].tolist()
                else:
                    coords_dict_chain['N_chain_' + letter] = xyz[:, 0, :].tolist()
                    coords_dict_chain['CA_chain_' + letter] = xyz[:, 1, :].tolist()
                    coords_dict_chain['C_chain_' + letter] = xyz[:, 2, :].tolist()
                    coords_dict_chain['O_chain_' + letter] = xyz[:, 3, :].tolist()
                my_dict['coords_chain_' + letter] = coords_dict_chain
                s += 1

        fi = biounit.rfind("/")
        my_dict['name'] = biounit[(fi + 1):-4]
        my_dict['num_of_chains'] = s
        my_dict['seq'] = concat_seq
        if s <= len(chain_alphabet):
            pdb_dict_list.append(my_dict)
            c += 1
    return pdb_dict_list

def get_chains_from_pdb(pdb_file):
  parser = PDBParser(QUIET=True)
  structure = parser.get_structure('', pdb_file)
  return [c.id for c in structure.get_chains()]

def get_chains(pdb_file, chain_list):
  # collect list of chains in PDB to match with input
  pdb_chains = get_chains_from_pdb(pdb_file)
  if len(chain_list) < 1: # fill in all chains if left blank
    chain_list = pdb_chains

  for ch in chain_list:
    assert ch in pdb_chains, f"Chain {ch} not found in PDB file with chains {pdb_chains}"

  return chain_list


In [3]:
# %%capture
#@title # **3. Upload or Fetch Input Data**

from google.colab import files
import os
import sys
from urllib import request
from urllib.error import HTTPError
from google.colab._message import MessageError

#@markdown ## You may either specify a PDB code to fetch or upload a custom PDB file.<br><br>

# -------- Collecting Settings for ThermoMPNN run --------- #

!rm /content/*.pdb &> /dev/null

#@markdown PDB code (example: 1PGA):
PDB = "1igy" #@param {type: "string"}

#@markdown Upload Custom PDB?
Custom = False #@param {type: "boolean"}
#@markdown NOTE: If enabled, a `Choose files` button will appear at the bottom of this cell once this cell is run.

#@markdown Chain(s) of Interest (example: A,B,C):
Chains = "" #@param {type:"string"}
#@markdown If left empty, all chains will be used.

# try to upload the PDB file to Colab servers
if Custom:
  try:
    uploaded_pdb = files.upload()
    for fn in uploaded_pdb.keys():
      PDB = os.path.basename(fn)
      if not PDB.endswith('.pdb'):
        raise ValueError(f"Uploaded file {PDB} does not end in '.pdb'. Please check and rename file as needed.")
      os.rename(fn, os.path.join("/content/", PDB))
      pdb_file = os.path.join("/content/", PDB)
  except (MessageError, FileNotFoundError):
    print('\n', '*' * 100, '\n')
    print('Sorry, your input file failed to upload. Please try the backup upload procedure (next cell).')

else:
  try:
    fn = download_pdb(PDB, "/content/")
    if fn is None:
      raise ValueError("Failed to fetch PDB from RSCB. Please double-check PDB code and try again.")
    else:
      pdb_file = fn
  except HTTPError:
    raise HTTPError(f"No protein with code {PDB} exists in RSCB PDB. Please double-check PDB code and try again.")


In [None]:
#@title # **3. Backup Data Upload (ONLY needed if initial upload failed)**

#@markdown ## Colab automatic file uploads are not very reliable. If your file failed to upload automatically, you can do so manually by following these steps.<br><br>

#@markdown #### 1. Click the "Files" icon on the left toolbar. This will open the Colab server file folder.

#@markdown #### 2. The only thing in this folder should be "ThermoMPNN" directory. If any other files are in here, delete them.

#@markdown #### 3. Click the "Upload to session storage" button under the "Files" header. Choose your file for upload.

#@markdown #### 4. Run this cell. ThermoMPNN will find your file in session storage and use it.


#@markdown Chain(s) of Interest (example: A,B,C):
Chains = "" #@param {type:"string"}
#@markdown If left empty, all chains will be used.

PDB = ""

files = sorted(os.listdir('/content/'))
files = [f for f in files if f.endswith('.pdb')]

if len(files) < 1:
  raise ValueError('No PDB file found. Please upload your file before running this cell. Make sure it has a .pdb suffix.')
elif len(files) > 1:
  raise ValueError('Too many PDB files found. Please clear out any other PDBs before running this cell.')
else:
  pdb_file = os.path.join("/content/", files[0])
  PDB = files[0].removesuffix('.pdb')
  print('Successfully uploaded PDB file %s' % (files[0]))

Successfully uploaded PDB file 1bvc.pdb


In [5]:
#@markdown # **4. Run Model**

#@markdown Stability model to use:
Model = "Additive" #@param ["Epistatic", "Additive", "Single"]

#@markdown ##### Model descriptions:
#@markdown * Single: Single mutation SSM sweep. Very fast and accurate.
#@markdown * Additive: Naive double mutation SSM sweep. Ignores non-additive coupling. Very fast but less accurate than Epistatic model for picking stabilizing mutations.
#@markdown * Epistatic: Full double mutation SSM sweep. Slower than Additive model, but more accurate for picking stabilizing mutations.

#@markdown ---------------

#@markdown Allow mutations to cysteine? (Not recommended)
Include = False #@param {type: "boolean"}
#@markdown Due to assay artifacts surrounding disulfide formation, model predictions for cysteine mutations may be overly favorable.

#@markdown ---------------

#@markdown Explicitly penalize disulfide breakage? (Recommended)
Penalize = True #@param {type: "boolean"}

#@markdown ThermoMPNN can usually detect disulfide breakage and penalize accordingly, but you may wish to explicitly forbid disulfide breakage to be safe. This option applies a flat penalty to make sure that breaking disulfides is always disfavored.

#@markdown --------------

#@markdown Batch size for model inference. (Recommended: 256 for Single/Additive models, 2048 for epistatic models)
BatchSize = 256 #@param {type: "integer"}
#@markdown If you hit a memory error, try lowering the BatchSize by factors of 2 to reduce memory usage.

#@markdown --------------

#@markdown Threshold for detecting stabilizing mutations. (Recommended: -1.0)
Threshold = -1.0 #@param {type: "number"}
#@markdown Only mutations with predicted ddG below this value will be kept for analysis. Higher thresholds will result in retaining more mutations.

#@markdown --------------

#@markdown Pairwise distance constraint for double mutants. (Recommended: 5.0)
Distance = 5.0 #@param {type: "number"}
#@markdown Only mutation pairs within this distance (in Angstrom) will be kept for analysis. Higher cutoffs will result in slower runtime and retaining more mutations.

# use input_chain_list to grab correct protein chain
chain_list = Chains.strip().split(',')
if len(chain_list) == 1 and chain_list[0] == '':
  chain_list = []

# validate chain inputs
chain_list = get_chains(pdb_file, chain_list)

# remove cys from alphabet if needed
alphabet = 'ACDEFGHIKLMNPQRSTVWY' if Include else 'ADEFGHIKLMNPQRSTVWY'


In [12]:
#@title # **Run SSM inference**

from v2_ssm import get_config, format_output_single, format_output_double, get_ssm_mutations_double, SSMDataset, run_double, format_output_epistatic
import argparse


def run_single(cfg, model, pdb):
    """Runs single-mutant SSM sweep with ThermoMPNN v2"""

    stime = time.time()

    pdb[0]['mutation'] = Mutation([0], ['A'], ['A'], [0.], '') # placeholder mutation to keep featurization from throwing error

    # featurize input
    device = 'cuda'
    batch = tied_featurize_mut(pdb)
    X, S, mask, lengths, chain_M, chain_encoding_all, residue_idx, mut_positions, mut_wildtype_AAs, mut_mutant_AAs, mut_ddGs, atom_mask = batch

    X = X.to(device)
    S = S.to(device)
    mask = mask.to(device)
    lengths = torch.Tensor(lengths).to(device)
    chain_M = chain_M.to(device)
    chain_encoding_all = chain_encoding_all.to(device)
    residue_idx = residue_idx.to(device)
    mut_ddGs = mut_ddGs.to(device)

    # do single pass through thermompnn
    X = torch.nan_to_num(X, nan=0.0)
    all_mpnn_hid, mpnn_embed, _, mpnn_edges = model.prot_mpnn(X, S, mask, chain_M, residue_idx, chain_encoding_all)

    all_mpnn_hid = torch.cat(all_mpnn_hid[:cfg.model.num_final_layers], -1)
    all_mpnn_hid = torch.squeeze(torch.cat([all_mpnn_hid, mpnn_embed], -1), 0) # [L, E]

    all_mpnn_hid = model.light_attention(torch.unsqueeze(all_mpnn_hid, -1))

    ddg = model.ddg_out(all_mpnn_hid) # [L, 21]

    # subtract wildtype ddgs to normalize
    S = torch.squeeze(S) # [L, ]

    wt_ddg = batched_index_select(ddg, dim=-1, index=S) # [L, 1]
    ddg = ddg - wt_ddg.expand(-1, 21) # [L, 21]
    etime = time.time()
    elapsed = etime - stime
    length = ddg.shape[0]
    print(f'ThermoMPNN single mutant predictions generated in {round(elapsed, 2)} seconds.')
    return ddg, S


def run_epistatic(config, model, pdb, BatchSize, Threshold):
    """Run epistatic model on double mutations """

    stime = time.time()

    pdb[0]['mutation'] = Mutation([0], ['A'], ['A'], [0.], '') # placeholder mutation to keep featurization from throwing error

    # featurize input
    device = 'cuda'
    batch = tied_featurize_mut(pdb)
    X, S, mask, lengths, chain_M, chain_encoding_all, residue_idx, mut_positions, mut_wildtype_AAs, mut_mutant_AAs, mut_ddGs, atom_mask = batch

    X = X.to(device)
    S = S.to(device)
    mask = mask.to(device)
    lengths = torch.Tensor(lengths).to(device)
    chain_M = chain_M.to(device)
    chain_encoding_all = chain_encoding_all.to(device)
    residue_idx = residue_idx.to(device)
    mut_ddGs = mut_ddGs.to(device)

    # do single pass through thermompnn
    X = torch.nan_to_num(X, nan=0.0)
    all_mpnn_hid, mpnn_embed, _, mpnn_edges = model.prot_mpnn(X, S, mask, chain_M, residue_idx, chain_encoding_all)

    # grab double mutation inputs
    MUT_POS, MUT_WT_AA, MUT_MUT_AA = get_ssm_mutations_double(pdb[0])
    dataset = SSMDataset(MUT_POS, MUT_WT_AA, MUT_MUT_AA)
    loader = DataLoader(dataset, shuffle=False, batch_size=BatchSize, num_workers=8)

    args = {'batch_size': BatchSize, 'threshold': Threshold}
    args = argparse.Namespace(**args)
    preds = run_double(all_mpnn_hid, mpnn_embed, config, loader, args, model, X, mask, mpnn_edges)
    ddg, mutations = format_output_epistatic(preds, S, MUT_POS, MUT_WT_AA, MUT_MUT_AA, args.threshold)

    etime = time.time()
    elapsed = etime - stime
    print(f'ThermoMPNN double mutant epistatic model predictions generated in {round(elapsed, 2)} seconds.')
    return ddg, mutations


def renumber_pdb(df, pdb, Model):
    """Renumber output mutations to match PDB numbering for interpretation"""
    # parse PDB
    if (Model == 'Additive') or (Model == 'Epistatic'):
        # grab positions
        df[['mut1', 'mut2']] = df['Mutation'].str.split(':', n=2, expand=True)
        df['pos1'] = df['mut1'].str[1:-1].astype(int) - 1
        df['pos2'] = df['mut2'].str[1:-1].astype(int) - 1

        df['pos1'] = idx_to_pdb_num(pdb, df['pos1'].values)
        df['pos2'] = idx_to_pdb_num(pdb, df['pos2'].values)

        df['wt1'], df['wt2'] = df['mut1'].str[0], df['mut2'].str[0]
        df['mt1'], df['mt2'] = df['mut1'].str[-1], df['mut2'].str[-1]

        df['Mutation'] = df['wt1'] + df['pos1'] + df['mt1'] + ':' + df['wt2'] + df['pos2'] + df['mt2']
        df = df[['ddG (kcal/mol)', 'Mutation', 'CA-CA Distance']].reset_index(drop=True)

    else:
        # grab position
        df['pos'] = df['Mutation'].str[1:-1].astype(int) - 1

        df['pos1'] = idx_to_pdb_num(pdb, df['pos'].values)

        df['wt1'] = df['Mutation'].str[0]
        df['mt1'] = df['Mutation'].str[-1]

        df['Mutation'] = df['wt1'] + df['pos1'] + df['mt1']
        df = df[['ddG (kcal/mol)', 'Mutation']].reset_index(drop=True)

    print(f'ThermoMPNN predictions renumbered.')
    return df


def distance_filter(df, pdb, Distance):
    """filter df based on pdb distances"""

    # grab positions
    df[['mut1', 'mut2']] = df['Mutation'].str.split(':', n=2, expand=True)
    df['pos1'] = df['mut1'].str[1:-1].astype(int) - 1
    df['pos2'] = df['mut2'].str[1:-1].astype(int) - 1

    # get distance matrix
    coords = [k for k in pdb.keys() if k.startswith('coords_chain_')]

    # compile all-by-all coords into big matrix
    coo_all = []
    for coord in coords:
      ch = coord.split('_')[-1]
      coo = np.stack(pdb[coord][f'CA_chain_{ch}']) # [L, 3]
      coo_all.append(coo)
    coo_all = np.concatenate(coo_all) # [L_total, 3]
    dmat = cdist(coo_all, coo_all)

    # filter df based on positions
    pos1, pos2 = df['pos1'].values, df['pos2'].values
    dist_list = []
    for p1, p2 in tqdm(zip(pos1, pos2)):
        dist_list.append(dmat[p1, p2])

    df['CA-CA Distance'] = dist_list
    df = df.loc[df['CA-CA Distance'] <= Distance]
    df.loc[:, 'CA-CA Distance'] = df['CA-CA Distance'].round(2).values

    df = df[['ddG (kcal/mol)', 'Mutation', 'CA-CA Distance']].reset_index(drop=True)
    print(f'Distance matrix generated.')
    return df


def idx_to_pdb_num(pdb, poslist):
  # set up PDB resns and boundaries
  chains = [key[-1] for key in pdb.keys() if key.startswith('resn_list_')]
  resn_lists = [pdb[key] for key in pdb.keys() if key.startswith('resn_list')]
  converter = {}
  offset = 0
  for n, rlist in enumerate(resn_lists):
      chain = chains[n]
      for idx, resid in enumerate(rlist):
          converter[idx + offset] = chain + resid
      offset += idx + 1

  return [converter[pos] for pos in poslist]


def disulfide_penalty(df, pdb_file, chain_list, Model):
  """Automatically detects disulfide breakage based on Cys-Cys distance."""

  pdb_dict = alt_parse_PDB(pdb_file, input_chain_list=chain_list, side_chains=True)

  # collect all SG coordinates from all chains
  coords_all = [k for k in pdb_dict[0].keys() if k.startswith('coords')]
  chains = [c[-1] for c in coords_all]
  sg_coords = [pdb_dict[0][c][f'SG_chain_{chain}'] for c, chain in zip(coords_all, chains)]
  sg_coords = np.concatenate(sg_coords, axis=0)

  # calculate pairwise distance and threshold to find disulfides
  dist = cdist(sg_coords, sg_coords)
  dist = np.nan_to_num(dist, 10000)
  hits = np.where((dist < 3) & (dist > 0)) # tuple of two [N] arrays of indices

  if Model == 'Single':
    df['wtAA'] = df['Mutation'].str[0]
    df['mutAA'] = df['Mutation'].str[-1]
    df['pos'] = df['Mutation'].str[1:-1].astype(int) - 1

    # match hit indices to actual resns for penalty
    bad_resns = []
    for h in hits[0]:
      bad_resns.append(h)

    print('Identified the following disulfide engaged residues:', bad_resns)

    # apply penalty
    penalty = 2  # in kcal/mol - higher is less stable
    mask = df['pos'].isin(bad_resns) & (df['wtAA'] != df['mutAA'])

    df.loc[mask, 'ddG (kcal/mol)'] = df.loc[mask, 'ddG (kcal/mol)'] + penalty
    return df[['Mutation', 'ddG (kcal/mol)']].reset_index(drop=True)

  else:
    df[['mut1', 'mut2']] = df['Mutation'].str.split(':', n=2, expand=True)
    df['wtAA1'] = df['mut1'].str[0]
    df['mutAA1'] = df['mut1'].str[-1]
    df['pos1'] = df['mut1'].str[1:-1].astype(int) - 1

    df['wtAA2'] = df['mut2'].str[0]
    df['mutAA2'] = df['mut2'].str[-1]
    df['pos2'] = df['mut2'].str[1:-1].astype(int) - 1

    bad_resns = []
    for h in hits[0]:
      bad_resns.append(h)

    print('Identified the following disulfide engaged residues:', bad_resns)

    # apply penalty
    penalty = 2  # in kcal/mol - higher is less stable
    mask = df['pos1'].isin(bad_resns) & (df['wtAA1'] != df['mutAA1'])
    mask2 = df['pos2'].isin(bad_resns) & (df['wtAA2'] != df['mutAA2'])
    mask = mask | mask2


    df.loc[mask, 'ddG (kcal/mol)'] = df.loc[mask, 'ddG (kcal/mol)'] + penalty
    return df[['Mutation', 'ddG (kcal/mol)', 'CA-CA Distance']].reset_index(drop=True)


# load config automatically

config = get_config(Model.lower())
config.platform.thermompnn_dir = '/content/ThermoMPNN-D'

if Model == 'Single' or Model == 'Additive':
  # load model
  model_path = '/content/ThermoMPNN-D/model_weights/ThermoMPNN-ens1.ckpt'
  model = TransferModelPLv2.load_from_checkpoint(checkpoint_path=model_path, cfg=config, device='gpu').model
  model.eval()
  model.cuda()

  # run inference routine
  pdb = alt_parse_PDB(pdb_file, chain_list)
  ddg, S = run_single(config, model, pdb)

  if Model == 'Single':
    ddg, mutations = format_output_single(ddg, S, Threshold)
  elif Model == 'Additive':
    ddg, mutations = format_output_double(ddg, S, Threshold)

else:
  # load model
  model_path = '/content/ThermoMPNN-D/model_weights/ThermoMPNN-D-ens1.ckpt'
  model = TransferModelPLv2Siamese.load_from_checkpoint(model_path, cfg=config, device='gpu').model
  model.eval()
  model.cuda()

  # run inference routine
  pdb = alt_parse_PDB(pdb_file, chain_list)
  ddg, mutations = run_epistatic(config, model, pdb, BatchSize, Threshold)

# compile output dataframe and sort/filter it
df = pd.DataFrame({
    'ddG (kcal/mol)': ddg,
    'Mutation': mutations
})

df.loc[:, 'ddG (kcal/mol)'] = df['ddG (kcal/mol)'].round(4).values

if df.shape[0] == 0:
  raise ValueError("No valid mutations passed your ddG Threshold. Please raise the Threshold value and try again.")

if Model != 'Single':
    df = distance_filter(df, pdb[0], Distance)
    if df.shape[0] == 0:
      raise ValueError("No valid mutations passed your Distance constraint. Please raise the Distance or Threshold value and try again.")

if Penalize:
  df = disulfide_penalty(df, pdb_file, chain_list, Model)

df = df.dropna(subset=['ddG (kcal/mol)'])
if Threshold <= 0.:
  df = df.sort_values(by=['ddG (kcal/mol)'])

if Model != 'Single': # sort to have same output order
    df[['mut1', 'mut2']] = df['Mutation'].str.split(':', n=2, expand=True)
    df['pos1'] = df['mut1'].str[1:-1].astype(int) + 1
    df['pos2'] = df['mut2'].str[1:-1].astype(int) + 1

    df = df.sort_values(by=['pos1', 'pos2'])
    df = df[['ddG (kcal/mol)', 'Mutation', 'CA-CA Distance']].reset_index(drop=True)

try:
  df = renumber_pdb(df, pdb[0], Model)
except (KeyError, IndexError):
  print('PDB renumbering failed (sorry!) You can still use the raw position data. Or, you can renumber your PDB, fill any weird gaps, and try again.')


Loading model %s /content/ThermoMPNN-D/vanilla_model_weights/v_48_020.pt
setting ProteinMPNN dropout: 0.0
MLP HIDDEN SIZES: [384, 64, 32, 21]


  checkpoint = torch.load(checkpoint_path, map_location='cpu')


ThermoMPNN single mutant predictions generated in 0.03 seconds.


198667it [00:00, 571202.22it/s]


ThermoMPNN double mutant additive model predictions calculated in 4.72 seconds.


190909it [00:00, 1703928.19it/s]


Distance matrix generated.
Identified the following disulfide engaged residues: [21, 86, 132, 192, 212, 233, 307, 357, 423, 450, 452, 455, 457, 489, 555, 605, 671, 711, 776, 822, 882, 902, 923, 997, 1047, 1113, 1140, 1142, 1145, 1145, 1147, 1147, 1179, 1245, 1295, 1361]
ThermoMPNN predictions renumbered.


In [13]:
#@title **Visualize data in an interactive table**
from google.colab import data_table

data_table.enable_dataframe_formatter()
data_table.DataTable(df, include_index=True, num_rows_per_page=10)

Unnamed: 0,ddG (kcal/mol),Mutation,CA-CA Distance
0,-1.0446,AA4I:HA5W,3.85
1,-1.5014,AA4I:AA25V,4.05
2,-1.4223,AA4V:AA25V,4.05
3,-1.1489,AA4I:AA25I,4.05
4,-1.1082,AA4M:AA25V,4.05
...,...,...,...
301,0.9609,CD456A:SD457T,3.81
302,0.9766,CD456F:SD457L,3.81
303,0.9853,CD456L:SD457I,3.81
304,-1.0315,SD457V:ED469C,4.22


In [None]:
#@title # **Save Output as CSV**

# ---------- Collect output into DF and save as CSV ---------- #
from google.colab import files

#@markdown Specify prefix for file saving (e.g., MyProtein). Leave blank to use input PDB code.
PREFIX = "example" #@param {type:"string"}

#@markdown NOTE: If you wish to retrieve your files manually, you may do so in the **Files** tab in the leftmost toolbar.

#@markdown NOTE: Make sure you click "Allow" if your browser asks to permit downloads at this step.

#@markdown Verbose output? This means saving more individual columns
VERBOSE = True

df['ddG (kcal/mol)'] = df['ddG (kcal/mol)'].round(4)

if len(PREFIX) < 1:
  PREFIX = pdb_file.split('.')[0]
else:
  PREFIX = os.path.join('/content/', PREFIX)

full_fname = PREFIX + '.csv'

if Model == 'Single':
  df['Wildtype AA'] = df['Mutation'].str[0]
  df['Mutant AA'] = df['Mutation'].str[-1]
  df['Position'] = df['Mutation'].str[2:-1]
  df['Chain'] = df['Mutation'].str[1]

else:
  df[['Mutation 1', 'Mutation 2']] = df['Mutation'].str.split(':', n=2, expand=True)
  df['Wildtype AA 1'], df['Wildtype AA 2'] = df['Mutation 1'].str[0], df['Mutation 2'].str[0]
  df['Mutant AA 1'], df['Mutant AA 2'] = df['Mutation 1'].str[-1], df['Mutation 2'].str[-1]
  df['Position 1'], df['Position 2'] = df['Mutation 1'].str[2:-1], df['Mutation 2'].str[2:-1]
  df['Chain 1'], df['Chain 2'] = df['Mutation 1'].str[1], df['Mutation 2'].str[1]

df.to_csv(full_fname, index=True)
files.download(full_fname)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

# APPENDIX

## License

The source code for ThermoMPNN-D, including license information, can be found [here](https://github.com/Kuhlman-Lab/ThermoMPNN-D)

## Citation Information

If you use ThermoMPNN-D in your research, please cite the following paper(s):

### Epistatic or Additive model:
Dieckhaus, H., Kuhlman, B., *Protein stability models fail to capture epistatic interactions of double point mutations*. **2024**, bioRxiv, doi: https://doi.org/10.1101/2024.08.20.608844.

### Single mutant model:
Dieckhaus, H., Brocidiacono, M., Randolph, N., Kuhlman, B. *Transfer learning to leverage larger datasets for improved prediction of protein stability changes.* Proc Natl Acad Sci **2024**, 121(6), e2314853121, doi: https://doi.org/10.1073/pnas.2314853121.

## Contact Information

# Please contact Henry Dieckhaus at dieckhau@unc.edu to report any bugs or issues with this notebook. You may also submit issues on the ThermoMPNN-D GitHub page [here](https://github.com/Kuhlman-Lab/ThermoMPNN-D/issues).
