# Inference Sample for ESM2 - Logit Sampling Scan

## Modified for logit/proba extraction
Adrian Lange, A-Alpha Bio

---------------

SPDX-FileCopyrightText: Copyright (c) <year> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: LicenseRef-NvidiaProprietary

NVIDIA CORPORATION, its affiliates and licensors retain all intellectual property and proprietary rights in and to this material, related documentation and any modifications thereto. Any use, reproduction, disclosure or distribution of this material and related documentation without an express license agreement from NVIDIA CORPORATION or its affiliates is strictly prohibited.

### Prerequisite

- Linux OS
- Pascal, Volta, Turing, or an NVIDIA Ampere architecture-based GPU.
- NVIDIA Driver
- Docker

#### Import

Components for inferencing are part of the BioNeMo ESM source code. This notebook demonstrates the use of these components.

__`ESMInferenceWrapper`__ implements __`seq_to_embedding`__ function to obtain encoder embeddings for the input protein sequence in text format. 


To run this notebook, you should launch the gRPC client in your terminal beforehand. 

The following command is an example of launching the gRPC inference client using the esm2-650M checkpoints:

```python3 -m bionemo.model.protein.esm1nv.grpc.service --model esm2nv_650M```

Note that gRPC limits request size to 4MB.


In [1]:
import warnings

warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

In [2]:
import numpy as np
import torch
from torch.cuda.amp import autocast

# For an optional visualization
import matplotlib.pyplot as plt

In [3]:
from pathlib import Path
import os

try:
    BIONEMO_HOME: Path = Path(os.environ['BIONEMO_HOME']).absolute()
except KeyError:
    print("Must have BIONEMO_HOME set in the environment! See docs for instructions.")
    raise

config_path = BIONEMO_HOME / "my_workspace" / "esm2nv_logits" / "conf"
print(f"Using model configuration at: {config_path}")
assert config_path.is_dir()

Using model configuration at: /workspace/bionemo/my_workspace/esm2nv_logits/conf


### Setup and Test Data

In [4]:
seqs = [
    'MSLKRKNIALIPAAGIGVRFGADKPKQYVEIGSKTVLEHVL', 
    'MIQSQINRNIRLDLADAILLSKAKKDLSFAEIADGTGLA',
]

In [5]:
from bionemo.triton.utils import load_model_config

cfg = load_model_config(config_path, config_name="infer.yaml")

NOTE! Installing ujson may make loading annotations faster.


INFO:rdkit:Enabling RDKit 2023.09.1 jupyter extensions


In [6]:
from bionemo.triton.utils import load_model_for_inference
from bionemo.model.protein.esm1nv.infer import ESM1nvInference

inferer = load_model_for_inference(cfg, interactive=True)

print(f"Loaded a {type(inferer)}")
assert isinstance(inferer, ESM1nvInference)

[NeMo I 2024-04-15 20:50:43 utils:426] pytorch DDP is not initialized. Initializing with pytorch-lightening...


      rank_zero_deprecation(
    


MisconfigurationException: `CUDAAccelerator` can not run on your system since the accelerator is not available. The following accelerator(s) is available and can be passed into `accelerator` argument of `Trainer`: ['cpu'].

### Turn off post_process

After loading, we switch off post processing so that the inferer still returns hidden states like we expect.

In [None]:
inferer.model.model.post_process = False

### Sequence to Hidden States

__`seq_to_hiddens`__ queries the model to fetch the encoder hiddens states for the input protein sequence. `pad_mask` is returned with `hidden_states` and contains padding information  

In [None]:
hidden_states, pad_masks = inferer.seq_to_hiddens(seqs)
print(f"{hidden_states.shape=}")
print(f"{pad_masks.shape=}")
assert tuple(hidden_states.shape) == (2, 43, 2560)  # ESM2nv has 2560 dimensions
assert tuple(pad_masks.shape) == (2, 43)

In [None]:
# Check pad_masks
pad_masks

### Language model head

Helpers for working with the ESM2nv BERT LM head.

The tokenizer has classpath: `nemo.collections.common.tokenizers.huggingface.auto_tokenizer.AutoTokenizer`

The source code for this should be findable on the BioNeMo Docker image at:
```
/usr/local/lib/python3.10/dist-packages/nemo/collections/common/tokenizers/huggingface/auto_tokenizer.py
```

In [None]:
# Check we have the lm_head
inferer.model.model.lm_head

In [None]:
# Check vocabulary
inferer.tokenizer.vocab

In [None]:
def hidden_states_to_logits(hidden_states):
    """Take hidden_states from inferer.seq_to_hiddens(seqs) and apply the LM head.
    
    hidden_states.shape: <n_batch, n_max_seq_length_for_batch, 2560 hidden state dimensions>
    
    The logit shape has a fixed length of 128 tokens, however, ESM2nv only uses a vocabulary of 33.
    So, we pare down to only the relevant vocabulary classes.
    """
    with autocast(enabled=inferer.model.enable_autocast):
        lm_out = inferer.model.model.lm_head(hidden_states, inferer.model.model.word_embeddings_weight())
    logits = lm_out[:, :, :inferer.tokenizer.vocab_size]
    return logits

### Example 1a: Proba heatmap

1. Get logits for a batch of seqs
2. Pick a seq of interest
3. Drop padding mask, which includes dropping SOS and EOS tokens too
4. Apply softmax to convert logits to probabilities
5. Make a fun heatmp viz

In [None]:
batch_logits = hidden_states_to_logits(hidden_states)

i_seq_of_interest = 1
logits = batch_logits[i_seq_of_interest]
logits = logits[pad_masks[i_seq_of_interest]]

logits.shape

In [None]:
probas = torch.softmax(logits, dim=-1).detach().cpu().numpy()
plt.matshow(probas.T)

### Example 1b: Argmax

1. Translate the logits/probas back into sequence space via argmax

In [None]:
# Argmax and compare prediction/reconstruction to true/input
pred_idx_list = np.argmax(probas, axis=-1).tolist()
true_idx_list = inferer.tokenizer.text_to_ids(seqs[i_seq_of_interest])

pred_seq = inferer.tokenizer.ids_to_text(pred_idx_list).replace(" ", "")
true_seq = inferer.tokenizer.ids_to_text(true_idx_list).replace(" ", "")

display(pred_seq)
display(
    "".join(
        ["." if a == b else "|" for a, b in zip(pred_seq, true_seq)]
    )
)
display(true_seq)

I guess it does pretty perfect for Nvidia's little peptide examples! Easy when there's no masking.

### Example 1c: Compute log-proba of the sequence

Given the reconstruction may not always be perfect, but we do know the exact input, we pluck out the probas of the input tokens.

1. Extract the input probas
2. Log transform
3. Sum

In [None]:
rows = np.arange(len(true_idx_list))
cols = np.asarray(true_idx_list)
seq_log_proba = np.log(probas[rows, cols]).sum()
seq_log_proba

### Example 2: Mask some individual positions

1. Make some masks
2. Run the whole shebang to get probas of the masked token(s)

In [None]:
MASK_TOKEN = "<mask>"

seqs = [
    "MIQSQINRNIRLDLADAILLSKAKKDLSFAEIADGTGLA",  # original unmasked sequence
    "MIQ" + MASK_TOKEN + "QINRNIRLDLADAILLSKAKKDLSFAEIADGTGLA",
    "MIQSQINRNIRLDLADAILLSKAKKDLSFAEIADGT" + MASK_TOKEN + "LA"
]

hidden_states, pad_masks = inferer.seq_to_hiddens(seqs)
batch_logits = hidden_states_to_logits(hidden_states)

In [None]:
i_seq_of_interest = 1
logits = batch_logits[i_seq_of_interest]
logits = logits[pad_masks[i_seq_of_interest]]

probas = torch.softmax(logits, dim=-1).detach().cpu().numpy()
probas[3]  # position of mask

In [None]:
tmp_sort = sorted(enumerate(probas[3]), key=lambda x: x[1], reverse=True)
tmp = [(inferer.tokenizer.vocab[i], i, p) for i, p in tmp_sort]
tmp

Hmm, ranks the original amino acid (S) at 6-th most probable.

In [None]:
i_seq_of_interest = 2
logits = batch_logits[i_seq_of_interest]
logits = logits[pad_masks[i_seq_of_interest]]

probas = torch.softmax(logits, dim=-1).detach().cpu().numpy()
probas[-3]  # position of mask

In [None]:
tmp_sort = sorted(enumerate(probas[-3]), key=lambda x: x[1], reverse=True)
tmp = [(inferer.tokenizer.vocab[i], i, p) for i, p in tmp_sort]
tmp

Seems to like the original amino acid (G) here as most probable.

### Example 3: In-paint a bunch of masked tokens all in one go

1. Mask a short contiguous part of a seq
2. Unmask/in-paint all at once by taking the argmax
3. Unmask/in-paint by sampling probabilities

In [None]:
MASK_TOKEN = "<mask>"

before = "M"
n_contig = 8
contig = "".join([MASK_TOKEN] * n_contig)
after = "IRLDLADAILLSKAKKDLSFAEIADGTGLA"

seqs = [
    "MIQSQINRNIRLDLADAILLSKAKKDLSFAEIADGTGLA",  # original unmasked sequence
    before + contig + after
]

hidden_states, pad_masks = inferer.seq_to_hiddens(seqs)
batch_logits = hidden_states_to_logits(hidden_states)

In [None]:
i_seq_of_interest = 1
logits = batch_logits[i_seq_of_interest]
logits = logits[pad_masks[i_seq_of_interest]]
probas = torch.softmax(logits, dim=-1).detach().cpu().numpy()

In [None]:
# argmax in-paint: only taken over the masked contig

pred_idx_list = np.argmax(probas[len(before):len(before) + n_contig, :], axis=-1).tolist()
in_paint = inferer.tokenizer.ids_to_text(pred_idx_list).replace(" ", "")
pred_seq = before + in_paint + after

orig_seq = seqs[0]

display(pred_seq)
display(
    "".join(
        ["." if a == b else "|" for a, b in zip(pred_seq, orig_seq)]
    )
)
display(orig_seq)

In [None]:
foo = []
for p in probas[len(before):len(before) + n_contig, :]:
    i = np.random.choice(np.arange(len(p)), p=p)
    foo.append(i)

In [None]:
# sample in-paint: only taken over the masked contig

np.random.seed(88888)

pred_idx_list = []
for p in probas[len(before):len(before) + n_contig, :]:
    i = np.random.choice(np.arange(len(p)), p=p)
    pred_idx_list.append(i)

in_paint = inferer.tokenizer.ids_to_text(pred_idx_list).replace(" ", "")
pred_seq = before + in_paint + after

orig_seq = seqs[0]

display(pred_seq)
display(
    "".join(
        ["." if a == b else "|" for a, b in zip(pred_seq, orig_seq)]
    )
)
display(orig_seq)