#<b>SSEmbLab: A joint embedding of protein sequence and structure enables robust variant effect predictions</b>

Welcome! Here we provide a simple pipeline for making SSEmb fitness predictions for all possible variants of a protein sequence based on a user-selected protein structure.

Important notes before running this notebook:

*   This notebook should be run in a GPU session (see: `Runtime` ->  `Change runtime type`).
*   Run one cell at a time sequentially and do not use the `Run all` function as the set-up of the environment requires a restart of the kernel.
*   The `Set-up environment` step is very slow and can take up to 20 min before completion. The screen might become unresponsive during this time.
*   The current pipeline is limited to processing a single PDB and a single chain id at a time, which can be specified by the user.
*   The MSA generation procedure used in this notebook is slightly different from the one used to generate the predictions presented in the paper. However, overall performance is expected to be the same.
*   Output SSEmb predictions are saved in the file `output/df_ssemb.csv`, where they can be downloaded to local storage.

# <b> Acknowledgements </b>
Code for the original MSA Transformer was developed by the ESM team at Meta Research:
https://github.com/facebookresearch/esm.

Code for the original GVP-GNN was developed by Jing et al:
https://github.com/drorlab/gvp-pytorch.

We thank Milot Mirdita and the rest of the ColabFold Search team for help in setting up the MSA generation pipeline. Code for querying the MMseqs2 webserver has been copied with permission from: https://github.com/sokrypton/ColabFold/blob/57b220e028610ba7331ebe1ef9c2d0419992469a/colabfold/colabfold.py#L72.

#<b>Set-up environment</b>

In [None]:
try:
    import google.colab
    ! pip install condacolab
    import condacolab
    condacolab.install()
except ModuleNotFoundError:
    pass

Collecting condacolab
  Downloading condacolab-0.1.9-py3-none-any.whl (7.2 kB)
Installing collected packages: condacolab
Successfully installed condacolab-0.1.9
⏬ Downloading https://github.com/conda-forge/miniforge/releases/download/23.11.0-0/Mambaforge-23.11.0-0-Linux-x86_64.sh...
📦 Installing...
📌 Adjusting configuration...
🩹 Patching environment...
⏲ Done in 0:00:11
🔁 Restarting kernel...


In [None]:
# Remove sample data
! rm -r sample_data

# Install dependencies with mamba
! mamba install scipy scikit-learn pdbfixer=1.8.1 openmm=8.0 pandas -c omnia -c conda-forge -c anaconda -c defaults --yes
! mamba install conda-forge::biopython
! mamba install nvidia/label/cuda-12.1.0::cuda-toolkit
! mamba install pytorch=2.3.0 pytorch-cuda=12.1 -c pytorch -c nvidia
! pip install pdb-tools
! pip install torch_geometric
! pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.3.0+cu121.html
! mamba install bioconda::f5c
! mamba install bioconda::mmseqs2
! pip install fair-esm==2.0.0

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Extracting   (3)  ⣾  [2K[1A[2K[1A[2K[0G[+] 4m:54.8s
Downloading      100%
Extracting   (3)  ⣾  [2K[1A[2K[1A[2K[0G[+] 4m:54.9s
Downloading      100%
Extracting   (3)  ⣾  [2K[1A[2K[1A[2K[0G[+] 4m:55.0s
Downloading      100%
Extracting   (3)  ⣾  [2K[1A[2K[1A[2K[0G[+] 4m:55.1s
Downloading      100%
Extracting   (3)  ⣾  [2K[1A[2K[1A[2K[0G[+] 4m:55.2s
Downloading      100%
Extracting   (3)  ⣾  [2K[1A[2K[1A[2K[0G[+] 4m:55.3s
Downloading      100%
Extracting   (3)  ⣾  [2K[1A[2K[1A[2K[0G[+] 4m:55.4s
Downloading      100%
Extracting   (2)  ⣾  [2K[1A[2K[1A[2K[0G[+] 4m:55.5s
Downloading      100%
Extracting   (2)  ⣾  [2K[1A[2K[1A[2K[0G[+] 4m:55.6s
Downloading      100%
Extracting   (2)  ⣾  [2K[1A[2K[1A[2K[0G[+] 4m:55.7s
Downloading      100%
Extracting   (2)  ⣾  [2K[1A[2K[1A[2K[0G[+] 4m:55.8s
Downloading      100%
Extracting   (2)  ⣾  [2K[1A[2K[1A[2K[0G[+] 4m:55.

#<b>Download files</b>

In [None]:
%%bash
mkdir -p data
mkdir -p data/structure
mkdir -p data/structure/raw
mkdir -p data/msa
mkdir -p output
mkdir -p pdb_parser_scripts
mkdir -p models
mkdir -p models/gvp
mkdir -p models/msa_transformer

wget -nc https://raw.github.com/KULL-Centre/_2023_Blaabjerg_SSEmb/main/src/pdb_parser_scripts/clean_pdb.py -P pdb_parser_scripts/ &> /dev/null
wget -nc https://raw.github.com/KULL-Centre/_2023_Blaabjerg_SSEmb/main/src/pdb_parser_scripts/clean_pdbs.sh -P pdb_parser_scripts/ &> /dev/null
wget -nc https://raw.github.com/KULL-Centre/_2023_Blaabjerg_SSEmb/main/src/pdb_parser_scripts/parse_pdbs.py -P pdb_parser_scripts/ &> /dev/null
chmod a+rx pdb_parser_scripts/clean_pdbs.sh

wget -nc https://raw.github.com/KULL-Centre/_2023_Blaabjerg_SSEmb/main/src/models/gvp/__init__.py -P models/gvp/ &> /dev/null
wget -nc https://raw.github.com/KULL-Centre/_2023_Blaabjerg_SSEmb/main/src/models/gvp/__init__.py -P models/gvp/ &> /dev/null
wget -nc https://raw.github.com/KULL-Centre/_2023_Blaabjerg_SSEmb/main/src/models/gvp/data.py -P models/gvp/ &> /dev/null
wget -nc https://raw.github.com/KULL-Centre/_2023_Blaabjerg_SSEmb/main/src/models/gvp/models.py -P models/gvp/ &> /dev/null
wget -nc https://raw.github.com/KULL-Centre/_2023_Blaabjerg_SSEmb/main/src/models/msa_transformer/__init__.py -P models/msa_transformer/ &> /dev/null
wget -nc https://raw.github.com/KULL-Centre/_2023_Blaabjerg_SSEmb/main/src/models/msa_transformer/axial_attention.py -P models/msa_transformer/ &> /dev/null
wget -nc https://raw.github.com/KULL-Centre/_2023_Blaabjerg_SSEmb/main/src/models/msa_transformer/constants.py -P models/msa_transformer/ &> /dev/null
wget -nc https://raw.github.com/KULL-Centre/_2023_Blaabjerg_SSEmb/main/src/models/msa_transformer/data.py -P models/msa_transformer/ &> /dev/null
wget -nc https://raw.github.com/KULL-Centre/_2023_Blaabjerg_SSEmb/main/src/models/msa_transformer/model.py -P models/msa_transformer/ &> /dev/null
wget -nc https://raw.github.com/KULL-Centre/_2023_Blaabjerg_SSEmb/main/src/models/msa_transformer/modules.py -P models/msa_transformer/ &> /dev/null
wget -nc https://raw.github.com/KULL-Centre/_2023_Blaabjerg_SSEmb/main/src/models/msa_transformer/multihead_attention.py -P models/msa_transformer/ &> /dev/null

wget -nc https://zenodo.org/records/12798019/files/weights.tar.gz?download=1 -P data/ &> /dev/null
tar -xvzf data/weights.tar.gz?download=1 -C data/
rm data/weights.tar.gz?download=1

._weights
weights/
weights/._.DS_Store
weights/.DS_Store
weights/._msa_alphabet.pkl
weights/msa_alphabet.pkl
weights/._final_cath_msa_transformer_110.pt
weights/final_cath_msa_transformer_110.pt
weights/._final_cath_gvp_110.pt
weights/final_cath_gvp_110.pt


tar: Ignoring unknown extended header keyword 'LIBARCHIVE.xattr.com.apple.quarantine'
tar: Ignoring unknown extended header keyword 'SCHILY.fflags'
tar: Ignoring unknown extended header keyword 'LIBARCHIVE.xattr.com.apple.FinderInfo'
tar: Ignoring unknown extended header keyword 'LIBARCHIVE.xattr.com.apple.quarantine'
tar: Ignoring unknown extended header keyword 'LIBARCHIVE.xattr.com.apple.macl'
tar: Ignoring unknown extended header keyword 'LIBARCHIVE.xattr.com.apple.quarantine'
tar: Ignoring unknown extended header keyword 'LIBARCHIVE.xattr.com.apple.quarantine'


#<b>Define functions</b>

In [None]:
import os
import sys
import glob
import subprocess
import torch
import torch.nn.functional as F
import models.gvp.data, models.gvp.models
import json
import os
import glob
import numpy as np
import torch_geometric
from functools import partial
import esm
import random
from models.msa_transformer.model import MSATransformer
from models.gvp.models import SSEmbGNN
import pickle
import Bio
import Bio.PDB
from Bio import SeqIO
from Bio.SeqRecord import SeqRecord
import pandas as pd
import re
from collections import OrderedDict
from typing import Tuple, List
import tqdm
import tarfile
import requests
import string
import time
import logging
logger = logging.getLogger(__name__)

def parse(pdb_dir):
    # Load PDBS
    pdb_filenames = sorted(glob.glob(f"{pdb_dir}/cleaned/*.pdb"))

    # Create fasta file
    fh = open(f"{pdb_dir}/seqs.fasta","w")

    # Initialize list of pdb dicts
    pdb_dict_list = []

    # Loop over proteins
    for pdb_filename in pdb_filenames:

        # Parse structure with Biopython
        pdb_parser = Bio.PDB.PDBParser()
        pdb_id = os.path.basename(pdb_filename).split("/")[-1][:-4]
        structure = pdb_parser.get_structure(pdb_id, pdb_filename)
        first_model = structure.get_list()[0]
        first_model.child_list = sorted(first_model.child_list) # Sort chains alphabetically

        # Iterate over chain,residue,atoms and extract features
        for chain in first_model: # Loop over chains even though there is only 1

            # Initialize
            chain_id = chain.id
            seq = []
            coords = []
            pdb_dict = {}

            for j, residue in enumerate(chain):
                atom_names = []
                backbone_coords = []

                for atom in residue:
                    # Extract atom features
                    if atom.name in ["N","CA","C","O"]:
                        atom_names.append(atom.name)
                        #backbone_coords.append(list(atom.coord))
                        backbone_coords.append([str(x) for x in atom.coord])

                # Check that all backbone atoms are present
                if atom_names == ["N","CA","C","O"] and len(backbone_coords)==4 and residue._id[0].startswith("H_") == False: # HETATM check

                    # Add coordinates
                    coords.append(backbone_coords)

                    # Add residue to sequence
                    seq.append(Bio.PDB.Polypeptide.protein_letters_3to1.get(residue.resname))

            # Save coords+seq to dict
            pdb_dict["name"] = pdb_id
            pdb_dict["coords"] = coords
            pdb_dict["seq"] = "".join(seq)
            pdb_dict_list.append(pdb_dict)

            # Output seq to fasta
            fh.write(f">{pdb_dict['name']}\n")
            fh.write("".join(seq))
            fh.write("\n")
    fh.close()

    # Save total coord dict
    with open(f'{pdb_dir}/coords.json', 'w') as fp:
        json.dump(pdb_dict_list, fp)
    fp.close()

TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'

def run_mmseqs2(x, prefix="mmseqs2", use_env=True, use_filter=True,
                use_templates=False, filter=None, use_pairing=False, pairing_strategy="greedy",
                host_url="https://api.colabfold.com",
                user_agent="SSEmb") -> Tuple[List[str], List[str]]:
  # The following code has been copied with permission from the ColabFold team
  # here: https://github.com/sokrypton/ColabFold/blob/57b220e028610ba7331ebe1ef9c2d0419992469a/colabfold/colabfold.py#L72
  submission_endpoint = "ticket/pair" if use_pairing else "ticket/msa"

  headers = {}
  if user_agent != "":
    headers['User-Agent'] = user_agent
  else:
    logger.warning("No user agent specified. Please set a user agent (e.g., 'toolname/version contact@email') to help us debug in case of problems. This warning will become an error in the future.")

  def submit(seqs, mode, N=101):
    n, query = N, ""
    for seq in seqs:
      query += f">{n}\n{seq}\n"
      n += 1

    while True:
      error_count = 0
      try:
        # https://requests.readthedocs.io/en/latest/user/advanced/#advanced
        # "good practice to set connect timeouts to slightly larger than a multiple of 3"
        res = requests.post(f'{host_url}/{submission_endpoint}', data={ 'q': query, 'mode': mode }, timeout=6.02, headers=headers)
      except requests.exceptions.Timeout:
        logger.warning("Timeout while submitting to MSA server. Retrying...")
        continue
      except Exception as e:
        error_count += 1
        logger.warning(f"Error while fetching result from MSA server. Retrying... ({error_count}/5)")
        logger.warning(f"Error: {e}")
        time.sleep(5)
        if error_count > 5:
          raise
        continue
      break

    try:
      out = res.json()
    except ValueError:
      logger.error(f"Server didn't reply with json: {res.text}")
      out = {"status":"ERROR"}
    return out

  def status(ID):
    while True:
      error_count = 0
      try:
        res = requests.get(f'{host_url}/ticket/{ID}', timeout=6.02, headers=headers)
      except requests.exceptions.Timeout:
        logger.warning("Timeout while fetching status from MSA server. Retrying...")
        continue
      except Exception as e:
        error_count += 1
        logger.warning(f"Error while fetching result from MSA server. Retrying... ({error_count}/5)")
        logger.warning(f"Error: {e}")
        time.sleep(5)
        if error_count > 5:
          raise
        continue
      break
    try:
      out = res.json()
    except ValueError:
      logger.error(f"Server didn't reply with json: {res.text}")
      out = {"status":"ERROR"}
    return out

  def download(ID, path):
    error_count = 0
    while True:
      try:
        res = requests.get(f'{host_url}/result/download/{ID}', timeout=6.02, headers=headers)
      except requests.exceptions.Timeout:
        logger.warning("Timeout while fetching result from MSA server. Retrying...")
        continue
      except Exception as e:
        error_count += 1
        logger.warning(f"Error while fetching result from MSA server. Retrying... ({error_count}/5)")
        logger.warning(f"Error: {e}")
        time.sleep(5)
        if error_count > 5:
          raise
        continue
      break
    with open(path,"wb") as out: out.write(res.content)

  # process input x
  seqs = [x] if isinstance(x, str) else x

  # compatibility to old option
  if filter is not None:
    use_filter = filter

  # setup mode
  if use_filter:
    mode = "env" if use_env else "all"
  else:
    mode = "env-nofilter" if use_env else "nofilter"

  if use_pairing:
    use_templates = False
    mode = ""
    # greedy is default, complete was the previous behavior
    if pairing_strategy == "greedy":
      mode = "pairgreedy"
    elif pairing_strategy == "complete":
      mode = "paircomplete"
    if use_env:
      mode = mode + "-env"

  # define path
  path = f"data/msa/"
  if not os.path.isdir(path): os.mkdir(path)

  # call mmseqs2 api
  tar_gz_file = f'{path}/out.tar.gz'
  N,REDO = 101,True

  # deduplicate and keep track of order
  seqs_unique = []
  [seqs_unique.append(x) for x in seqs if x not in seqs_unique]
  Ms = [N + seqs_unique.index(seq) for seq in seqs]
  # lets do it!
  if not os.path.isfile(tar_gz_file):
    TIME_ESTIMATE = 150 * len(seqs_unique)
    with tqdm.tqdm(total=TIME_ESTIMATE, bar_format=TQDM_BAR_FORMAT) as pbar:
      while REDO:
        pbar.set_description("SUBMIT")

        # Resubmit job until it goes through
        out = submit(seqs_unique, mode, N)
        while out["status"] in ["UNKNOWN", "RATELIMIT"]:
          sleep_time = 5 + random.randint(0, 5)
          logger.error(f"Sleeping for {sleep_time}s. Reason: {out['status']}")
          # resubmit
          time.sleep(sleep_time)
          out = submit(seqs_unique, mode, N)

        if out["status"] == "ERROR":
          raise Exception(f'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence. If error persists, please try again an hour later.')

        if out["status"] == "MAINTENANCE":
          raise Exception(f'MMseqs2 API is undergoing maintenance. Please try again in a few minutes.')

        # wait for job to finish
        ID,TIME = out["id"],0
        pbar.set_description(out["status"])
        while out["status"] in ["UNKNOWN","RUNNING","PENDING"]:
          t = 5 + random.randint(0,5)
          logger.error(f"Sleeping for {t}s. Reason: {out['status']}")
          time.sleep(t)
          out = status(ID)
          pbar.set_description(out["status"])
          if out["status"] == "RUNNING":
            TIME += t
            pbar.update(n=t)

        if out["status"] == "COMPLETE":
          if TIME < TIME_ESTIMATE:
            pbar.update(n=(TIME_ESTIMATE-TIME))
          REDO = False

        if out["status"] == "ERROR":
          REDO = False
          raise Exception(f'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence. If error persists, please try again an hour later.')

      # Download results
      download(ID, tar_gz_file)

  # prep list of a3m files
  if use_pairing:
    a3m_files = [f"{path}/pair.a3m"]
  else:
    a3m_files = [f"{path}/uniref.a3m"]
    if use_env: a3m_files.append(f"{path}/bfd.mgnify30.metaeuk30.smag30.a3m")

  # extract a3m files
  if any(not os.path.isfile(a3m_file) for a3m_file in a3m_files):
    with tarfile.open(tar_gz_file) as tar_gz:
      tar_gz.extractall(path)

  # templates
  if use_templates:
    templates = {}
    #print("seq\tpdb\tcid\tevalue")
    for line in open(f"{path}/pdb70.m8","r"):
      p = line.rstrip().split()
      M,pdb,qid,e_value = p[0],p[1],p[2],p[10]
      M = int(M)
      if M not in templates: templates[M] = []
      templates[M].append(pdb)
      #if len(templates[M]) <= 20:
      #  print(f"{int(M)-N}\t{pdb}\t{qid}\t{e_value}")

    template_paths = {}
    for k,TMPL in templates.items():
      TMPL_PATH = f"{prefix}_{mode}/templates_{k}"
      if not os.path.isdir(TMPL_PATH):
        os.mkdir(TMPL_PATH)
        TMPL_LINE = ",".join(TMPL[:20])
        response = None
        while True:
          error_count = 0
          try:
            # https://requests.readthedocs.io/en/latest/user/advanced/#advanced
            # "good practice to set connect timeouts to slightly larger than a multiple of 3"
            response = requests.get(f"{host_url}/template/{TMPL_LINE}", stream=True, timeout=6.02, headers=headers)
          except requests.exceptions.Timeout:
            logger.warning("Timeout while submitting to template server. Retrying...")
            continue
          except Exception as e:
            error_count += 1
            logger.warning(f"Error while fetching result from template server. Retrying... ({error_count}/5)")
            logger.warning(f"Error: {e}")
            time.sleep(5)
            if error_count > 5:
              raise
            continue
          break
        with tarfile.open(fileobj=response.raw, mode="r|gz") as tar:
          tar.extractall(path=TMPL_PATH)
        os.symlink("pdb70_a3m.ffindex", f"{TMPL_PATH}/pdb70_cs219.ffindex")
        with open(f"{TMPL_PATH}/pdb70_cs219.ffdata", "w") as f:
          f.write("")
      template_paths[k] = TMPL_PATH

# Initialize
deletekeys = dict.fromkeys(string.ascii_lowercase)
deletekeys["."] = None
deletekeys["*"] = None
translation = str.maketrans(deletekeys)

def remove_insertions(sequence: str):
    """Removes any insertions into the sequence. Needed to load aligned sequences in an MSA."""
    return sequence.translate(translation)


def read_msa(filename: str) -> List[Tuple[str, str]]:
    """Reads the first nseq sequences from an MSA file, automatically removes insertions."""
    return [
        (record.description, remove_insertions(str(record.seq)))
        for record in SeqIO.parse(filename, "fasta")
    ]


def hamming_distance(string1, string2):
    return sum(c1 != c2 for c1, c2 in zip(string1, string2))


# Initialize
def merge_and_sort_msas(msa_dir):
    subprocess.run(["mkdir", "-p", f"{msa_dir}_tmp"])

    # Load MSA files
    msa_files = sorted(glob.glob(f"{msa_dir}/*.a3m"))

    # Concatenate MSAs
    msas = []
    for i, _file in enumerate(msa_files):
        msas.append(read_msa(_file))

    assert msas[0][0][0] == msas[1][0][0]
    query = msas[0][0]
    outfile_name = f"{msa_dir}/final_unsorted.a3m"
    outfile = open(outfile_name, "w")
    outfile.write(f">{query[0]}\n")
    outfile.write(f"{query[1]}\n")

    seqs_1 = msas[0][1:]
    for seq in seqs_1:
        outfile.write(f">{seq[0]}\n")
        outfile.write(f"{seq[1]}\n")

    seqs_2 = msas[1][1:]
    for seq in seqs_2:
        outfile.write(f">{seq[0]}\n")
        outfile.write(f"{seq[1]}\n")
    outfile.close()

    # Sort concatenated MSA
    msa = read_msa(outfile_name)
    seqs = [x for x in msa]
    query = seqs[0]
    seqs = seqs[1:]
    ham_dists = np.zeros(len(seqs))

    for j, seq in enumerate(seqs):
        assert len(query) == len(seq)
        ham_dists[j] = hamming_distance(query[1], seq[1])

        # Rank indices
        rank_indices = np.argsort(ham_dists)

        # Remove query duplicates
        if 0 in ham_dists:
            query_idx = np.argwhere(ham_dists == 0)[0]
            rank_indices = np.delete(
                rank_indices, np.argwhere(np.isin(rank_indices, query_idx))
            )

        # Construct new sorted MSA
        seqs_new = []
        for idx in rank_indices:
            seqs_new.append(seqs[idx])

    # Write to new file
    outfile = open(f"{msa_dir}_tmp/final.a3m", "w")
    outfile.write(f">{query[0]}\n")
    outfile.write(f"{query[1]}\n")

    for seq in seqs_new:
        outfile.write(f">{seq[0]}\n")
        outfile.write(f"{seq[1]}\n")
    outfile.close()

    # Delete tmp directory
    subprocess.run(["rm", "-r", f"{msa_dir}"])
    subprocess.run(["mv", f"{msa_dir}_tmp", msa_dir])

def forward(
    model_msa,
    model_gvp,
    msa_batch_tokens_masked,
    seq_masked,
    batch,
    mask_pos=None,
    loss_fn=None,
    batch_prots=None,
    get_logits_only=False,
):
    # Make MSA Transformer predictions
    msa_transformer_pred = model_msa(
        msa_batch_tokens_masked, repr_layers=[12], self_row_attn_mask=batch.dist_mask
    )
    msa_emb = msa_transformer_pred["representations"][12][0, 0, 1:, :]

    # Make GVP predictions
    h_V = (batch.node_s, batch.node_v)
    h_E = (batch.edge_s, batch.edge_v)
    logits = model_gvp(h_V, batch.edge_index, h_E, msa_emb, seq_masked)

    if get_logits_only == True:
        # Return logits
        return logits
    else:
        # Compute loss
        logits, seq = logits[mask_pos], batch.seq[mask_pos]
        loss_value = loss_fn(logits, seq)
        loss_value = loss_value / batch_prots

        return loss_value, logits, seq

def loop_pred(
    model_msa,
    model_gvp,
    msa_batch_converter,
    dataloader,
    variant_pos_dict,
    data,
    letter_to_num,
    device=None,
):
    # Initialize
    t = tqdm.tqdm(dataloader)
    pred_list = []
    total_correct, total_count = 0, 0

    # Loop over proteins
    for i, batch in enumerate(t):
        with torch.cuda.amp.autocast(enabled=True):
            # Move data to device
            batch = batch.to(device)

            # Initialize
            variant_wtpos_list = variant_pos_dict[batch.name[0]]
            seq_len = len(batch.seq)

            # Make masked marginal predictions
            for k, variant_wtpos in enumerate(variant_wtpos_list):
                print(
                    f"Computing logits for protein {batch.name[0]} ({i+1}/{len(dataloader)}) at position: {k+1}/{len(variant_wtpos_list)}"
                )

                # Extract variant info and initialize
                wt = letter_to_num[variant_wtpos[0]]
                pos = int(variant_wtpos[1:]) - 1  # Shift from DMS pos to seq idx
                score_ml_pos_ensemble = torch.zeros((len(batch.msa[0]), 20))

                # If protein too long; redo data loading with fragment
                if seq_len > 1024:
                    # Get sliding window
                    window_size = 1024 - 1
                    lower_side = max(pos - window_size // 2, 0)
                    upper_side = min(pos + window_size // 2 + 1, seq_len)
                    lower_bound = lower_side - (pos + window_size // 2 + 1 - upper_side)
                    upper_bound = upper_side + (lower_side - (pos - window_size // 2))

                    # Get fragment
                    data_frag = copy.deepcopy(data[i])
                    data_frag["seq"] = data[i]["seq"][lower_bound:upper_bound]
                    data_frag["coords"] = data[i]["coords"][lower_bound:upper_bound]
                    data_frag["msa"] = [
                        [(seq[0], seq[1][lower_bound:upper_bound]) for seq in msa_sub]
                        for msa_sub in data[i]["msa"]
                    ]
                    batch = models.gvp.data.ProteinGraphData([data_frag])[0]
                    batch = batch.to(device)
                    batch.msa = [batch.msa]
                    batch.name = [batch.name]

                    # Re-map position
                    pos = pos - lower_bound

                # Loop over MSA ensemble
                for j, msa_sub in enumerate(batch.msa[0]):
                    # Tokenize MSA
                    (
                        msa_batch_labels,
                        msa_batch_strs,
                        msa_batch_tokens,
                    ) = msa_batch_converter(msa_sub)
                    msa_batch_tokens = msa_batch_tokens.to(device)

                    # Mask position
                    msa_batch_tokens_masked = msa_batch_tokens.detach().clone()
                    msa_batch_tokens_masked[
                        :, 0, pos + 1
                    ] = 32  # Account for appended <cls> token
                    seq_masked = batch.seq.detach().clone()
                    seq_masked[pos] = 20

                    # Forward pass
                    logits = forward(
                        model_msa,
                        model_gvp,
                        msa_batch_tokens_masked,
                        seq_masked,
                        batch,
                        get_logits_only=True,
                    )
                    logits_pos = logits[pos, :]

                    # Compute accuracy
                    pred = (
                        torch.argmax(logits_pos, dim=-1).detach().cpu().numpy().item()
                    )
                    true = batch.seq[pos].detach().cpu().numpy().item()
                    if pred == true:
                        total_correct += 1 / len(batch.msa[0])

                    # Compute all possible nlls at this position based on known wt
                    nlls_pos = -torch.log(F.softmax(logits_pos, dim=-1))
                    nlls_pos_repeat = nlls_pos.repeat(20, 1)
                    score_ml_pos_ensemble[j, :] = torch.diagonal(
                        nlls_pos_repeat[:, wt] - nlls_pos_repeat[:, torch.arange(20)]
                    )

                # Append to total
                score_ml_pos = torch.mean(score_ml_pos_ensemble[: j + 1, :], axis=0)
                pred_list.append(
                    [
                        batch.name[0],
                        int(variant_wtpos[1:]),
                        score_ml_pos.detach().cpu().tolist(),
                    ]
                )
                total_count += 1

    return pred_list, total_correct / total_count

# <b>Upload PDB and build MSA</b>

In [None]:
#@markdown Choose between one of the possible input sources:
#@markdown - AlphaFold2 PDB (v4) via Uniprot ID:
AF_ID ='P68871'#@param {type:"string"}
#@markdown - PDB ID (imported from RCSB PDB):
PDB_ID =''#@param {type:"string"}
#@markdown - Upload custom PDB
PDB_custom =False #@param {type:"boolean"}

#@markdown

#@markdown Select target chain (default A)
chain='A' #@param {type:'string'}

if PDB_custom:
  print('Upload PDB file:')
  uploaded_pdb = files.upload()
  for fn in uploaded_pdb.keys():
    os.rename(fn, f"data/structure/query_protein.pdb")
    print('PDB file correctly loaded')
elif (AF_ID !='') and (len(AF_ID)>=6) :
    subprocess.call(['curl','-s','-f',f'https://alphafold.ebi.ac.uk/files/AF-{AF_ID}-F1-model_v4.pdb','-o','data/structure/query_protein.pdb'])
elif (PDB_ID !='') and (len(PDB_ID)==4):
    subprocess.call(['curl','-s','-f',f'https://files.rcsb.org/download/{PDB_ID}.pdb','-o','data/structure/query_protein.pdb'])
else:
  print(f'ERROR: Please select one of the above inputs')

## remove other chains and move to raw folder
! pdb_selchain -"$chain" data/structure/query_protein.pdb | pdb_delhetatm | pdb_delres --999:0:1 | pdb_fixinsert | pdb_tidy  > data/structure/raw/query_protein_"$chain".pdb

print(f"Pre-processing PDBs ...")
# Pre-process PDBs
pdb_dir = "data/structure"
subprocess.run(
    [
        "pdb_parser_scripts/clean_pdbs.sh",
        str(pdb_dir),
    ]
)
parse(pdb_dir)

# Load structure data
print("Loading models and data...")
with open(f"{pdb_dir}/coords.json") as json_file:
    data = json.load(json_file)
json_file.close()

#@markdown ****

Pre-processing PDBs ...
Loading models and data...


In [None]:
# Get MSA by querying server
seq = [str(record.seq) for record in SeqIO.parse("data/structure/seqs.fasta",
                                                 "fasta")][0]
run_mmseqs2(seq)

# Merge and sort MSAs
merge_and_sort_msas("data/msa")
msa_filename = "data/msa/final.a3m"
subprocess.run(["mv", f"{msa_filename}",
                f"data/msa/query_protein_{chain}_unfiltered.a3m"])

# Filter MSA
subprocess.run(["mmseqs","filtera3m","--diff","512","--filter-min-enable","64",
                "--max-seq-id","0.90","--cov","0.75",
                f"data/msa/query_protein_{chain}_unfiltered.a3m",
                f"data/msa/query_protein_{chain}.a3m"])

# Load MSA data
msa_filenames = sorted(glob.glob(f"data/msa/*.a3m"))
mave_msa_sub = {}
for i, f in enumerate(msa_filenames):
    name = f.split("/")[-1].split(".")[0]
    mave_msa_sub[name] = []
    for j in range(5):
        msa = read_msa(f)
        msa_sub = [msa[0]]
        k = min(len(msa) - 1, 16 - 1)
        msa_sub += [msa[i] for i in sorted(random.sample(range(1, len(msa)), k))]
        mave_msa_sub[name].append(msa_sub)

# Add MSAs to data
for entry in data:
    entry["msa"] = mave_msa_sub[entry["name"]]

COMPLETE: 100%|██████████| 150/150 [elapsed: 00:01 remaining: 00:00]


# <b>Load pre-trained model weights</b>

In [None]:
# Convert to graph dataset and dataloader
testset = models.gvp.data.ProteinGraphData(data)
letter_to_num = testset.letter_to_num
test_loader = torch_geometric.loader.DataLoader(
        testset, batch_size=1, shuffle=False
    )

# Make variant pos dict
variant_pos_dict = {}
for entry in data:
    seq = entry["seq"]
    pos = [str(x + 1) for x in range(len(seq))]
    variant_wtpos_list = [[seq[i] + pos[i]] for i in range(len(seq))]
    variant_wtpos_list = [x for sublist in variant_wtpos_list for x in sublist]
    variant_pos_dict[entry["name"]] = variant_wtpos_list

# Load MSA Transformer
with open('data/weights/msa_alphabet.pkl', 'rb') as fh:
    msa_alphabet = pickle.load(fh)
msa_batch_converter = msa_alphabet.get_batch_converter()
model_msa = MSATransformer()

model_dict = OrderedDict()
state_dict_msa = torch.load("data/weights/final_cath_msa_transformer_110.pt")
pattern = re.compile("module.")
for k, v in state_dict_msa.items():
    if re.search("module", k):
        model_dict[re.sub(pattern, "", k)] = v
    else:
        model_dict = state_dict_msa
model_msa.load_state_dict(model_dict)

# Load GVP
node_dim = (256, 64)
edge_dim = (32, 1)
model_gvp = SSEmbGNN((6, 3), node_dim, (32, 1), edge_dim)

model_dict = OrderedDict()
state_dict_gvp = torch.load(f"data/weights/final_cath_gvp_110.pt")
pattern = re.compile("module.")
for k, v in state_dict_gvp.items():
    if k.startswith("module"):
        model_dict[k[7:]] = v
    else:
        model_dict = state_dict_gvp
model_gvp.load_state_dict(model_dict)

<All keys matched successfully>

# <b>Make SSEmb predictions and save output</b>

In [None]:
# Call test
model_msa.eval()
model_gvp.eval()

with torch.no_grad():
    pred_list, acc_mean = loop_pred(
        model_msa,
        model_gvp,
        msa_batch_converter,
        test_loader,
        variant_pos_dict,
        data,
        letter_to_num,
    )

# Transform results into df
df_ml = pd.DataFrame(pred_list, columns=["protein", "variant_pos", "score_ml_pos"])

# Compute score_ml from nlls
pred_list_scores = []
mt_list = [x for x in sorted(letter_to_num, key=letter_to_num.get)][:-1]

for entry in data:
    protein = entry["name"]
    df_protein = df_ml[df_ml["protein"] == protein]

    wt = [[wt] * 20 for wt in entry["seq"]]
    pos = [[pos] * 20 for pos in list(protein["variant_pos"])]
    pos = [item for sublist in pos for item in sublist]
    mt = mt_list * len(wt)
    wt = [item for sublist in wt for item in sublist]
    score_ml = [
        item for sublist in list(protein["score_ml_pos"]) for item in sublist
    ]

    rows = [
        [protein, wt[i] + str(pos[i]) + mt[i], score_ml[i]] for i in range(len(mt))
    ]
    pred_list_scores += rows

# Transform results into df
df_ml_scores = pd.DataFrame(
    pred_list_scores, columns=["protein", "variant", "score_ml"]
)

# Save
df_ml_scores.to_csv(f"output/df_ssemb.csv", index=False)
print("Done!")

  0%|          | 0/1 [00:00<?, ?it/s]

Computing logits for protein query_protein_A (1/1) at position: 1/147
Computing logits for protein query_protein_A (1/1) at position: 2/147
Computing logits for protein query_protein_A (1/1) at position: 3/147
Computing logits for protein query_protein_A (1/1) at position: 4/147
Computing logits for protein query_protein_A (1/1) at position: 5/147
Computing logits for protein query_protein_A (1/1) at position: 6/147
Computing logits for protein query_protein_A (1/1) at position: 7/147
Computing logits for protein query_protein_A (1/1) at position: 8/147
Computing logits for protein query_protein_A (1/1) at position: 9/147
Computing logits for protein query_protein_A (1/1) at position: 10/147
Computing logits for protein query_protein_A (1/1) at position: 11/147
Computing logits for protein query_protein_A (1/1) at position: 12/147
Computing logits for protein query_protein_A (1/1) at position: 13/147
Computing logits for protein query_protein_A (1/1) at position: 14/147
Computing logit