# Inference Sample for ESM2

SPDX-FileCopyrightText: Copyright (c) <year> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: LicenseRef-NvidiaProprietary

NVIDIA CORPORATION, its affiliates and licensors retain all intellectual property and proprietary rights in and to this material, related documentation and any modifications thereto. Any use, reproduction, disclosure or distribution of this material and related documentation without an express license agreement from NVIDIA CORPORATION or its affiliates is strictly prohibited.

### Prerequisites

Before diving in, please ensure that you have completed all steps in the [Getting Started](../../../../docs/bionemo/index.md) section.

Additionally, this notebook assumes you have started a [local inference server](https://docs.nvidia.com/bionemo-framework/latest/inference-triton-fw.html) using a pretrained [ESM-2](https://docs.nvidia.com/bionemo-framework/latest/models/esm2-nv.html) model.


To run this notebook, you must launch the PyTriton based model server script beforehand. See `bionemo.triton.inference_wrapper` for instructions.

In [None]:
from bionemo.triton.inference_wrapper import new_inference_wrapper
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

### Setup and Test Data

`new_inference_wrapper` creates a client that communicates with the Triton model server.

Please note that the batch size (number of sequences submitted for embedding inference), in other words, inference throughput, may be limited by the compute capacity of the inference node hosting the model. 

In [None]:
connection = new_inference_wrapper("grpc://localhost:8001")

seqs = [
    'MSLKRKNIALIPAAGIGVRFGADKPKQYVEIGSKTVLEHVL','MIQSQINRNIRLDLADAILLSKAKKDLSFAEIADGTGLA',
]

### Sequence to Hidden States

__`seq_to_hiddens`__ queries the model to fetch the encoder hiddens states for the input protein sequence. `enc_mask` is returned with `hidden_states` and contains padding information  

In [None]:
hidden_states, pad_masks = connection.seqs_to_hidden(seqs)
print(f"{hidden_states.shape=}")
print(f"{pad_masks.shape=}")
assert tuple(hidden_states.shape) == (2, 43, 1280)
assert tuple(pad_masks.shape) == (2, 43)

### Hidden States to Embedding

__`hiddens_to_embedding`__ computes embedding vector by averaging `hidden_states` 

In [None]:
embeddings = connection.hiddens_to_embedding(hidden_states, pad_masks)
print(f"{embeddings.shape=}")
assert tuple(embeddings.shape) == (2, 1280)

### Sequence to Embedding

__`seq_to_embedding`__  queries the model to fetch the encoder hiddens states and computes embedding vector

In [None]:
embeddings = connection.seqs_to_embedding(seqs)
print(f"{embeddings.shape=}")
assert tuple(embeddings.shape) == (2, 1280)