In [1]:
import sys
sys.path.append("../")
from qmg.generator import MoleculeGenerator
from qmg.utils import ConditionalWeightsGenerator, FitnessCalculatorWrapper
from rdkit import RDLogger
import numpy as np
import pandas as pd

from ax.service.ax_client import AxClient
from ax.service.utils.instantiation import ObjectiveProperties
from ax.modelbridge.generation_strategy import GenerationStrategy, GenerationStep
from ax.modelbridge.registry import Models
from ax import SearchSpace, ParameterType, RangeParameter
from ax.core.observation import ObservationFeatures
from ax.core.arm import Arm
import torch

In [21]:
task_name = "unconditional_5"
task_list = ["uniqueness", "validity"]
objective_list = ["maximize", "maximize"]
number_flexible_parameters = 44
previous_csv_path = "../results_chemistry_constraint_bo/unconditional_5.csv"
previous_data = pd.read_csv(previous_csv_path)
previous_data#.tail()

Unnamed: 0,trial_index,arm_name,trial_status,generation_method,uniqueness,validity,x1,x2,x3,x4,...,x35,x36,x37,x38,x39,x40,x41,x42,x43,x44
0,0,0_0,COMPLETED,Sobol,0.045676,0.3897,0.997513,0.104366,0.822979,0.419432,...,0.017679,0.318375,0.428213,0.382799,0.169766,0.086547,0.408415,0.358370,0.470249,0.487025
1,1,1_0,COMPLETED,Sobol,0.030492,0.9609,0.055046,0.653015,0.345471,0.708022,...,0.526224,0.932867,0.688492,0.720725,0.620063,0.645194,0.558199,0.899542,0.573572,0.573467
2,2,2_0,COMPLETED,Sobol,0.096405,0.4896,0.444051,0.301033,0.535844,0.068764,...,0.785372,0.680072,0.181371,0.232924,0.370444,0.252204,0.152236,0.672294,0.108700,0.171115
3,3,3_0,COMPLETED,Sobol,0.048666,0.5733,0.503528,0.939648,0.015434,0.810632,...,0.293916,0.067512,0.951904,0.886209,0.918409,0.951613,0.752246,0.070024,0.942933,0.772176
4,4,4_0,COMPLETED,Sobol,0.028777,0.7645,0.668296,0.408161,0.397002,0.941006,...,0.719341,0.823247,0.615294,0.840088,0.426103,0.813692,0.908571,0.853561,0.316958,0.913298
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
200,200,200_0,COMPLETED,GPEI,0.188904,0.9084,0.154644,0.932217,0.272021,0.257247,...,0.665534,0.770794,0.609409,0.417269,0.094133,0.000000,0.155573,1.000000,0.821068,1.000000
201,201,201_0,COMPLETED,GPEI,0.058273,0.9610,0.000000,0.735314,0.506411,0.572102,...,0.669469,0.864354,0.886785,1.000000,0.631551,0.434430,0.002487,1.000000,0.707513,1.000000
202,202,202_0,COMPLETED,GPEI,0.133911,0.9230,0.162387,0.830371,0.278899,0.337194,...,0.667949,0.778031,0.942415,0.000000,0.021712,0.000000,0.205354,1.000000,0.804467,1.000000
203,203,203_0,COMPLETED,GPEI,0.198453,0.9050,0.157859,0.844276,0.275767,0.340950,...,0.677747,0.784094,1.000000,0.000000,0.022551,0.000000,0.202811,1.000000,0.797313,1.000000


In [19]:
model_dict = {'MOO': Models.MOO, 'GPEI': Models.GPEI, 'SAASBO': Models.SAASBO,}
gs = GenerationStrategy(
    steps=[
        GenerationStep(
            model=model_dict['GPEI'],
            num_trials=-1,  # No limitation on how many trials should be produced from this step
            max_parallelism=1,  # Parallelism limit for this step, often lower than for Sobol
            model_kwargs = {"torch_dtype": torch.float64, "torch_device": torch.device("cuda" if torch.cuda.is_available() else "cpu"), 
                            },
        ),
    ]
)
ax_client = AxClient(random_seed = 42, generation_strategy = gs) # set the random seed for BO for reproducibility
ax_client.create_experiment(
    name=task_name,
    parameters=[
        {
            "name": f"x{i+1}",
            "type": "range",
            "bounds": [0.0, 1.0],
            "value_type": "float"
        }
        for i in range(number_flexible_parameters)
    ],
    objectives={task: ObjectiveProperties(minimize = objective=="minimize",) for task, objective in zip(task_list, objective_list)},
    overwrite_existing_experiment=True,
    is_test=True,
)
ax_client.get_trials_data_frame()

[INFO 11-19 00:48:31] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.
[INFO 11-19 00:48:31] ax.service.utils.instantiation: Due to non-specification, we will use the heuristic for selecting objective thresholds.
[INFO 11-19 00:48:31] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x1', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x3', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x4', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x5', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x6', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x7', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x8', parameter_type=FLOAT, range=[0.0, 1.0]),

Unnamed: 0,mean,sem,arm_name,metric_name


In [20]:
for i, row in previous_data.iterrows():
    input_parameters = row[[f"x{i+1}" for i in range(number_flexible_parameters)]].to_dict()
    output_response = row[task_list].to_dict()
    ax_client.attach_trial(input_parameters)
    ax_client.complete_trial(i, output_response)
    
ax_client.get_trials_data_frame()

[INFO 11-19 00:48:33] ax.core.experiment: Attached custom parameterizations [{'x1': 0.997513, 'x2': 0.104366, 'x3': 0.822979, 'x4': 0.419432, 'x5': 0.528322, 'x6': 0.056241, 'x7': 0.455292, 'x8': 0.749203, 'x9': 0.221074, 'x10': 0.617698, 'x11': 0.810095, 'x12': 0.212291, 'x13': 0.038436, 'x14': 0.881824, 'x15': 0.65063, 'x16': 0.666585, 'x17': 0.820932, 'x18': 0.596579, 'x19': 0.236308, 'x20': 0.806974, 'x21': 0.167144, 'x22': 0.767845, 'x23': 0.603483, 'x24': 0.932567, 'x25': 0.537682, 'x26': 0.398814, 'x27': 0.361339, 'x28': 0.289725, 'x29': 0.653567, 'x30': 0.234766, 'x31': 0.37597, 'x32': 0.322637, 'x33': 0.992174, 'x34': 0.348451, 'x35': 0.017679, 'x36': 0.318375, 'x37': 0.428213, 'x38': 0.382799, 'x39': 0.169766, 'x40': 0.086547, 'x41': 0.408415, 'x42': 0.35837, 'x43': 0.470249, 'x44': 0.487025}] as trial 0.
[INFO 11-19 00:48:33] ax.service.ax_client: Completed trial 0 with data: {'uniqueness': (0.045676, None), 'validity': (0.3897, None)}.
[INFO 11-19 00:48:33] ax.core.experime

[INFO 11-19 00:48:33] ax.core.experiment: Attached custom parameterizations [{'x1': 0.0, 'x2': 0.673957, 'x3': 0.304884, 'x4': 0.251053, 'x5': 0.261592, 'x6': 0.89568, 'x7': 0.487664, 'x8': 0.107787, 'x9': 0.717469, 'x10': 0.448258, 'x11': 0.085986, 'x12': 0.853217, 'x13': 0.838008, 'x14': 0.266435, 'x15': 0.0, 'x16': 0.097061, 'x17': 0.105365, 'x18': 0.539015, 'x19': 0.703425, 'x20': 0.426149, 'x21': 1.0, 'x22': 0.125257, 'x23': 0.443643, 'x24': 0.104994, 'x25': 0.483207, 'x26': 0.691284, 'x27': 0.950942, 'x28': 0.938622, 'x29': 0.0, 'x30': 0.901001, 'x31': 0.543153, 'x32': 1.0, 'x33': 0.160302, 'x34': 0.732708, 'x35': 0.637128, 'x36': 0.985485, 'x37': 0.662617, 'x38': 0.556142, 'x39': 0.632157, 'x40': 0.756546, 'x41': 0.459599, 'x42': 0.814234, 'x43': 0.451828, 'x44': 0.434732}] as trial 12.
[INFO 11-19 00:48:33] ax.service.ax_client: Completed trial 12 with data: {'uniqueness': (0.089136, None), 'validity': (0.9278, None)}.
[INFO 11-19 00:48:33] ax.core.experiment: Attached custom p

Unnamed: 0,trial_index,arm_name,trial_status,generation_method,uniqueness,validity,x1,x2,x3,x4,...,x35,x36,x37,x38,x39,x40,x41,x42,x43,x44
0,0,0_0,COMPLETED,Manual,0.045676,0.3897,0.997513,0.104366,0.822979,0.419432,...,0.017679,0.318375,0.428213,0.382799,0.169766,0.086547,0.408415,0.358370,0.470249,0.487025
1,1,1_0,COMPLETED,Manual,0.030492,0.9609,0.055046,0.653015,0.345471,0.708022,...,0.526224,0.932867,0.688492,0.720725,0.620063,0.645194,0.558199,0.899542,0.573572,0.573467
2,2,2_0,COMPLETED,Manual,0.096405,0.4896,0.444051,0.301033,0.535844,0.068764,...,0.785372,0.680072,0.181371,0.232924,0.370444,0.252204,0.152236,0.672294,0.108700,0.171115
3,3,3_0,COMPLETED,Manual,0.048666,0.5733,0.503528,0.939648,0.015434,0.810632,...,0.293916,0.067512,0.951904,0.886209,0.918409,0.951613,0.752246,0.070024,0.942933,0.772176
4,4,4_0,COMPLETED,Manual,0.028777,0.7645,0.668296,0.408161,0.397002,0.941006,...,0.719341,0.823247,0.615294,0.840088,0.426103,0.813692,0.908571,0.853561,0.316958,0.913298
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
200,200,200_0,COMPLETED,Manual,0.188904,0.9084,0.154644,0.932217,0.272021,0.257247,...,0.665534,0.770794,0.609409,0.417269,0.094133,0.000000,0.155573,1.000000,0.821068,1.000000
201,201,201_0,COMPLETED,Manual,0.058273,0.9610,0.000000,0.735314,0.506411,0.572102,...,0.669469,0.864354,0.886785,1.000000,0.631551,0.434430,0.002487,1.000000,0.707513,1.000000
202,202,202_0,COMPLETED,Manual,0.133911,0.9230,0.162387,0.830371,0.278899,0.337194,...,0.667949,0.778031,0.942415,0.000000,0.021712,0.000000,0.205354,1.000000,0.804467,1.000000
203,203,203_0,COMPLETED,Manual,0.198453,0.9050,0.157859,0.844276,0.275767,0.340950,...,0.677747,0.784094,1.000000,0.000000,0.022551,0.000000,0.202811,1.000000,0.797313,1.000000


In [22]:
parameters, trial_index = ax_client.get_next_trial()

[INFO 11-19 00:58:46] ax.service.ax_client: Generated new trial 205 with parameters {'x1': 0.181411, 'x2': 0.803432, 'x3': 0.268268, 'x4': 0.364369, 'x5': 0.286539, 'x6': 0.71173, 'x7': 0.401039, 'x8': 0.025733, 'x9': 0.584227, 'x10': 0.638681, 'x11': 0.0, 'x12': 0.840831, 'x13': 0.84978, 'x14': 0.005248, 'x15': 0.0, 'x16': 0.164393, 'x17': 0.559678, 'x18': 0.543222, 'x19': 0.859382, 'x20': 0.501981, 'x21': 1.0, 'x22': 0.0, 'x23': 0.866622, 'x24': 0.178469, 'x25': 1.0, 'x26': 0.828297, 'x27': 1.0, 'x28': 0.0, 'x29': 0.083502, 'x30': 0.583558, 'x31': 0.53998, 'x32': 1.0, 'x33': 0.311643, 'x34': 0.728216, 'x35': 0.661124, 'x36': 0.756491, 'x37': 1.0, 'x38': 0.0, 'x39': 0.0, 'x40': 0.0, 'x41': 0.304406, 'x42': 1.0, 'x43': 0.812148, 'x44': 1.0} using model GPEI.


In [23]:
print(parameters)
print(trial_index)

{'x1': 0.18141052270579935, 'x2': 0.8034324904497651, 'x3': 0.26826832763401554, 'x4': 0.36436946408574733, 'x5': 0.2865386872817858, 'x6': 0.7117302621912257, 'x7': 0.40103904330271084, 'x8': 0.025733064781228568, 'x9': 0.5842265181705965, 'x10': 0.6386806151272885, 'x11': 0.0, 'x12': 0.8408313379124558, 'x13': 0.8497801082983985, 'x14': 0.0052481531559081165, 'x15': 0.0, 'x16': 0.16439276659174346, 'x17': 0.5596782792900187, 'x18': 0.5432224112177131, 'x19': 0.8593819711430474, 'x20': 0.50198116295897, 'x21': 1.0, 'x22': 0.0, 'x23': 0.8666217759782995, 'x24': 0.17846908614861848, 'x25': 1.0, 'x26': 0.8282973482748601, 'x27': 1.0, 'x28': 0.0, 'x29': 0.08350155833726731, 'x30': 0.5835577865998803, 'x31': 0.5399798317387333, 'x32': 1.0, 'x33': 0.31164295997967906, 'x34': 0.7282156741508381, 'x35': 0.6611238704433996, 'x36': 0.7564905454970923, 'x37': 1.0, 'x38': 0.0, 'x39': 0.0, 'x40': 0.0, 'x41': 0.3044056674308849, 'x42': 1.0, 'x43': 0.8121482376628346, 'x44': 1.0}
205
