# Inference Sample

Copyright (c) 2022, NVIDIA CORPORATION. Licensed under the Apache License, Version 2.0 (the "License") you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0 

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

### Prerequisite
* Linux OS
* Pascal, Volta, Turing, or an NVIDIA Ampere architecture-based GPU.
* Nvidia Driver
* Docker

### Import
Components for inferencing is part of NeMo-MegaMolBART source code. This notebook demonstrates the use of these components.

MegaMolBARTInferer implements following functions:
* `smis_to_hidden`
* `smis_to_embedding`
* `hidden_to_smis`

In [1]:
from infer import InferenceWrapper

import logging
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

### Setup and Test Data

`InferenceWrapper` is an adaptor that allows interaction with inference service.

In [2]:
connection = InferenceWrapper()

smis = ['c1cc2ccccc2cc1',
        'COc1cc2nc(N3CCN(C(=O)c4ccco4)CC3)nc(N)c2cc1OC']


### SMILES to hidden state

`smis_to_hidden` queries the model to fetch the latent space representation of the SMILES.

In [3]:
hidden_states, pad_masks = connection.smis_to_hidden(smis)
hidden_states.shape, pad_masks.shape

(torch.Size([2, 45, 512]), torch.Size([2, 45]))

### SMILES to Embedding

`smis_to_embedding` queries the model to fetch the encoder embedding for the input SMILES.

In [4]:
embedding = connection.smis_to_embedding(smis)
embedding.shape

torch.Size([2, 512])

### Hidden state to SMILES

`hidden_to_smis` decodes the latent space representation back to SMILES.

In [6]:
infered_smis = connection.hidden_to_smis(hidden_states, pad_masks)
infered_smis

['c1cc2ccccc2cc1', 'COc1cc2nc(N3CCN(C(=O)c4ccco4)CC3)nc(N)c2cc1OC']