## Specify hyperparameter grid

In [15]:
import os
import itertools

prop_dict = {
#    0: 'parp1_qed_sa',
#    1: 'fa7_qed_sa',
    2: '5ht1b_qed_sa',
#    3: 'braf_qed_sa',
#    4: 'jak2_qed_sa',
}

guidance_weight_dict = {
    0: 0.5,
    1: 0.7,
    2: 0.3,
}

hyp_prior_likelihood_cov_scale = [1e-5, 1e-3, 1e-1, 1e0, 1e1, 1e3]
hyp_prior_likelihood_cov_diag = [1e-5, 1e-3, 1e-1, 1e0, 1e1, 1e3]
n_context_points = [256, 512]

hyper_id_list = [
    {
        "prior_likelihood_cov_scale": plcs,
        "prior_likelihood_cov_diag": plcd,
        "n_context_points": ncp,
    }
    for plcs, plcd, ncp in itertools.product(
        hyp_prior_likelihood_cov_scale,
        hyp_prior_likelihood_cov_diag,
        n_context_points,
    )
    if not (
        plcs != 0 and plcd / plcs < 1e-3
    )  # runs with low diag to scale ration always fail
]

print(f"Number of hyperparameter combinations: {len(hyper_id_list) - 1}")

Number of hyperparameter combinations: 57


## Remove any existing configs

In [18]:
prop_train_configs = [f for f in os.listdir(os.getcwd()) if f.endswith('.yaml')]
for f in prop_train_configs:
    os.remove(f)

## Generate new configs

In [17]:
keys = prop_dict.keys()
for i in list(keys):
    for j, hypers in enumerate(hyper_id_list):
        config = f"""data: 
    data: ZINC250k
    context: ZINC500k
    dir: './data'
    batch_size: 1024
    context_size: {hypers['n_context_points']}
    max_node_num: 38
    max_feat_num: 9

sde:
    x:
        type: VP
        beta_min: 0.1
        beta_max: 1.0
        num_scales: 1000
    adj:
        type: VE
        beta_min: 0.2
        beta_max: 1.0
        num_scales: 1000

model:
    model: Regressor
    depth: 3
    nhid: 16
    dropout: 0

train:
    prop: {prop_dict[i]}
    num_epochs: 10
    lr: 0.001
    lr_schedule: False
    reg_type: fseb
    weight_decay: 0
    prior_likelihood_cov_diag: {hypers['prior_likelihood_cov_diag']:1.1e}
    prior_likelihood_cov_scale: {hypers['prior_likelihood_cov_scale']:1.1e}
    lr_decay: 0.999
    eps: 1.0e-5"""

        with open(f'prop_train_{i}_{j}.yaml', 'w') as f:
            f.write(config)
    
    for k in range(len(guidance_weight_dict)):
        sample_config=f"""data:
  data: ZINC250k
  dir: './data'

model:
  diff:
    ckpt: gdss_zinc250k_v2
    predictor: Reverse
    corrector: Langevin
    snr: 0.2
    scale_eps: 0.8
    n_steps: 1
  prop:
    ckpt: fseb_ZINC500k/prop_{prop_dict[i]}
    weight_x: {guidance_weight_dict[k]}   # 0.5, 0.4, 0.6, 0.7, and 0.6 for parp1, fa7, 5ht1b, braf, and jak2, respectively
    weight_adj: 0

sample:
  noise_removal: True
  probability_flow: False
  eps: 1.0e-3
  num_samples: 3000
  ood: 0.04
        """

        with open(f'sample_{i}_{k}.yaml', 'w') as f:
            f.write(sample_config)