<a href="https://colab.research.google.com/github/Kuhlman-Lab/ThermoMPNN-D/blob/main/ThermoMPNN-D.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# <center>**This is a Colab implementation of ThermoMPNN-D**</center>

---

ThermoMPNN-D is an updated version of ThermoMPNN for predicting double point mutations.

It was trained on an augmented version of the Megascale double mutant dataset. It is state-of-the-art at predicting stabilizing double mutations.


### **COLAB TIPS:**
- The cells of this notebook are meant to be executed *in order*, so users should start from the top and work their way down.
- Executable cells can be run by clicking the PLAY button (>) that appears when you hover over each cell, or by using **Shift+Enter**.
- Make sure GPU is enabled by checking `Runtime` -> `Change Runtime Type`
  - Make sure that `Runtime type` is set to `Python 3`
  - Make sure that `Hardware accelerator` is set to `GPU`
  - Click `Save` to confirm

- If the notebook freezes up or otherwise crashes, go to `Runtime` -> `Restart Runtime` and try again.


In [1]:
%%capture

#@title # 1. Set up **ThermoMPNN environment**
#@markdown Import ThermoMPNN and its dependencies to this session. This may take a minute or two.

#@markdown You only need to do this once *per session*. To re-run ThermoMPNN on a new protein, you may start on Step 3.

#@markdown ---

# cleaning out any remaining data
!cd /content
!rm -rf /content/ThermoMPNN-D
!rm -rf /content/sample_data
!rm /content/*.pdb
!rm /content/*.csv

# import ThermoMPNN-D github repo
import os
if not os.path.exists("/content/ThermoMPNN-D"):
  !git clone https://github.com/Kuhlman-Lab/ThermoMPNN-D.git
  %cd /content/ThermoMPNN-D

# downloading various dependencies - add more if needed later
! pip install omegaconf wandb pytorch-lightning biopython nglview


In [7]:
#@title # **2. Set up ThermoMPNN imports and functions**

# Setting up imports and functions for use later in the protocol

from google.colab import files
import os
import sys
from urllib import request
from urllib.error import HTTPError
from google.colab._message import MessageError

import torch
import sys
import numpy as np
from dataclasses import dataclass
from Bio.PDB import PDBParser
from omegaconf import OmegaConf
import pandas as pd
from copy import deepcopy
from tqdm import tqdm

from torch.utils.data import Dataset, DataLoader

tMPNN_path = '/content/ThermoMPNN-D'
if tMPNN_path not in sys.path:
  sys.path.append(tMPNN_path)

from thermompnn.datasets.dataset_utils import Mutation
from thermompnn.datasets.v2_datasets import tied_featurize_mut

from thermompnn.train_thermompnn import parse_cfg
from thermompnn.trainer.v2_trainer import TransferModelPLv2

def download_pdb(pdbcode, datadir, downloadurl="https://files.rcsb.org/download/"):
    """
    Downloads a PDB file from the Internet and saves it in a data directory.
    :param pdbcode: The standard PDB ID e.g. '3ICB' or '3icb'
    :param datadir: The directory where the downloaded file will be saved
    :param downloadurl: The base PDB download URL, cf.
        `https://www.rcsb.org/pages/download/http#structures` for details
    :return: the full path to the downloaded PDB file or None if something went wrong
    """

    pdbfn = pdbcode + ".pdb"
    url = downloadurl + pdbfn
    outfnm = os.path.join(datadir, pdbfn)
    try:
        request.urlretrieve(url, outfnm)
        return outfnm
    except Exception as err:
        print(str(err), file=sys.stderr)
        return None

def alt_parse_PDB_biounits(x, atoms=['N', 'CA', 'C'], chain=None):
    '''
  input:  x = PDB filename
          atoms = atoms to extract (optional)
  output: (length, atoms, coords=(x,y,z)), sequence
  '''

    alpha_1 = list("ARNDCQEGHILKMFPSTWYV-")
    states = len(alpha_1)
    alpha_3 = ['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 'HIS', 'ILE',
               'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL', 'GAP']

    aa_1_N = {a: n for n, a in enumerate(alpha_1)}
    aa_3_N = {a: n for n, a in enumerate(alpha_3)}
    aa_N_1 = {n: a for n, a in enumerate(alpha_1)}
    aa_1_3 = {a: b for a, b in zip(alpha_1, alpha_3)}
    aa_3_1 = {b: a for a, b in zip(alpha_1, alpha_3)}

    def AA_to_N(x):
        # ["ARND"] -> [[0,1,2,3]]
        x = np.array(x);
        if x.ndim == 0: x = x[None]
        return [[aa_1_N.get(a, states - 1) for a in y] for y in x]

    def N_to_AA(x):
        # [[0,1,2,3]] -> ["ARND"]
        x = np.array(x);
        if x.ndim == 1: x = x[None]
        return ["".join([aa_N_1.get(a, "-") for a in y]) for y in x]

    xyz, seq, min_resn, max_resn = {}, {}, 1e6, -1e6
    resn_list = []
    for line in open(x, "rb"):
        line = line.decode("utf-8", "ignore").rstrip()

        # handling MSE and SEC residues
        if line[:6] == "HETATM" and line[17:17 + 3] == "MSE":
            line = line.replace("HETATM", "ATOM  ")
            line = line.replace("MSE", "MET")
        elif line[17:17 + 3] == "MSE":
            line = line.replace("MSE", "MET")
        elif line[17:17 + 3] == "SEC":
            line = line.replace("SEC", "CYS")

        if line[:4] == "ATOM":
            ch = line[21:22]
            if ch == chain or chain is None:
                atom = line[12:12 + 4].strip()
                resi = line[17:17 + 3]
                resn = line[22:22 + 5].strip()

                # RAW resn is defined HERE
                if resn not in resn_list:
                  resn_list.append(resn) # NEED to keep ins code here
                x, y, z = [float(line[i:(i + 8)]) for i in [30, 38, 46]]
                if resn[-1].isalpha():
                    resa, resn = resn[-1], int(resn[:-1]) - 1
                else:
                    resa, resn = "", int(resn) - 1
                if resn < min_resn:
                    min_resn = resn
                if resn > max_resn:
                    max_resn = resn
                if resn not in xyz:
                    xyz[resn] = {}
                if resa not in xyz[resn]:
                    xyz[resn][resa] = {}
                if resn not in seq:
                    seq[resn] = {}
                if resa not in seq[resn]:
                    seq[resn][resa] = resi

                if atom not in xyz[resn][resa]:
                    xyz[resn][resa][atom] = np.array([x, y, z])
                # print(resn_list)

    # convert to numpy arrays, fill in missing values
    seq_, xyz_ = [], []
    # print(seq)
    try:
        for resn in range(min_resn, max_resn + 1):
            if resn in seq:
              # CLEANED resn is defined HERE
                # resn_list.append(str(resn + 1))
                # print(resn)
                for k in sorted(seq[resn]): seq_.append(aa_3_N.get(seq[resn][k], 20))
            else:
                seq_.append(20)

            if resn in xyz:
                for k in sorted(xyz[resn]):
                    for atom in atoms:
                        if atom in xyz[resn][k]:
                            xyz_.append(xyz[resn][k][atom])
                        else:
                            xyz_.append(np.full(3, np.nan))
            else:
                for atom in atoms: xyz_.append(np.full(3, np.nan))
        return np.array(xyz_).reshape(-1, len(atoms), 3), N_to_AA(np.array(seq_)), list(dict.fromkeys(resn_list))
    except TypeError:
        return 'no_chain', 'no_chain', 'no_chain'

def alt_parse_PDB(path_to_pdb, input_chain_list=None, ca_only=False, side_chains=False, mut_chain=None):
    c = 0
    pdb_dict_list = []
    init_alphabet = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T',
                     'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n',
                     'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
    extra_alphabet = [str(item) for item in list(np.arange(300))]
    chain_alphabet = init_alphabet + extra_alphabet

    if input_chain_list:
        chain_alphabet = input_chain_list

    biounit_names = [path_to_pdb]
    for biounit in biounit_names:
        my_dict = {}
        s = 0
        concat_seq = ''
        concat_N = []
        concat_CA = []
        concat_C = []
        concat_O = []
        concat_mask = []
        coords_dict = {}
        for letter in chain_alphabet:
            if ca_only:
                sidechain_atoms = ['CA']
            elif side_chains:
                sidechain_atoms = ["N", "CA", "C", "O", "CB",
                                   "CG", "CG1", "OG1", "OG2", "CG2", "OG", "SG",
                                   "CD", "SD", "CD1", "ND1", "CD2", "OD1", "OD2", "ND2",
                                   "CE", "CE1", "NE1", "OE1", "NE2", "OE2", "NE", "CE2", "CE3",
                                   "NZ", "CZ", "CZ2", "CZ3", "CH2", "OH", "NH1", "NH2"]
            else:
                sidechain_atoms = ['N', 'CA', 'C', 'O']
            xyz, seq, resn_list = alt_parse_PDB_biounits(biounit, atoms=sidechain_atoms, chain=letter)
            if mut_chain is not None:
              if resn_list != 'no_chain' and letter in mut_chain:
                  my_dict['resn_list'] = list(resn_list)
            if type(xyz) != str:
                concat_seq += seq[0]
                my_dict['seq_chain_' + letter] = seq[0]
                coords_dict_chain = {}
                if ca_only:
                    coords_dict_chain['CA_chain_' + letter] = xyz.tolist()
                elif side_chains:
                    coords_dict_chain['SG_chain_' + letter] = xyz[:, 11].tolist()
                else:
                    coords_dict_chain['N_chain_' + letter] = xyz[:, 0, :].tolist()
                    coords_dict_chain['CA_chain_' + letter] = xyz[:, 1, :].tolist()
                    coords_dict_chain['C_chain_' + letter] = xyz[:, 2, :].tolist()
                    coords_dict_chain['O_chain_' + letter] = xyz[:, 3, :].tolist()
                my_dict['coords_chain_' + letter] = coords_dict_chain
                s += 1

        fi = biounit.rfind("/")
        if mut_chain is None:
          my_dict['resn_list'] = list(resn_list)
        my_dict['name'] = biounit[(fi + 1):-4]
        my_dict['num_of_chains'] = s
        my_dict['seq'] = concat_seq
        # my_dict['resn_list'] = list(resn_list)
        if s <= len(chain_alphabet):
            pdb_dict_list.append(my_dict)
            c += 1
    return pdb_dict_list

def get_chains_from_pdb(pdb_file):
  parser = PDBParser(QUIET=True)
  structure = parser.get_structure('', pdb_file)
  return [c.id for c in structure.get_chains()]

def get_chains(pdb_file, chain_list):
  # collect list of chains in PDB to match with input
  pdb_chains = get_chains_from_pdb(pdb_file)
  if len(chain_list) < 1: # fill in all chains if left blank
    chain_list = pdb_chains

  for ch in chain_list:
    assert ch in pdb_chains, f"Chain {ch} not found in PDB file with chains {pdb_chains}"

  return chain_list

In [5]:
# %%capture
#@title # **3. Upload or Fetch Input Data**

from google.colab import files
import os
import sys
from urllib import request
from urllib.error import HTTPError
from google.colab._message import MessageError

#@markdown ## You may either specify a PDB code to fetch or upload a custom PDB file.<br><br>

# -------- Collecting Settings for ThermoMPNN run --------- #

!rm /content/*.pdb &> /dev/null

#@markdown PDB code (example: 1PGA):
PDB = "1PGA" #@param {type: "string"}

#@markdown Upload Custom PDB?
Custom = False #@param {type: "boolean"}
#@markdown NOTE: If enabled, a `Choose files` button will appear at the bottom of this cell once this cell is run.


# try to upload the PDB file to Colab servers
if Custom:
  try:
    uploaded_pdb = files.upload()
    for fn in uploaded_pdb.keys():
      PDB = os.path.basename(fn)
      if not PDB.endswith('.pdb'):
        raise ValueError(f"Uploaded file {PDB} does not end in '.pdb'. Please check and rename file as needed.")
      os.rename(fn, os.path.join("/content/", PDB))
      pdb_file = os.path.join("/content/", PDB)
  except (MessageError, FileNotFoundError):
    print('\n', '*' * 100, '\n')
    print('Sorry, your input file failed to upload. Please try the backup upload procedure (next cell).')

else:
  try:
    fn = download_pdb(PDB, "/content/")
    if fn is None:
      raise ValueError("Failed to fetch PDB from RSCB. Please double-check PDB code and try again.")
    else:
      pdb_file = fn
  except HTTPError:
    raise HTTPError(f"No protein with code {PDB} exists in RSCB PDB. Please double-check PDB code and try again.")


In [None]:
#@title # **3. Backup Data Upload (ONLY needed if initial upload failed)**

#@markdown ## Colab automatic file uploads are not very reliable. If your file failed to upload automatically, you can do so manually by following these steps.<br><br>

#@markdown #### 1. Click the "Files" icon on the left toolbar. This will open the Colab server file folder.

#@markdown #### 2. The only thing in this folder should be "ThermoMPNN" directory. If any other files are in here, delete them.

#@markdown #### 3. Click the "Upload to session storage" button under the "Files" header. Choose your file for upload.

#@markdown #### 4. Run this cell. ThermoMPNN will find your file in session storage and use it.

PDB = ""

files = sorted(os.listdir('/content/'))
files = [f for f in files if f.endswith('.pdb')]

if len(files) < 1:
  raise ValueError('No PDB file found. Please upload your file before running this cell. Make sure it has a .pdb suffix.')
elif len(files) > 1:
  raise ValueError('Too many PDB files found. Please clear out any other PDBs before running this cell.')
else:
  pdb_file = os.path.join("/content/", files[0])
  PDB = files[0].removesuffix('.pdb')
  print('Successfully uploaded PDB file %s' % (files[0]))

In [10]:
#@markdown # **4. Run Model**

#@markdown Stability model to use:
Model = "Epistatic" #@param ["Epistatic", "Additive", "Single"]

#@markdown Chain(s) of Interest (example: A,B,C):
Chains = "A" #@param {type:"string"}
#@markdown If left empty, all chains will be used.

#@markdown ---------------
#@markdown Include Cysteines?
Include = False #@param {type: "boolean"}
#@markdown NOTE: Due to assay artifacts surrounding disulfide formation, model predictions for surface cysteine mutations may be overly favorable.

#@markdown Explicitly penalize disulfide breakage? Recommended.
Penalize = True #@param {type: "boolean"}

#@markdown --------------

BatchSize = 256 #@param {type: "integer"}
#@markdown Recommended: 256 (single/additive) or 2048 (epistatic). Should be lowered if you hit a memory error.

Threshold = -0.5 #@param {type: "number"}
#@markdown Ignored for single mutant model. Only mutations BELOW (more stable than) this threshold are kept. This is useful for speed and memory, since double mutation data grows exponentially.

Distance = 10.0 #@param {type: "number"}
#@markdown Ignored for single mutant model. Only mutation pairs BELOW (within) this distance (in Angstrom) are kept. Useful for speed and memory savings.

# ---------- END OPTIONS -------------- #

# use input_chain_list to grab correct protein chain
chain_list = Chains.strip().split(',')

# validate chain inputs
chain_list = get_chains(pdb_file, chain_list)
print(f'Using ThermoMPNN-I to predict stability across chain(s): {chain_list}')

# remove cys from alphabet if needed
alphabet = 'ACDEFGHIKLMNPQRSTVWY' if Include else 'ADEFGHIKLMNPQRSTVWY'





Using ThermoMPNN-I to predict stability across chain(s): ['A']


In [None]:
# load config and model

# load mutation dataset

# run inference

# collate output data