In [1]:
from gru.vae_gru import EncoderDecoder
from gru.dataset import GRUDataset
import selfies as sf
import rdkit.Chem as Chem
import torch
import numpy as np
import torch.nn as nn
import pandas as pd
from tqdm import tqdm
from vectorizer import SELFIESVectorizer, determine_alphabet
import gc

# Set hyperparameters
encoding_size = 512
hidden_size = 512
num_layers = 1
dropout = 0 # dropout must be equal 0 if num_layers = 1
teacher_ratio = 0.5

# Set paths

model_path = './models/vae_gru/epoch_150.pt'

val_dataset_path = './models/vae_gru/val_dataset.parquet'

# Set number of molecules to generate

n_molecules = 1000

In [4]:
gc.collect()
torch.cuda.empty_cache()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cuda'

# Init model
model = EncoderDecoder(
    fp_size=4860,
    encoding_size=encoding_size,
    hidden_size=hidden_size,
    num_layers=num_layers,
    dropout=dropout,
    teacher_ratio = teacher_ratio).to(device)

#model.load_state_dict(torch.load('PATH'))
model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))

def get_predictions(val_dataset_path, n_molecules, shuffle=False):
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    alphabet = pd.read_csv('./GRU_data/alphabet.txt', header=None)[0].values.tolist()
    vectorizer = SELFIESVectorizer(alphabet, pad_to_len=128)

    val_df = pd.read_parquet(val_dataset_path).reset_index().drop(columns='index')
    val_dataset = GRUDataset(val_df, vectorizer)

    from torch.utils.data import DataLoader
    batch_size = n_molecules
    val_loader = DataLoader(val_dataset, shuffle=shuffle, batch_size=batch_size, drop_last=True)

    x, y = next(iter(val_loader))
    x = x.to(device)
    y = y.to(device)
    preds = model(x, y, teacher_forcing=False)
    preds = preds.detach().cpu().numpy()
    targets = y.detach().cpu().numpy()
    fps = val_df.fps.apply(eval).tolist()

    smiles_targets = []
    smiles_preds = []
    
    for n in range(n_molecules):
        selfie_pred = vectorizer.devectorize(preds[n], remove_special=True)
        selfie_target = vectorizer.devectorize(targets[n], remove_special=True)
        try:
            smiles_target = sf.decoder(selfie_pred)
        except:
            smiles_target = 'C'
        smiles_target = sf.decoder(selfie_target)
    
    return targets, preds, fps

In [5]:
targets, preds, fps = get_predictions(val_dataset_path, n_molecules, shuffle=True)
