In [None]:
import sys
from pyprojroot import here as project_root

sys.path.insert(0, str(project_root()))

from fs_mol.data.cnp import get_cnp_batcher
from fs_mol.utils.torch_utils import torchify

from bayes_opt.bo_utils import load_cep_dataset, run_gp_ei_bo, min_so_far, task_to_batches, CNPModelFeatureExtractor
import numpy as np
import torch
import matplotlib.pyplot as plt
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
task = load_cep_dataset("cep-dataset-subsampled.csv", "../../fs_mol/preprocessing/utils/helper_files/")

In [None]:
batcher = get_cnp_batcher(max_num_graphs=100)
cnp_batches = torchify(
    task_to_batches(task, batcher), 
    device=device
)

In [None]:
model_weights_file = "../../outputs/FSMol_CNPModel_gnn+ecfp+fc_2022-04-11_16-46-27/best_validation.pt" #classificaiton


cnp_model = CNPModelFeatureExtractor.build_from_model_file(
    model_weights_file,
    device=device
).to(device)

cnp_model.eval()

In [None]:
representations = []

for features in cnp_batches:
    representation = cnp_model.get_representation(features)
    representations.append(representation)
    
del cnp_model

In [None]:
dataset = task.samples

x_all = torch.cat(representations, dim=0)
y_all = torch.FloatTensor([float(x.numeric_label) for x in dataset]).to(device)

In [None]:
num_init_points = 16
query_batch_size = 1
num_bo_iters = 40
kernel_type = "matern"
init_from = 1200
noise_init = 0.01
noise_prior = True

num_repeats = 20

bo_records = []

In [None]:
for repeat in tqdm(range(num_repeats)):
    bo_record = run_gp_ei_bo(dataset, x_all, y_all, num_init_points, query_batch_size, num_bo_iters, kernel_type, device, init_from, noise_init, noise_prior)
    bo_records.append(min_so_far(bo_record))

In [None]:
x_axis = np.arange(query_batch_size*num_bo_iters+1)

#bo_records = np.array([[float(-1.0 * y_all[i].item()) for i in bo_record] for bo_record in bo_records])
bo_records_mean = bo_records.mean(axis=0)
bo_records_std = bo_records.std(axis=0)

plt.figure(figsize=(5,5))

plt.plot(x_axis, bo_records_mean)
plt.fill_between(x_axis, bo_records_mean-bo_records_std, bo_records_mean+bo_records_std, alpha=0.4)

plt.xlabel("Number of molecules queried")
plt.ylabel("Top-1 power conversion efficiency")
plt.ylim(0.0, 12.0)

In [None]:
import pickle

with open("outputs/cnp_bo_records.pkl", "wb") as output_file:
    pickle.dump(bo_records, output_file)