# This is a notebook to show cases of ESM_if

In [1]:
import esm
#load the model
model, alphabet = esm.pretrained.esm_if1_gvp4_t16_142M_UR50()
model = model.eval()  

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# download the example structure
!wget https://files.rcsb.org/download/7mmo.cif -P data/

--2025-06-13 12:39:53--  https://files.rcsb.org/download/7mmo.cif
Connecting to 128.59.114.167:3128... connected.
Proxy request sent, awaiting response... 200 OK
Length: unspecified [application/octet-stream]
Saving to: ‘data/7mmo.cif.1’

7mmo.cif.1              [ <=>                ] 968.49K  --.-KB/s    in 0.05s   

2025-06-13 12:39:54 (19.8 MB/s) - ‘data/7mmo.cif.1’ saved [991730]



In [3]:
# information on structure to be designed
fpath = 'data/7mmo.cif' # .pdb format is also acceptable
chain_id = 'A'
structure = esm.inverse_folding.util.load_structure(fpath, chain_id)
coords, native_seq = esm.inverse_folding.util.extract_coords_from_structure(structure)
print('Native sequence:')
print(native_seq)

Native sequence:
ITLKESGPTLVKPTQTLTLTCTFSGFSLSISGVGVGWLRQPPGKALEWLALIYWDDDKRYSPSLKSRLTISKDTSKNQVVLKMTNIDPVDTATYYCAHHSISTIFDHWGQGTLVTVSSASTKGPSVFPLAPCTAALGCLVKDYFPEPVTVSWNSGALTSGVHTFPAVLQSSGLYSLSSVVTVPSSSLGTKTYTCNVDHKPSNTKVDKRVHH


# Application 1. sample new seqence based on the backbone structure

The sample strategy follows a multinomial sampling fashion, which samples based on probility distribution and temperature. (lower probability will be also sampled, espically at high temperature)

In [4]:
import numpy as np

sampled_seq = model.sample(coords, temperature=1)
print('Sampled sequence:', sampled_seq)

recovery = np.mean([(a==b) for a, b in zip(native_seq, sampled_seq)])
print('Sequence recovery:', recovery)

Sampled sequence: STNTVSGPTVVKPTDTLTLQCTYSGFSLNTTGVGLGWSRTPQGSTWEELALIHWNNEKTYKPSEKADLFVSKNDVKQTVTLKYTHLQPSDTAVYYCAYYTGETVMTNFSAGLQVTVSSATIQGPSVLPLTPGTVVVGCAINDFYPQPVTVTVNSGSLTSGVTVYPQVLLPSGLYKRNGLVTVPTSSCNTTTITFNVAHKPSDTTVNQVVNC
Sequence recovery: 0.5308056872037915


In [5]:
# we can also sample sequences conditioned on a partial structure
from copy import deepcopy
masked_coords = deepcopy(coords)
print('Masked coordinates shape:', masked_coords.shape)
masked_coords[:15] = np.inf # mask the first 15 residues
ll_fullseq, ll_withcoord = esm.inverse_folding.util.score_sequence(model, alphabet, masked_coords, native_seq)

print(f'average log-likelihood on entire sequence: {ll_fullseq:.2f} (perplexity {np.exp(-ll_fullseq):.2f})')
print(f'average log-likelihood excluding missing coordinates: {ll_withcoord:.2f} (perplexity {np.exp(-ll_withcoord):.2f})')

sampled_seq = model.sample(masked_coords, temperature=1)
print('Sampled sequence with masked coordinates:', sampled_seq)

recovery = np.mean([(a==b) for a, b in zip(native_seq, sampled_seq)])
print('Sequence recovery:', recovery)

# we can also sample sequences conditioned on a partial sequence
# masked_seq = deepcopy(native_seq)
# masked_seq[:15] = '-' # mask the first 15 residues
# ll_fullseq, ll_withseq = esm.inverse_folding.util.score_sequence(model, alphabet, coords, masked_seq)
# print(f'average log-likelihood on entire sequence: {ll_fullseq:.2f} (perplexity {np.exp(-ll_fullseq):.2f})')
# print(f'average log-likelihood excluding missing sequence: {ll_withseq:.2f} (perplexity {np.exp(-ll_withseq):.2f})')

Masked coordinates shape: (211, 3, 3)
average log-likelihood on entire sequence: -1.44 (perplexity 4.22)
average log-likelihood excluding missing coordinates: -1.33 (perplexity 3.77)
Sampled sequence with masked coordinates: GPADAPQVQLVVRQGQLSLLCEYSGFELKTAGVGIGFRWRPPGKSEEGLALILHDDTTYYNPSLRSRLAVSSNVQKKTVTLVMSEVVPQDTATYYCGFVTGTHQVTSWSPGVLVIVSKAKPTGPKVMPLKPGELTLGCSIESYWPESVTVSWNSGTTARGVTIQPSQLLPTGLYKRDGTVTVPKSRCDTVSYTCNVLHVPTATQVALLVHC
Sequence recovery: 0.45023696682464454


In [7]:
print(model.forward)

<bound method GVPTransformerModel.forward of GVPTransformerModel(
  (encoder): GVPTransformerEncoder(
    (dropout_module): Dropout(p=0.1, inplace=False)
    (embed_tokens): Embedding(35, 512, padding_idx=1)
    (embed_positions): SinusoidalPositionalEmbedding()
    (embed_gvp_input_features): Linear(in_features=15, out_features=512, bias=True)
    (embed_confidence): Linear(in_features=16, out_features=512, bias=True)
    (embed_dihedrals): DihedralFeatures(
      (node_embedding): Linear(in_features=6, out_features=512, bias=True)
      (norm_nodes): Normalize()
    )
    (gvp_encoder): GVPEncoder(
      (embed_graph): GVPGraphEmbedding(
        (embed_node): Sequential(
          (0): GVP(
            (wh): Linear(in_features=3, out_features=256, bias=False)
            (ws): Linear(in_features=263, out_features=1024, bias=True)
            (wv): Linear(in_features=256, out_features=256, bias=False)
          )
          (1): LayerNorm(
            (scalar_norm): LayerNorm((1024,), 

# Application 2: generate structure embedding

In [6]:

rep = esm.inverse_folding.util.get_encoder_output(model, alphabet, coords)
len(coords), rep.shape

(211, torch.Size([211, 512]))

# Application 3: evaluate squence fitness 

with the backbone info Y, we evulate the fitness of the sequence X by comparing target = sequence[1:] and predicted ligits based on input seqences =sequence[:-1]


In [9]:
ll_fullseq, ll_withcoord = esm.inverse_folding.util.score_sequence(model, alphabet, coords, native_seq)
print(f'average log-likelihood on entire sequence: {ll_fullseq:.2f} (perplexity {np.exp(-ll_fullseq):.2f})')
print(f'average log-likelihood excluding missing coordinates: {ll_withcoord:.2f} (perplexity {np.exp(-ll_withcoord):.2f})')

average log-likelihood on entire sequence: -1.32 (perplexity 3.75)
average log-likelihood excluding missing coordinates: -1.32 (perplexity 3.75)


In [13]:
import os
import sys
dir_path = '/home/yunyao/structural-evolution/bin'

# Add the directory containing the modules to sys.path
if dir_path not in sys.path:
    sys.path.append(dir_path)

In [17]:
# we can test the fitnesss of various sequences generated by dms

import argparse
from dms_utils import deep_mutational_scan
from pathlib import Path
import numpy as np

import esm
from util import load_structure, extract_coords_from_structure
import biotite.structure
from collections import defaultdict


def get_native_seq(pdbfile, chain):
    structure = load_structure(pdbfile, chain)
    _ , native_seq = extract_coords_from_structure(structure)
    return native_seq
    
def write_dms_lib(args):
    '''Writes a deep mutational scanning library, including the native/wildtype (wt) of the 
    indicated target chain in the structure to an output Fasta file'''

    sequence = get_native_seq(args.pdbfile, args.chain)
    Path(args.dmspath).parent.mkdir(parents=True, exist_ok=True)
    with open(args.dmspath, 'w') as f:
        f.write('>wt\n')
        f.write(sequence+'\n')
        for pos, wt, mt in deep_mutational_scan(sequence):
            assert(sequence[pos] == wt)
            mut_seq = sequence[:pos] + mt + sequence[(pos + 1):]

            f.write('>' + str(wt) + str(pos+1) + str(mt) + '\n')
            f.write(mut_seq + '\n')

pdbfile_path= 'data/7mmo.cif'
chain_id='A'
outpath='/home/yunyao/structural_evo_breakdown/dms_lib'
dmspath = f'{outpath}/7mmo_dms_lib.fasta'
write_dms_lib(
    argparse.Namespace(
        pdbfile=pdbfile_path,
        chain=chain_id,
        dmspath=dmspath,
        outpath=outpath
    )
)

In [18]:
  # now we can scan through the library and score each sequence
dms_lib = defaultdict(list)
with open(dmspath, 'r') as f:
    for line in f:
        if line.startswith('>'):
            seq_id = line.strip()[1:]  # remove '>'
        else:
            sequence = line.strip()
            dms_lib[seq_id].append(sequence)
# score each sequence in the library
scores = {}
for seq_id, sequences in dms_lib.items():
    for sequence in sequences:
        ll_fullseq, ll_withcoord = esm.inverse_folding.util.score_sequence(model, alphabet, coords, sequence)
        scores[seq_id] = (ll_fullseq, ll_withcoord)