#CafChem tools for progressive unmasking of proteins

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/MauricioCafiero/CafChem/blob/main/notebooks/ProteinProgUnmask_CafChem.ipynb)

## This notebook allows you to:
- load a protein sequence
- specify which residues to mask
- ESM model unmasks residues; code then chooses the unmasked residue with the highest probability and adds it to the chain.
- chain with newly unmasked residue is passed through the model again, and the new unmasked residue with the highest probability os added to the chain.
- etc.
- tools to compare old an new chains.

## Requirements:

- Runs quickly on an L4 GPU

## Install and import libraries

In [1]:
!pip install -q py3Dmol
!pip install -q "fair-esm[esmfold]"

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/510.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.3/510.3 kB[0m [31m17.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m91.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.7/76.7 kB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m849.5/849.5 kB[0m [31m62.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m57.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.0/54.0 kB[0m [31m4.8 MB/s[0m eta [36m0:0

In [2]:
from transformers import AutoTokenizer, EsmModel, EsmForMaskedLM, EsmForSequenceClassification
import torch
from torch import inf
import py3Dmol
import requests
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import pandas as pd
import esm

## Define functions

In [107]:
def one_to_three(one_seq):
  '''
  Convert one-letter code to three-letter code
  Input: one-letter code
  Output: three-letter code
  '''
  rev_aa_hash = {
      'A': 'ALA',
      'R': 'ARG',
      'N': 'ASN',
      'D': 'ASP',
      'C': 'CYS',
      'Q': 'GLN',
      'E': 'GLU',
      'G': 'GLY',
      'H': 'HIS',
      'I': 'ILE',
      'L': 'LEU',
      'K': 'LYS',
      'M': 'MET',
      'F': 'PHE',
      'P': 'PRO',
      'S': 'SER',
      'T': 'THR',
      'W': 'TRP',
      'Y': 'TYR',
      'V': 'VAL'
  }

  try:
    three_seq = rev_aa_hash[one_seq]
  except:
    three_seq = 'X'

  return three_seq

class gen_mask_fill():
  '''
  Class to generate masks and fill them with ESM predictions
  '''
  def __init__(self, checkpoint: str, seq: list, res_to_mask: list[str]):
    '''
    Constructor for mask filling
    Input:
    - checkpoint: path to ESM model
    - seq: sequence to mask
    - res_to_mask: list of residues to mask
    '''
    self.checkpoint = checkpoint
    self.seq = seq
    self.res_to_mask = res_to_mask
    self.natural_residues = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L',
                             'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y']

  def start_model(self):
    '''
    Start ESM model and tokenizer
    '''
    self.tokenizer = AutoTokenizer.from_pretrained(self.checkpoint)
    self.model = EsmForMaskedLM.from_pretrained((self.checkpoint))

  def mask_tokens(self):
    '''
    Mask tokens in sequence
    Output:
    - seq_ids: sequence of tokens
    - masked_chain: masked sequence
    - masked_chain_ids: masked sequence of tokens
    '''
    self.seq_ids = self.tokenizer(''.join(self.seq))['input_ids']

    masked_chain = []
    num_masked = 0
    for i, token in enumerate(self.seq):
      if token in self.res_to_mask:
        masked_chain.append('<mask>')
        num_masked += 1
      else:
        masked_chain.append(token)

    self.num_masked = num_masked
    self.masked_chain = masked_chain
    self.masked_chain_ids = self.tokenizer(''.join(masked_chain))['input_ids']

    return self.seq_ids, self.masked_chain, self.masked_chain_ids

  def unmask(self):
    '''
    Unmask tokens in sequence; fills in mask with highest probability, then re-runs
    ummasking model on remaining masked tokens and repeats until all are unmasked.
    Output:
    - masked_chain: unmasked sequence
    '''
    for k in range(self.num_masked):
      model_out = self.model(**self.tokenizer(text = ''.join(self.masked_chain), return_tensors='pt'))

      masked_probs = []
      for i, row in enumerate(model_out.logits[0][1:-1]):
        if self.masked_chain[i] == '<mask>':
          probs = torch.softmax(row.detach().clone(), dim=0)
          best_prob = torch.argmax(probs).detach().clone().item()
          masked_probs.append((probs[best_prob],i+1))

      masked_probs.sort(key=lambda x: x[0], reverse=True)

      new_token = self.tokenizer.decode([masked_probs[0][1]])
      if new_token not in self.natural_residues:
        new_token = 'G'
      self.masked_chain[masked_probs[0][1]-1] = new_token

    return self.masked_chain

  def compare_seqs(self):
    '''
    Compare original and new sequences
    Output:
    - chain: original sequence
    - new_seq: new sequence
    '''
    self.new_seq = ''.join(self.masked_chain)
    self.chain = ''.join(self.seq).replace('<cls>','').replace('<eos>','')
    print(f"Original: {self.chain}")
    print(f"Novel   : {self.new_seq}")

    i = 1
    for char_o, char_n in zip(self.seq,self.new_seq):
      if char_o != char_n:
        print(f"Residue {i} changed {one_to_three(char_o)} --> {one_to_three(char_n)}.")
      i += 1

    return self.chain, self.new_seq

  def compare_seqs_naive(self):
    '''
    Compare original and new sequences by % of differences
    Output:
    - chain: original sequence
    - new_seq: new sequence
    '''
    self.new_seq = ''.join(self.masked_chain)

    self.chain = ''.join(self.seq).replace('<cls>','').replace('<eos>','')
    print(f"Original: {self.chain}")
    print(f"Novel   : {self.new_seq}")

    num_diff = 0
    for char_o, char_n in zip(self.seq,self.new_seq):

      if char_o != char_n:
        num_diff += 1

    print(f"Number of differences: {num_diff} out of {len(self.seq)}")
    print(f"Percentage of differences: {num_diff/len(self.seq):.3f}")

    return self.chain, self.new_seq

## Unmask proteins

In [108]:
sgt_mask = gen_mask_fill(checkpoint = 'facebook/esm2_t33_650M_UR50D', seq = 'HXEGTFTSDVSSYLEGQAAKEFIAWLVRGRG', res_to_mask = ['G'])
sgt_mask.start_model()

In [109]:
seq_ids, masked_chain, masked_chain_ids = sgt_mask.mask_tokens()
print(seq_ids)
print(masked_chain_ids)

[0, 21, 24, 9, 6, 11, 18, 11, 8, 13, 7, 8, 8, 19, 4, 9, 6, 16, 5, 5, 15, 9, 18, 12, 5, 22, 4, 7, 10, 6, 10, 6, 2]
[0, 21, 24, 9, 32, 11, 18, 11, 8, 13, 7, 8, 8, 19, 4, 9, 32, 16, 5, 5, 15, 9, 18, 12, 5, 22, 4, 7, 10, 32, 10, 32, 2]


In [110]:
print(masked_chain)

['H', 'X', 'E', '<mask>', 'T', 'F', 'T', 'S', 'D', 'V', 'S', 'S', 'Y', 'L', 'E', '<mask>', 'Q', 'A', 'A', 'K', 'E', 'F', 'I', 'A', 'W', 'L', 'V', 'R', '<mask>', 'R', '<mask>']


In [111]:
new_chain = sgt_mask.unmask()
print(new_chain)

['H', 'X', 'E', 'L', 'T', 'F', 'T', 'S', 'D', 'V', 'S', 'S', 'Y', 'L', 'E', 'Q', 'Q', 'A', 'A', 'K', 'E', 'F', 'I', 'A', 'W', 'L', 'V', 'R', 'G', 'R', 'G']


In [112]:
orig, new = sgt_mask.compare_seqs()

Original: HXEGTFTSDVSSYLEGQAAKEFIAWLVRGRG
Novel   : HXELTFTSDVSSYLEQQAAKEFIAWLVRGRG
Residue 4 changed GLY --> LEU.
Residue 16 changed GLY --> GLN.


In [113]:
orig, new = sgt_mask.compare_seqs_naive()

Original: HXEGTFTSDVSSYLEGQAAKEFIAWLVRGRG
Novel   : HXELTFTSDVSSYLEQQAAKEFIAWLVRGRG
Number of differences: 2 out of 31
Percentage of differences: 0.065
