In [1]:
# Changes working directory to root to allow imports to work

import os

def change_directory_to_root(root_name: str = "GDProject"):
    """Changes the directory to the root of the project"""
    current_folder = os.getcwd().split('/')[-1]
    if current_folder != root_name:
        os.chdir('..')

    print(f"New Current Directory is {os.getcwd()}")

change_directory_to_root()

New Current Directory is /home/alden/Research/GDProject


In [2]:
# Supress pytorch pickle load warnings
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

# Torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

# Logging
from tqdm import tqdm
import matplotlib.pyplot as plt
import pickle

# Library imports
import gdiffusion as gd
import util
import util.chem as chem
import util.visualization as vis
import util.stats as gdstats

from gdiffusion.classifier.logp_predictor import LogPPredictor
# Paths for all of the models:
from util.model_paths import *

# Import guacamole tasks:
import util.chem.guacamole as guac

device = util.get_device(print_device=True)

Device: cuda


In [3]:
# Load Diffusion / VAE
molecule_diffusion_model = gd.MoleculeDiffusionModel(unet_state_dict_path=MOLECULE_DIFFUSION_MODEL_PATH, device=device)
diffusion = gd.DDIMSampler(diffusion_model=molecule_diffusion_model, sampling_timesteps=100)
vae = gd.MoleculeVAE()

[Molecule Diffusion]: UNet Successfully loaded
- Total parameters: 57,314,049
- Trainable parameters: 57,314,049
- Model size: 218.6 MB

 Loading Molecule VAE:
------------------------------------------------
Loaded VAE Vocab from saved_models/selfies_vae/vocab.json
Getting State Dict...
Loading model from saved_models/selfies_vae/selfies-vae.ckpt
Enc params: 1,994,592
Dec params: 277,346
------------------------------------------------



Vanilla BO:

In [39]:
from botorch.optim import optimize_acqf
from botorch.acquisition import qLogExpectedImprovement, qExpectedImprovement
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.models import SingleTaskGP
from botorch.utils.transforms import normalize, unnormalize
from torch.optim import SGD
from botorch import fit_gpytorch_mll

# Loads 20,000 guacamol molecule latents
LATENT_DATASET_LOCATION = "data/latents_pair_dataset_1"
X_all_data = torch.load(LATENT_DATASET_LOCATION, weights_only=True)

def print_dataset_stats(X, Y):
    print(f"--- Dataset Stats ---")
    print(f"Min Obj Value: {Y.min():.4f}")
    print(f"Max Obj Value: {Y.max():.4f}")

In [40]:
def obj_fun(z: torch.Tensor):
    selfies = vae.decode(z.float())
    smiles = chem.selfies_to_smiles(selfies)
    guac_score = guac.smiles_to_desired_scores(smiles, task_id="rano")
    return torch.tensor(guac_score).reshape(-1, 1)


In [41]:
latent_dim = 128

bounds = torch.tensor([[-3.0] * latent_dim, [3.0] * latent_dim], device=device, dtype=torch.float64)
unit_bounds = torch.tensor([[0.0] * latent_dim, [1.0] * latent_dim], device=device, dtype=torch.float64)
def get_initial_data(num_points: int = 100):
    Xs = X_all_data[:num_points].reshape(-1, latent_dim).clone().double().to(device)
    Ys = obj_fun(Xs.float()).reshape(-1, 1).clone().double().to(device)
    return Xs, Ys

In [42]:
BATCH_SIZE = 5
NUM_RESTARTS = 10
RAW_SAMPLES = 256
N_BATCH = 100

def get_fitted_model(train_x, train_obj, state_dict=None):
    train_x = torch.clip(train_x, min=-3.0, max=3.0)
    model = SingleTaskGP(
        train_X=normalize(train_x, bounds),
        train_Y=train_obj,
    )
    if state_dict is not None:
        model.load_state_dict(state_dict)
    mll = ExactMarginalLogLikelihood(model.likelihood, model)
    mll.to(train_x)
    fit_gpytorch_mll(mll)
    return model


def optimize_acqf_and_get_observation(acq_func):
    candidates, _ = optimize_acqf(
        acq_function=acq_func,
        bounds=unit_bounds,
        q=BATCH_SIZE,
        num_restarts=NUM_RESTARTS,
        raw_samples=RAW_SAMPLES
    )

    new_x = unnormalize(candidates.detach(), bounds=bounds).to(device)
    new_obj = obj_fun(new_x).to(device)
    return new_x, new_obj



In [43]:
train_x, train_obj = get_initial_data(num_points=10)

print_dataset_stats(train_x, train_obj)

state_dict = None

num_oracle_calls_list = []
best_value_list = []

print(f"\nRunning BO ")

for iteration in range(N_BATCH):
    model = get_fitted_model(
        train_x=train_x,
        train_obj=train_obj,
        state_dict=state_dict,
    )

    # define the qNEI acquisition function
    qEI = qLogExpectedImprovement(
        model=model, best_f=train_obj.max()
    )

    # optimize and get new observation
    new_x, new_obj = optimize_acqf_and_get_observation(qEI)

    # update training points
    train_x = torch.cat((train_x, new_x))
    train_obj = torch.cat((train_obj, new_obj))

    # update progress
    best_value = train_obj.max().item()
    num_oracle_calls_list.append(train_x.shape[0])
    best_value_list.append(best_value)

    print(f"[{train_x.shape[0]}] Best Value: {best_value:.2f}")
    state_dict = model.state_dict()

    


--- Dataset Stats ---
Min Obj Value: 0.0056
Max Obj Value: 0.4235

Running BO 
[15] Best Value: 0.42
[20] Best Value: 0.52


KeyboardInterrupt: 

BO Guided Diffusion:


In [46]:
def optimize_acqf_and_get_observation_with_diffusion(acq_func):
    def log_prob_fn_acq_func(x):
        x = x.clamp(-3.0, 3.0)
        return acq_func(x)

    cond_fn_acq_func = gd.get_cond_fn(
        log_prob_fn=log_prob_fn_acq_func,
        clip_grad=False,
        latent_dim=128
    )

    candidates = diffusion.sample(batch_size=BATCH_SIZE, guidance_scale=8.0, cond_fn=cond_fn_acq_func)

    new_x = unnormalize(candidates.detach(), bounds=bounds).to(device)
    new_obj = obj_fun(new_x).to(device).nan_to_num(nan=0.0)
    return new_x, new_obj

In [47]:
# Diffusion with BO:
train_x, train_obj = get_initial_data(num_points=10)

print_dataset_stats(train_x, train_obj)

state_dict = None

num_oracle_calls_list = []
best_value_list = []

print(f"\nRunning BO ")

for iteration in range(N_BATCH):
    model = get_fitted_model(
        train_x=train_x,
        train_obj=train_obj,
        state_dict=state_dict,
    )

    # define the qNEI acquisition function
    qEI = qLogExpectedImprovement(
        model=model, best_f=train_obj.max()
    )

    # optimize and get new observation
    new_x, new_obj = optimize_acqf_and_get_observation_with_diffusion(qEI)

    # update training points
    train_x = torch.cat((train_x, new_x))
    train_obj = torch.cat((train_obj, new_obj))

    # update progress
    best_value = train_obj.max().item()
    num_oracle_calls_list.append(train_x.shape[0])
    best_value_list.append(best_value)

    print(f"[{train_x.shape[0]}] Best Value: {best_value:.2f}")
    state_dict = model.state_dict()

    


--- Dataset Stats ---
Min Obj Value: 0.0056
Max Obj Value: 0.4235

Running BO 


DDIM Sampling Loop Time Step: 100%|██████████| 100/100 [00:01<00:00, 55.92it/s]


[15] Best Value: 0.42


DDIM Sampling Loop Time Step: 100%|██████████| 100/100 [00:01<00:00, 55.99it/s]


[20] Best Value: 0.42


DDIM Sampling Loop Time Step: 100%|██████████| 100/100 [00:01<00:00, 55.86it/s]


[25] Best Value: 0.42


DDIM Sampling Loop Time Step: 100%|██████████| 100/100 [00:01<00:00, 56.45it/s]


[30] Best Value: 0.45


DDIM Sampling Loop Time Step: 100%|██████████| 100/100 [00:01<00:00, 54.83it/s]


[35] Best Value: 0.45


DDIM Sampling Loop Time Step: 100%|██████████| 100/100 [00:01<00:00, 54.79it/s]


[40] Best Value: 0.45


DDIM Sampling Loop Time Step: 100%|██████████| 100/100 [00:01<00:00, 54.40it/s]


[45] Best Value: 0.62


DDIM Sampling Loop Time Step: 100%|██████████| 100/100 [00:01<00:00, 54.90it/s]


[50] Best Value: 0.62


DDIM Sampling Loop Time Step: 100%|██████████| 100/100 [00:01<00:00, 55.19it/s]


[55] Best Value: 0.62


DDIM Sampling Loop Time Step: 100%|██████████| 100/100 [00:01<00:00, 55.05it/s]


[60] Best Value: 0.62


DDIM Sampling Loop Time Step: 100%|██████████| 100/100 [00:01<00:00, 55.63it/s]


[65] Best Value: 0.62


DDIM Sampling Loop Time Step: 100%|██████████| 100/100 [00:01<00:00, 55.37it/s]


[70] Best Value: 0.62


DDIM Sampling Loop Time Step: 100%|██████████| 100/100 [00:01<00:00, 55.63it/s]


[75] Best Value: 0.62


DDIM Sampling Loop Time Step: 100%|██████████| 100/100 [00:01<00:00, 56.25it/s]


[80] Best Value: 0.62


DDIM Sampling Loop Time Step: 100%|██████████| 100/100 [00:01<00:00, 58.30it/s]


[85] Best Value: 0.62


DDIM Sampling Loop Time Step: 100%|██████████| 100/100 [00:01<00:00, 56.41it/s]


[90] Best Value: 0.62


DDIM Sampling Loop Time Step: 100%|██████████| 100/100 [00:01<00:00, 55.20it/s]


[95] Best Value: 0.62


KeyboardInterrupt: 