# Inference Sample for ESM2

SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: LicenseRef-Apache2

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

### Setup

Before diving in, please ensure that you have completed all steps in the [Getting Started](../../../../docs/bionemo/index.md) section.

Additionally, this notebook assumes you have started a [local inference server](https://docs.nvidia.com/bionemo-framework/latest/inference-triton-fw.html) using a pretrained [ESM-1](https://docs.nvidia.com/bionemo-framework/latest/models/esm1-nv.html) model.


In [None]:
import warnings

warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

In [None]:
from pathlib import Path
import os

try:
    BIONEMO_HOME: Path = Path(os.environ['BIONEMO_HOME']).absolute()
except KeyError:
    print("Must have BIONEMO_HOME set in the environment! See docs for instructions.")
    raise

config_path = BIONEMO_HOME / "examples" / "protein" / "esm2nv" / "conf"
print(f"Using model configuration at: {config_path}")
assert config_path.is_dir()

### Setup and Test Data

In [None]:
seqs = [
    'MSLKRKNIALIPAAGIGVRFGADKPKQYVEIGSKTVLEHVL', 
    'MIQSQINRNIRLDLADAILLSKAKKDLSFAEIADGTGLA',
]

In [None]:
from bionemo.utils.hydra import load_model_config

cfg = load_model_config(config_name="infer.yaml", config_path=config_path)

In [None]:
from bionemo.triton.utils import load_model_for_inference
from bionemo.model.protein.esm1nv.infer import ESM1nvInference

inferer = load_model_for_inference(cfg, interactive=True)

print(f"Loaded a {type(inferer)}")
assert isinstance(inferer, ESM1nvInference)

### Sequence to Hidden States

__`seq_to_hiddens`__ queries the model to fetch the encoder hiddens states for the input protein sequence. `pad_mask` is returned with `hidden_states` and contains padding information  

In [None]:
hidden_states, pad_masks = inferer.seq_to_hiddens(seqs)
print(f"{hidden_states.shape=}")
print(f"{pad_masks.shape=}")
assert tuple(hidden_states.shape) == (2, 43, 1280)
assert tuple(pad_masks.shape) == (2, 43)

### Hidden States to Embedding

__`hiddens_to_embedding`__ computes embedding vector by averaging `hidden_states` 

In [None]:
embeddings = inferer.hiddens_to_embedding(hidden_states, pad_masks)
print(f"{embeddings.shape=}")
assert tuple(embeddings.shape) == (2, 1280)

### Sequence to Embedding

__`seq_to_embedding`__  queries the model to fetch the encoder hiddens states and computes embedding vector

In [None]:
embeddings = inferer.seq_to_embeddings(seqs)
print(f"{embeddings.shape=}")
assert tuple(embeddings.shape) == (2, 1280)