# Inference Sample for ESM1nv

SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: LicenseRef-Apache2

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

### Prerequisites

Before diving in, please ensure that you have completed all steps in the [Getting Started](../../../../docs/bionemo/index.md) section.

Additionally, this notebook assumes you have started a [local inference server](https://docs.nvidia.com/bionemo-framework/latest/inference-triton-fw.html) using a pretrained [ESM-1](https://docs.nvidia.com/bionemo-framework/latest/models/esm1-nv.html) model.


### Setup


In [None]:
from bionemo.triton.inference_wrapper import new_inference_wrapper
import warnings

warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

In [None]:
connection = new_inference_wrapper("grpc://localhost:8001")

seqs = [
    'MSLKRKNIALIPAAGIGVRFGADKPKQYVEIGSKTVLEHVL', 'MIQSQINRNIRLDLADAILLSKAKKDLSFAEIADGTGLA',
]

### Sequence to Hidden States

__`seq_to_hiddens`__ queries the model to fetch the encoder hiddens states for the input protein sequence. `enc_mask` is returned with `hiddens` and contains padding information  

In [None]:
hidden_states, pad_masks = connection.seqs_to_hidden(seqs)
print(f"{hidden_states.shape=}")
print(f"{pad_masks.shape=}")
assert tuple(hidden_states.shape) == (2, 43, 768)
assert tuple(pad_masks.shape) == (2, 43)

### Hidden States to Embedding

__`hiddens_to_embedding`__ computes embedding vector by averaging `hiddens` 

In [None]:
embeddings = connection.hiddens_to_embedding(hidden_states, pad_masks)
print(f"{embeddings.shape=}")
assert tuple(embeddings.shape) == (2, 768)

### Sequence to Embedding

__`seq_to_embedding`__  queries the model to fetch the encoder hiddens states and computes embedding vector

In [None]:
embeddings = connection.seqs_to_embedding(seqs)
print(f"{embeddings.shape=}")
assert tuple(embeddings.shape) == (2, 768)