In [2]:
from gru.gru_v3 import EncoderDecoder
from gru.dataset import GRUDataset
from gru.cce import ConsciousCrossEntropy
import torch
import torch.nn as nn
import pandas as pd
from vectorizer import SELFIESVectorizer, determine_alphabet

model = EncoderDecoder

device = 'cuda'
print(device)

# Set hyperparameters
encoding_size = 512
hidden_size = 512
num_layers = 1
learn_rate = 0.0003
dropout = 0 # dropout must be equal 0 if num_layers = 1
teacher_ratio = 0.5

# Init model
model = EncoderDecoder(
    fp_size=4860,
    encoding_size=encoding_size,
    hidden_size=hidden_size,
    num_layers=num_layers,
    dropout=dropout,
    teacher_ratio = teacher_ratio).to(device)

alphabet = pd.read_csv('./GRU_data/alphabet.txt', header=None).values.flatten()

#model.load_state_dict(torch.load('PATH'))
model.load_state_dict(torch.load('./models/v3-revisited/model_epoch_14.pt'))

cuda


<All keys matched successfully>

In [23]:
test_df = pd.read_parquet('./GRU_data/test_dataset.parquet').reset_index().drop(columns='index')
vectorizer = SELFIESVectorizer(alphabet, pad_to_len=128)
test_dataset = GRUDataset(test_df, vectorizer)

print("Test size:", len(test_dataset))

from torch.utils.data import DataLoader
batch_size = 256
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size, drop_last=True)

x, y = next(iter(test_loader))
x = x.to(device)
y = y.to(device)

Test size: 121524


In [24]:
out = model(x, y, teacher_forcing=False)
softmax = nn.Softmax(dim=2)
out = softmax(out)
out = out.detach().cpu().numpy()
target = y.detach().cpu().numpy()

out.shape

(256, 128, 42)

In [25]:
import selfies as sf
import rdkit.Chem as Chem

preds = []
targets = []

for n in range(256):
    selfie_out = vectorizer.devectorize(out[n], remove_special=True)
    selfie_target = vectorizer.devectorize(target[n], remove_special=True)
    smiles_out = sf.decoder(selfie_out)
    smiles_target = sf.decoder(selfie_target)
    mol_out = Chem.MolFromSmiles(smiles_out)
    mol_target = Chem.MolFromSmiles(smiles_target)
    preds.append(mol_out)
    targets.append(mol_target)

In [26]:
from ipywidgets import interact
import ipywidgets as widgets

@interact(idx=(0, batch_size-1))
def print_at_idx(idx):
    return Chem.Draw.MolsToImage([targets[idx], preds[idx]], subImgSize=(400, 400), legends=None)

interactive(children=(IntSlider(value=127, description='idx', max=255), Output()), _dom_classes=('widget-inter…