# Inference Sample

SPDX-FileCopyrightText: Copyright (c) <year> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: LicenseRef-NvidiaProprietary

NVIDIA CORPORATION, its affiliates and licensors retain all intellectual property and proprietary rights in and to this material, related documentation and any modifications thereto. Any use, reproduction, disclosure or distribution of this material and related documentation without an express license agreement from NVIDIA CORPORATION or its affiliates is strictly prohibited.

### Prerequisite

**Before diving in, ensure you have all [necessary prerequisites](https://docs.nvidia.com/bionemo-framework/latest/pre-reqs.html). If this is your first time using BioNeMo, we recommend following the [quickstart guide](https://docs.nvidia.com/bionemo-framework/latest/quickstart-fw.html) first.** 

Additionally, this notebook assumes you have started a [local inference server](https://docs.nvidia.com/bionemo-framework/latest/inference-triton-fw.html) using a pretrained [MegaMolBART](https://docs.nvidia.com/bionemo-framework/latest/models/megamolbart.html) model.

In [None]:
import warnings

warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

In [None]:
from typing import List
from pathlib import Path
import os

try:
    BIONEMO_HOME: Path = Path(os.environ['BIONEMO_HOME']).absolute()
except KeyError:
    print("Must have BIONEMO_HOME set in the environment! See docs for instructions.")
    raise

config_path = BIONEMO_HOME / "examples" / "molecule" / "megamolbart" / "conf"
print(f"Using model configuration at: {config_path}")
assert config_path.is_dir()

### Setup and Test Data

`InferenceWrapper` is an adaptor that allows interaction with inference service.

In [None]:
smis = [
    'c1ccc2ccccc2c1',
    'COc1cc2nc(N3CCN(C(=O)c4ccco4)CC3)nc(N)c2cc1OC',
]

In [None]:
from bionemo.utils.hydra import load_model_config

cfg = load_model_config(config_name="infer.yaml", config_path=config_path)

In [None]:
from bionemo.triton.utils import load_model_for_inference
from bionemo.model.molecule.megamolbart.infer import MegaMolBARTInference

inferer = load_model_for_inference(cfg, interactive=True)

print(f"Loaded a {type(inferer)}")
assert isinstance(inferer, MegaMolBARTInference)

### SMILES to hidden state

`seq_to_hiddens` obtains the model's latent space representation of the SMILES.

In [None]:
hidden_states, pad_masks = inferer.seq_to_hiddens(smis)
print(f"{hidden_states.shape=}")
print(f"{pad_masks.shape=}")

assert tuple(hidden_states.shape) == (2, 45, 512)
assert tuple(pad_masks.shape) == (2, 45)

In [None]:
embeddings = inferer.hiddens_to_embedding(hidden_states, pad_masks)
print(f"{embeddings.shape=}")
assert tuple(embeddings.shape) == (2, 512)

### SMILES to Embedding

`smis_to_embedding` queries the model to fetch the encoder embedding for the input SMILES.

In [None]:
embedding = inferer.seq_to_embeddings(smis)
print(f"{embedding.shape=}")
assert tuple(embedding.shape) == (2, 512)

Note that this is equivalent to first producing the hidden representation, then using the input mask to produce embeddings with the encoder.

### Hidden state to SMILES

`hidden_to_smis` decodes the latent space representation back to SMILES.

In [None]:
from rdkit import Chem


def canonicalize_smiles(smiles: str) -> str:
    """Canonicalize input SMILES"""
    mol = Chem.MolFromSmiles(smiles)
    canon_smiles = Chem.MolToSmiles(mol, canonical=True)
    return canon_smiles

In [None]:
infered_smis = inferer.hiddens_to_seq(hidden_states, pad_masks)
canon_infered_smis = list(map(canonicalize_smiles, infered_smis))
print(f"Reconstructed SMILES:\n{canon_infered_smis}")
assert len(canon_infered_smis) == 2
for i, (original, reconstructed) in enumerate(zip(smis, canon_infered_smis)):
    assert original == reconstructed, f"Failure to recongstruct on #{i+1}: {original=}, {reconstructed=}"

### Sampling: Generate SMILES


In [None]:
samples = inferer.sample(num_samples=3, return_embedding=False, sampling_method="greedy-perturbate", seqs=smis)
print(f"Generated {len(samples)} samples")

assert len(samples) == 2
for i,s in enumerate(samples):
    print(f"Sample #{i+1} (length: {len(s)}):\n{s}\n-----------------------")
    assert len(s) == 3