# Inference Sample

SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: LicenseRef-Apache2

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

### Setup

Before diving in, please ensure that you have completed all steps in the [Getting Started](../../../../docs/bionemo/index.md) section.

Additionally, this notebook assumes you have started a [local inference server](https://docs.nvidia.com/bionemo-framework/latest/inference-triton-fw.html) using a pretrained [MegaMolBART](https://docs.nvidia.com/bionemo-framework/latest/models/megamolbart.html) model.

In [None]:
from rdkit import Chem
from bionemo.triton.inference_wrapper import new_inference_wrapper
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

### Setup and Test Data

`new_inference_wrapper` creates a client that communicates with the Triton model server.

In [None]:
connection = new_inference_wrapper("grpc://localhost:8001")

smis = [
    'c1ccc2ccccc2c1',
    'COc1cc2nc(N3CCN(C(=O)c4ccco4)CC3)nc(N)c2cc1OC',
]

### SMILES to hidden state

`seqs_to_hidden` queries the model to fetch the latent space representation of the SMILES.

In [None]:
hidden_states, pad_masks = connection.seqs_to_hidden(smis)
print(f"{hidden_states.shape=}")
print(f"{pad_masks.shape=}")

assert tuple(hidden_states.shape) == (2, 45, 512)
assert tuple(pad_masks.shape) == (2, 45)

## Hidden States to Embeddings

In [None]:
embeddings = connection.hiddens_to_embedding(hidden_states, pad_masks)
print(f"{embeddings.shape=}")
assert tuple(embeddings.shape) == (2, 512)

### SMILES to Embedding

`seqs_to_embedding` queries the model to fetch the encoder embedding for the input SMILES.

In [None]:
embedding = connection.seqs_to_embedding(smis)
print(f"{embedding.shape=}")
assert tuple(embedding.shape) == (2, 512)

### Hidden state to SMILES

`hidden_to_seqs` decodes the latent space representation back to SMILES.

In [None]:
def canonicalize_smiles(smiles: str) -> str:
    """Canonicalize input SMILES"""
    mol = Chem.MolFromSmiles(smiles)
    canon_smiles = Chem.MolToSmiles(mol, canonical=True)
    return canon_smiles

In [None]:
infered_smis = connection.hidden_to_seqs(hidden_states, pad_masks)
canon_infered_smis = list(map(canonicalize_smiles, infered_smis))
print(f"Reconstructed SMILES:\n{canon_infered_smis}")
assert len(canon_infered_smis) == 2

### Sampling: Generate SMILES


In [None]:
samples = connection.sample_seqs(smis)
print(f"Generated {len(samples)} samples")
assert len(samples) == 2
for i,s in enumerate(samples):
    print(f"Sample #{i+1} (length: {len(s)}):\n{s}\n-----------------------")
    assert len(s) == 1