## Inferencing

### Unconditional Sampling

In [None]:
import torch
import ipywidgets as widgets
import os 

from inference import Inference
from IPython.display import clear_output 
from rdkit.Chem.Draw.IPythonConsole import drawMol3D
from rdkit.Chem import MolFromSmiles, AllChem


def sampling(smiles:str, model_path:str, N:int=1, 
             save:bool=False, timesteps:int=0):
    
    print('rdkit generation')
    mol = MolFromSmiles(smiles)
    AllChem.EmbedMolecule(mol)
    drawMol3D(mol)
    
    inference = Inference(device, os.path.join('./model', model_path))
    inference.sampling(
        smiles    = smiles, 
        N         = N,
        save      = save,
        timesteps = timesteps
    )
    
        
        
def on_click(change):
    
    sampling(
        smiles      = smiles_widgets.value, 
        model_path  = model_widgets.value,
        N           = sample_widgets.value,
        save        = save_widgets.value,
        timesteps   = timesteps_widgets.value
    )
 

smiles_widgets = widgets.Textarea(
                    value='O=C(/C=C/c1ccco1)Nc1ccc(Cl)c(S(=O)(=O)N2CCOCC2)c1',
                    placeholder='Type something',
                    description='SMILES:',
                    disabled=False,
                    layout=widgets.Layout(height="100%", width="auto")
                )

sample_widgets = widgets.IntSlider(
                    value=4,
                    min=1,
                    max=8,
                    step=1,
                    description='Sample:',
                    orientation='horizontal',
                )

timesteps_widgets = widgets.IntSlider(
                    value=50,
                    min=0,
                    max=500,
                    step=10,
                    description='Timesteps:',
                    orientation='horizontal',
                )

button_widgets = widgets.ToggleButton(
                    value=False,
                    description='run',
                    disabled=False,
                    button_style='', # 'success', 'info', 'warning', 'danger' or ''
                    tooltip='Description',
                    icon='check' # (FontAwesome names without the `fa-` prefix)
                )
 
model_widgets = widgets.Dropdown(
                    options=[model for model in sorted(os.listdir("./model"))],
                    description='Model:',
                    disabled=False,
                )
 
save_widgets = widgets.Checkbox(
                    value       = False,
                    description = 'PDB file save',
                    disabled    = False,
                    indent      = False
                )
 
device = 'cuda' if torch.cuda.is_available() else 'cpu'
button_widgets.observe(on_click, 'value')

clear_output()
display(smiles_widgets, sample_widgets, timesteps_widgets, model_widgets, button_widgets, save_widgets)

### Conditional Sampling

In [None]:
import torch
import ipywidgets as widgets
import os 

from inference import Inference
from IPython.display import clear_output 

        
def on_click(change):
    
    inference = Inference(device, os.path.join('./model', model_widgets.value))
    inference.conditional_sampling(
        pdb_path = path_widgets.value,
        key_atom_list = key_atom_widgets.value.split(),
        mode               = mode_widgets.value, 
        N                  = sample_widgets.value, 
        timesteps          = timesteps_widgets.value,
        resampling         = resampling_widgets.value,
        refix_steps        = refix_widgets.value,
        save               = save_widgets.value
    )
    
    
path_widgets = widgets.Textarea(
                    value='./test.pdb',
                    placeholder='Type something',
                    description='Path:',
                    disabled=False,
                    layout=widgets.Layout(height="100%", width="auto")
                )

key_atom_widgets = widgets.Textarea(
                    value='O1 O2 C10 C15',
                    placeholder='Type something',
                    description='Key atom name:',
                    disabled=False,
                    layout=widgets.Layout(height="100%", width="auto")
                )

sample_widgets = widgets.IntSlider(
                    value=4,
                    min=1,
                    max=8,
                    step=1,
                    description='Sample:',
                    orientation='horizontal',
                )

timesteps_widgets = widgets.IntSlider(
                    value=50,
                    min=0,
                    max=500,
                    step=10,
                    description='Timesteps:',
                    orientation='horizontal',
                )

resampling_widgets = widgets.IntSlider(
                    value=10,
                    min=1,
                    max=10,
                    step=1,
                    description='Resamlpe:',
                    orientation='horizontal',
                )

refix_widgets = widgets.IntSlider(
                    value = 100,
                    min = 0,
                    max = 500,
                    step = 1,
                    description='Refix steps',
                    orientation='horizontal',
                )

mode_widgets = widgets.Dropdown(
                    options=['fixed', 'replacement'],
                    value='fixed',
                    description='Mode',
                    disabled=False,
                )

model_widgets = widgets.Dropdown(
                options=[model for model in sorted(os.listdir("./model"))],
                description='Model:',
                disabled=False,
            )
 
 
button_widgets = widgets.ToggleButton(
                    value=False,
                    description='run',
                    disabled=False,
                    button_style='', # 'success', 'info', 'warning', 'danger' or ''
                    tooltip='Description',
                    icon='check' # (FontAwesome names without the `fa-` prefix)
                )

save_widgets = widgets.Checkbox(
                    value       = False,
                    description = 'PDB file save',
                    disabled    = False,
                    indent      = False
                )
 
 
device = 'cuda' if torch.cuda.is_available() else 'cpu'
button_widgets.observe(on_click, 'value')

clear_output()
display(path_widgets, key_atom_widgets, sample_widgets, timesteps_widgets, resampling_widgets, refix_widgets, mode_widgets, model_widgets, button_widgets, save_widgets)