In [None]:
from weighted_retraining.chem.jtnn import Vocab
from weighted_retraining.chem.chem_model import JTVAE
import torch
from tqdm import tqdm
import numpy as np
import os
import rdkit

In [None]:
result_dir = "./sample-results"
os.makedirs(result_dir, exist_ok=True)

In [None]:
# Load vocab
vocab_file = "./data/chem/orig_model/vocab.txt"
with open(vocab_file) as f:
    vocab = Vocab([x.strip() for x in f.readlines()])

In [None]:
# Make rdkit be quiet
def rdkit_quiet():
    lg = rdkit.RDLogger.logger()
    lg.setLevel(rdkit.RDLogger.CRITICAL)
    
rdkit_quiet()

In [None]:
samples_per_iter = 5000
z_sample = torch.randn(samples_per_iter, 56, device=torch.device("cuda"))

In [None]:
# Set up results tracking
results = dict(
    sample_points=[],
    sample_versions=[],
)
# Load model
model = JTVAE.load_from_checkpoint(
    "./data/models/chem.ckpt", 
    vocab=vocab
).cuda()

# Decode all points in a fixed decoding radius
z_decode = []
batch_size = 1
with tqdm(total=len(z_sample)) as pbar:
    for j in range(0, len(z_sample), batch_size):
        pbar.set_description("decoding with uniform weighted model")
        with torch.no_grad():
            z_batch = z_sample[j : j + batch_size]
            smiles_out = model.decode_deterministic(z_batch)
            pbar.update(z_batch.shape[0])
        z_decode += smiles_out

results["sample_points"].append(z_decode)
# results["sample_versions"].append(retrain)

# Save results
np.savez_compressed(os.path.join(result_dir, "results-uniform-weight.npz"), **results)

In [None]:
# seed should be the same as for training
seed = 730007773

In [None]:
for pathway_model in ["viable", "modified", "impractical"]:
    for k in [3, 4, 5, 6]:
        # Set up results tracking
        results = dict(
            sample_points=[],
            sample_versions=[],
        )

        for idx, retrain in enumerate(np.arange(0,500,50)):
            # Load model
            model = JTVAE.load_from_checkpoint(
                "./logs/bo/chem_therapeutic_score_{}/rank/k_1e-{}/r_50/seed{}/retraining/retrain_{}/checkpoints/last.ckpt".format(pathway_model, k, seed, retrain), 
                vocab=vocab
            ).cuda()

            # Decode all points in a fixed decoding radius
            z_decode = []
            batch_size = 1
            with tqdm(total=len(z_sample)) as pbar:
                for j in range(0, len(z_sample), batch_size):
                    pbar.set_description("decoding with retrain iteration "+ str(idx+1)+" for "+ pathway_model + " pathway model and k "+str(k))
                    with torch.no_grad():
                        z_batch = z_sample[j : j + batch_size]
                        smiles_out = model.decode_deterministic(z_batch)
                        pbar.update(z_batch.shape[0])
                    z_decode += smiles_out

            results["sample_points"].append(z_decode)
            results["sample_versions"].append(retrain)

            # Save results
            np.savez_compressed(os.path.join(result_dir, "results-{}-k-{}.npz".format(pathway_model, k)), **results)