In [1]:
import os
import re
import json
from pathlib import Path

import pandas as pd
import torch

from botorch import fit_gpytorch_mll
from gpytorch.likelihoods import FixedNoiseGaussianLikelihood
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.acquisition.logei import qLogNoisyExpectedImprovement
from botorch.sampling import SobolQMCNormalSampler
from botorch.optim import optimize_acqf_discrete

from gskgpr import GaussianStringKernelGP
from seq2ascii import Seq2Ascii

In [2]:
device = "cpu"

In [3]:
def load_json_res(data_dir):
    with open(data_dir) as f:
        rep = json.load(f)
    
    F = rep["FE"]
    F_err = rep["FE_error"]
    
    return {"PCC": [rep["PCC2"]], "F": [float(F)], "F_err": [float(F_err)],
            }

def load_data(data_dir):
    PCC_list = []
    data = []
    for jsonf in os.listdir(data_dir):
        if re.match("^[A-Z]{5}.JSON", jsonf):
            PCC_list.append(jsonf.split(".")[0])
            data.append(pd.DataFrame(load_json_res(data_dir/jsonf)))

    data = pd.concat(data)
    data.reset_index(inplace=True, drop=True)
    return data

def initialize_model(train_x, train_y, err_y, translator, L=5):
    model = GaussianStringKernelGP(train_x=train_x, train_y=train_y, 
                likelihood=FixedNoiseGaussianLikelihood(noise=err_y), 
                translator=translator, L=L)
    model.num_outputs = 1
    mll = ExactMarginalLogLikelihood(model.likelihood, model)
    return model, mll

def opt_qlogEI_get_obs(model, choices, train_x, sampler):

    acq_func = qLogNoisyExpectedImprovement(model=model, X_baseline=train_x,
            sampler=sampler,)

    candidates, _ = optimize_acqf_discrete(
        acq_function=acq_func,
        q=3,
        choices=choices,
        max_batch_size=100
    )
    # observe new values
    new_x = candidates.detach()
    return new_x

In [4]:
dataset_raw = load_data(Path("/Users/arminsh/Documents/FEN-HTVS/results_dec"))
dataset = dataset_raw.copy()
dataset["F"] = -dataset.F
dataset["F_err"] = dataset.F_err
norm_transform = [dataset.F.mean(), dataset.F.std()]
dataset.F_err = dataset.F_err/norm_transform[1]
dataset.F = (dataset.F - norm_transform[0])/norm_transform[1]

In [5]:
dataset.describe()

Unnamed: 0,F,F_err
count,35.0,35.0
mean,6.399643e-16,0.268254
std,1.0,0.170349
min,-1.654453,0.085149
25%,-0.5386596,0.161007
50%,-0.1667077,0.214549
75%,0.5625199,0.299655
max,2.649232,0.786974


In [6]:
translator = Seq2Ascii("/Users/arminsh/Documents/FEN-HTVS/MFMOBO/AA.blosum62.pckl")

fspace = []
with open("/Users/arminsh/Documents/FEN-HTVS/gen_input_space/full_space.txt") as f:
    line = f.readline()
    while line:
        fspace.append(line.split()[0])
        line = f.readline()

translator.fit(fspace)
full_space = torch.as_tensor(list(translator.int2str.keys())).to(device)

In [7]:
encoded_x = translator.encode_to_int(dataset.PCC.to_list()).to(device)
train_y = torch.tensor(dataset.F.to_numpy()).float().to(device)
err_y = torch.tensor(dataset.F_err.to_numpy()).float().to(device)
model, mll = initialize_model(encoded_x, train_y, err_y**2, translator)

In [8]:
choices = list(translator.int2str.keys())
for i in dataset.PCC: # remove the ones that are already in the training set
    choices.remove(translator.str2int[i])
choices = torch.Tensor(choices).view(-1, 1).to(device)

In [9]:
mll.train()
model.train()
fit_gpytorch_mll(mll)
mll.eval()
mll.eval()

ExactMarginalLogLikelihood(
  (likelihood): FixedNoiseGaussianLikelihood(
    (noise_covar): FixedGaussianNoise()
  )
  (model): GaussianStringKernelGP(
    (likelihood): FixedNoiseGaussianLikelihood(
      (noise_covar): FixedGaussianNoise()
    )
    (mean_module): ConstantMean()
    (covar_module): GenericStringKernel(
      (raw_sigma1_constraint): Positive()
      (raw_sigma2_constraint): Positive()
    )
  )
)

In [14]:
sampler = SobolQMCNormalSampler(sample_shape=torch.Size([1028]))
new_x = opt_qlogEI_get_obs(model=model, choices=choices, train_x=encoded_x.reshape(-1, 1), sampler=sampler)

In [15]:
new_x_seq = translator.decode(new_x)

In [16]:
print(new_x_seq)

['YRWWW', 'HLWWW', 'YWPWF']
