In [26]:
import torch
import os
import sys

import numpy as np
from tqdm import tqdm
cwd = os.getcwd()
parent_dir = os.path.dirname(cwd)
sys.path.append(parent_dir)

from DataPipeline.preprocessing import node_encoder, tensor_to_smiles
from generation import Sampling_Path_Batch
from models import Model_GNNs

from dataclasses import dataclass

In [27]:
@dataclass
class Experiment:
    exp_name: str
    encod: str
    keku: bool
    train: bool
    encoding_size: int = 13
    edge_size: int = 3
    encoding_option: str = 'charged'
    compute_lambdas: bool = False

exp = Experiment('GNN_baseline_3_modif', 'charged', True, False, 13, 3, 'charged', True)
exp2 = Experiment('GNN_baseline_3_modif_debiased', 'charged', True, False, 13, 3, 'charged', True)

In [28]:
from rdkit import Chem
from rdkit.Chem import Descriptors

def logP(smiles_list):
    logP_values = []
    for smiles in smiles_list:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            logP_values.append(0)
        else: 
            logp = Descriptors.MolLogP(mol)
            if logp > 2.0 and logp < 2.5:
                logP_values.append(1)
            else:
                logP_values.append(0)
    return torch.tensor(logP_values)

def QED(smiles_list):
    qed_values = []
    for smiles in smiles_list:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            qed_values.append(0)
        else: 
            qed = Descriptors.qed(mol)
            if qed > 0.90:
                qed_values.append(1)
            else:
                qed_values.append(0)
    return torch.tensor(qed_values)

In [29]:
GNNs_q = Model_GNNs(exp)
GNNs_a = Model_GNNs(exp2)
GNNs_pi = Model_GNNs(exp2)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Module_Gen = Sampling_Path_Batch(GNNs_q, GNNs_a, GNNs_pi, features = {'logP' : logP, 'QED' : QED}, lambdas = torch.Tensor([1.0, 1.0]), device = device, args=exp)


Loading best checkpoint number 1 of the epoch 2050.0 with a loss of 0.5107812829972882
Loading best checkpoint number 1 of the epoch 2550.0 with a loss of 0.1984780294417413
Loading best checkpoint number 2 of the epoch 2700.0 with a loss of 10.3904066518488
..\trained_models\GNN_baseline_3_modif\GNN1_baseline\history_training\checkpoint_1.pt ..\trained_models\GNN_baseline_3_modif\GNN2_baseline\history_training\checkpoint_1.pt ..\trained_models\GNN_baseline_3_modif\GNN3_split_two_without_node_embedding\history_training\checkpoint_2.pt


Loading best checkpoint number 1 of the epoch 2950.0 with a loss of 0.5103764619447081
Loading best checkpoint number 2 of the epoch 2150.0 with a loss of 0.19804639930989767
Loading best checkpoint number 2 of the epoch 2700.0 with a loss of 10.3904066518488
..\trained_models\GNN_baseline_3_modif_debiased\GNN1_charged_baseline_debiased\history_training\checkpoint_1.pt ..\trained_models\GNN_baseline_3_modif_debiased\GNN2_charged_baseline_debiased\history_training\checkpoint_2.pt ..\trained_models\GNN_baseline_3_modif_debiased\GNN3_split_two_without_node_embedding\history_training\checkpoint_2.pt
Loading best checkpoint number 1 of the epoch 2950.0 with a loss of 0.5103764619447081
Loading best checkpoint number 2 of the epoch 2150.0 with a loss of 0.19804639930989767
Loading best checkpoint number 2 of the epoch 2700.0 with a loss of 10.3904066518488
..\trained_models\GNN_baseline_3_modif_debiased\GNN1_charged_baseline_debiased\history_training\checkpoint_1.pt ..\trained_models\GNN_bas

In [30]:
def compute_optimal_lambdas(Module_Gen, desired_moments, sample_size=5, n_iters=1000, lr=.5, min_nabla_lambda = 0.001, batch_size = 1000): #how do they define the learning rate and sample size maybe do more
        """
        This performs the first step: Constraints --> EBM through self-normalized importance sampling. 
        Args:
            sample_size: total number of samples to use for lambda computation
        Returns:
            dicitonary of optimal lambdas per constraint: {'black': lambda_1, 'positive': lambda_2}
        """


        print("Computing Optimal Lambdas for desired moments...")

        max_n_iters = n_iters

        feature_names = list(Module_Gen.features.keys())
        mu_star = desired_moments #name mu_bar in the pseudo code

        mu_star = torch.tensor([mu_star[f] for f in feature_names])
        lambdas = Module_Gen.lambdas.cpu()

        # Collect sample_size samples for this:
        list_feature_tensor = []
        for i in tqdm(range(sample_size)):
            #if we do multi processing, i think here the best 
            #put pi without grad for lambdas

            Module_Gen.full_generation(batch_size = batch_size)
            Module_Gen.convert_to_smiles()
            Module_Gen.compute_features()

            q_value = Module_Gen.q_value
            a_value = Module_Gen.a_value
            pi_value = Module_Gen.pi_value

            batch_features_values = Module_Gen.all_features_values

            Module_Gen.clean_memory()
            
            list_feature_tensor.append(batch_features_values)

        all_feature_tensor = torch.cat(list_feature_tensor, dim=0)  # [sample_sz*size_batch x F]

        print("mean logP : ", all_feature_tensor[:, 0].mean())
        print("mean QED : ", all_feature_tensor[:, 1].mean())

        #### check for zero-occuring features. 
        # If a constraint has not occurred in your sample, no lambdas will be learned for that constraint, so we must check.

        for i, feature  in enumerate(feature_names):
            assert all_feature_tensor[:, i].sum().item() > 0, "constraint {feature} hasn't occurred in the samples, use a larger sample size"

        for step in range(max_n_iters): #SGD for finding lambdas

            # 1. calculate P_over_q batch wise with current lambdas which will be name w
            ## compute new exponents

            w = torch.exp(torch.matmul(all_feature_tensor, lambdas.to(all_feature_tensor.device)))
            print(w.shape)
            print(w)
            print(all_feature_tensor.shape)
            print(all_feature_tensor)
            print(lambdas.shape)
            print(lambdas)

            # 2. compute mu (mean) of features given the current lambda using SNIS
            mu_lambda_numerator = w.view(1, -1).matmul(all_feature_tensor).squeeze(0) # F
            print(mu_lambda_numerator.shape)
            print(mu_lambda_numerator)
            mu_lambda_denominator = w.sum()
            mu_lambda = mu_lambda_numerator / mu_lambda_denominator # F

            # 3. Update current Lambdas
            nabla_lambda = mu_star - mu_lambda.cpu()
            err = np.linalg.norm(nabla_lambda.cpu().numpy())
            print("step: %s \t ||nabla_lambda|| = %.6f" %(step, err))
            lambdas = lambdas + lr * nabla_lambda
            print("\tlambdas : {} ".format(Module_Gen.lambdas))
            print("\tμ: {}".format(mu_lambda))
            print("\tμ*: {}".format(mu_star))

            Module_Gen.lambdas = lambdas
            
            ## Check if error is less than tolerance, then break.
            if err < min_nabla_lambda: 
                break

In [31]:
compute_optimal_lambdas(Module_Gen, {'logP': 1.0, 'QED': 1.0}, sample_size=5, n_iters=10000, lr=.2, min_nabla_lambda = 0.0001, batch_size=1000)

Computing Optimal Lambdas for desired moments...


  0%|          | 0/5 [00:00<?, ?it/s]

[16:54:26] Explicit valence for atom # 17 O, 3, is greater than permitted
[16:54:27] Explicit valence for atom # 17 O, 3, is greater than permitted
 40%|████      | 2/5 [02:07<03:14, 64.70s/it][16:56:44] Explicit valence for atom # 18 O, 3, is greater than permitted
[16:56:45] Explicit valence for atom # 18 O, 3, is greater than permitted
100%|██████████| 5/5 [05:59<00:00, 71.93s/it]


mean logP :  tensor(0.1244, device='cuda:0')
mean QED :  tensor(0.0676, device='cuda:0')
torch.Size([5000])
tensor([1., 1., 1.,  ..., 1., 1., 1.], device='cuda:0')
torch.Size([5000, 2])
tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        ...,
        [0., 0.],
        [0., 0.],
        [0., 0.]], device='cuda:0')
torch.Size([2])
tensor([1., 1.])
torch.Size([2])
tensor([1989.7008, 1217.7089], device='cuda:0')
step: 0 	 ||nabla_lambda|| = 1.085504
	lambdas : tensor([1., 1.], device='cuda:0') 
	μ: tensor([0.2910, 0.1781], device='cuda:0')
	μ*: tensor([1., 1.])
torch.Size([5000])
tensor([1., 1., 1.,  ..., 1., 1., 1.], device='cuda:0')
torch.Size([5000, 2])
tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        ...,
        [0., 0.],
        [0., 0.],
        [0., 0.]], device='cuda:0')
torch.Size([2])
tensor([1.1418, 1.1644])
torch.Size([2])
tensor([2390.2107, 1520.1995], device='cuda:0')
step: 1 	 ||nabla_lambda|| = 1.042485
	lambdas : tensor([1.1418, 1.1644]) 
	μ: tensor

In [32]:
tensor([7.6847, 8.2535])  5000
tensor([7.6169, 8.3329])   6000
lambdas : tensor([7.8733, 8.4059])  1000

SyntaxError: invalid syntax (794465100.py, line 1)