In [None]:
import logging
import sys, os
import time
import warnings
import torch
import argparse
from pathlib import Path
import numpy as np
from datetime import datetime
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn import MSELoss, L1Loss
from torch.serialization import save
from gnn.model.metric import EarlyStopping
from gnn.model.gated_solv_network import CrossAttention
from gnn.data.dataset import SolvationDataset, train_validation_test_split, load_mols_labels
from gnn.data.dataloader import DataLoaderSolvation
from gnn.data.grapher import HeteroMoleculeGraph
from gnn.data.featurizer import (SolventAtomFeaturizer, BondAsNodeFeaturizerFull, SolventGlobalFeaturizer)
from gnn.utils import (load_checkpoints,save_checkpoints,seed_torch, pickle_dump, yaml_dump)
from sklearn.metrics import mean_squared_error


  from .autonotebook import tqdm as notebook_tqdm
Using backend: pytorch


In [None]:
#random_seed=2

dataset_file ="./input_data/FreeSolv.csv"
save_dir = "./outputnew"
output_file = "results.pkl"
dielectric_constants=None
molecular_volume = False
molecular_refractivity = False
dataset_state_dict_filename = None


feature_scaling= True
solvent_split = None
element_split = None
scaffold_split = False
attention_map = False
distributed = 0
restore=0
batch_size = 100
attention_map = False
embedding_size =24

# gated layer
gated_num_layers =3
gated_hidden_size =[192, 192, 192]
gated_num_fc_layers =2
gated_graph_norm =1
gated_batch_norm = 1
gated_activation= "LeakyReLU"
gated_residual =1
gated_dropout = 0.4

# readout layer
num_lstm_iters =6 #,help="number of iterations for the LSTM in set2set readout layer")
num_lstm_layers=3 # help="number of layers for the LSTM in set2set readout layer")

# fc layer
fc_num_layers =2
fc_hidden_size = [64, 32]
fc_batch_norm =0
fc_activation = "LeakyReLU"
fc_dropout = 0.2

start_epoch=0
epochs = 1000
lr = 0.001
weight_decay =0.0

In [18]:
def grapher(dielectric_constant=None, mol_volume=False, mol_refract=False):
    atom_featurizer  = SolventAtomFeaturizer()
    bond_featurizer  = BondAsNodeFeaturizerFull(length_featurizer=None, dative=False)
    global_featurizer= SolventGlobalFeaturizer(dielectric_constant=dielectric_constant, mol_volume=mol_volume, mol_refract=mol_refract)

    grapher = HeteroMoleculeGraph(atom_featurizer, bond_featurizer, global_featurizer, self_loop=True)

    return grapher

In [None]:
seed_torch(random_seed)

mols, labels = load_mols_labels(dataset_file)
dataset = SolvationDataset(solute_grapher = grapher(mol_volume=molecular_volume, mol_refract = molecular_refractivity),
                           solvent_grapher = grapher(mol_volume=molecular_volume, mol_refract = molecular_refractivity),
            molecules = mols, labels = labels, solute_extra_features = None, solvent_extra_features = None, feature_transformer = False,
            label_transformer= False, state_dict_filename=dataset_state_dict_filename)

In [20]:
# Save the solute and solvent graphers for loading datasets later
pickle_dump([dataset.solute_grapher, dataset.solvent_grapher], os.path.join(save_dir,"graphers.pkl")) 

In [21]:
dataset

Dataset SolvationDataset
Length: 642
Solvent feature: atom, size: 28
Solvent feature: bond, size: 11
Solvent feature: global, size: 3
Solute feature: atom, size: 28
Solute feature: bond, size: 11
Solute feature: global, size: 3
Solute feature: atom, name: ['total degree', 'partial/formal charge', 'is aromatic', 'is in ring', 'num lone pairs', 'num total H', 'H bond acceptor', 'H bond donor', 'chemical symbol', 'chemical symbol', 'chemical symbol', 'chemical symbol', 'chemical symbol', 'chemical symbol', 'chemical symbol', 'chemical symbol', 'chemical symbol', 'chemical symbol', 'hybridization', 'hybridization', 'hybridization', 'hybridization', 'ring size', 'ring size', 'ring size', 'ring size', 'ring size']
Solute feature: bond, name: ['in_ring', 'conjugated', 'ring size', 'ring size', 'ring size', 'ring size', 'ring size', 'single', 'double', 'triple', 'aromatic']
Solute feature: global, name: ['num atoms', 'num bonds', 'molecule weight']
Solvent feature: atom, name: ['total degree',

In [None]:
best = np.finfo(np.float32).max
os.makedirs(save_dir, exist_ok=True)
if (solvent_split is None) and (element_split is None) and (scaffold_split is False):
    print(f'Splitting data using random seed {random_seed}')
    trainset, valset, testset = train_validation_test_split(dataset, validation=0.1, test=0.1, random_seed=random_seed)


Splitting data using random seed 2


In [23]:
# Scale training dataset features
if feature_scaling:
    solute_features_scaler, solvent_features_scaler = trainset.normalize_features()
    valset.normalize_features(solute_features_scaler, solvent_features_scaler)
    testset.normalize_features(solute_features_scaler, solvent_features_scaler)
else:
    solute_features_scaler, solvent_features_scaler = None, None
label_scaler = trainset.normalize_labels()

#if not distributed or (distributed and gpu == 0):
torch.save(dataset.state_dict(), "./output/dataset_state_dict.pkl")
print( "Trainset size: {}, valset size: {}: testset size: {}.".format(len(trainset), len(valset), len(testset)))
#if distributed:
#    train_sampler = torch.utils.data.distributed.DistributedSampler(trainset)
#else:
train_sampler = None

Standard deviation for feature 1 is 0.0, smaller than 0.001. You may want to exclude this feature.
Standard deviation for feature 2 is 0.0, smaller than 0.001. You may want to exclude this feature.
Standard deviation for feature 3 is 0.0, smaller than 0.001. You may want to exclude this feature.
Standard deviation for feature 6 is 0.0, smaller than 0.001. You may want to exclude this feature.
Standard deviation for feature 7 is 0.0, smaller than 0.001. You may want to exclude this feature.
Standard deviation for feature 8 is 0.0, smaller than 0.001. You may want to exclude this feature.
Standard deviation for feature 9 is 0.0, smaller than 0.001. You may want to exclude this feature.
Standard deviation for feature 10 is 0.0, smaller than 0.001. You may want to exclude this feature.
Standard deviation for feature 11 is 0.0, smaller than 0.001. You may want to exclude this feature.
Standard deviation for feature 13 is 0.0, smaller than 0.001. You may want to exclude this feature.
Standar

Trainset size: 514, valset size: 64: testset size: 64.


In [None]:
train_loader = DataLoaderSolvation(trainset, batch_size = batch_size, shuffle = (train_sampler is None), sampler = train_sampler)
bs = max(len(valset) // 10, 1)
val_loader = DataLoaderSolvation(valset, batch_size=bs, shuffle=False)
bs = max(len(testset) // 10, 1)
test_loader = DataLoaderSolvation(testset, batch_size=bs, shuffle=False)
### model
feature_names = ["atom", "bond", "global"]
set2set_ntypes_direct = ["global"]
solute_feature_size = dataset.feature_sizes[0]
solvent_feature_size = dataset.feature_sizes[1]

In [None]:
model =  CrossAttention(solute_in_feats=solute_feature_size, solvent_in_feats=solvent_feature_size,
        embedding_size=embedding_size,gated_num_layers=gated_num_layers, gated_hidden_size=gated_hidden_size,
        gated_num_fc_layers=gated_num_fc_layers,gated_graph_norm=gated_graph_norm,gated_batch_norm=gated_batch_norm,
        gated_activation=gated_activation, gated_residual=gated_residual, gated_dropout=gated_dropout,num_lstm_iters=num_lstm_iters,
        num_lstm_layers=num_lstm_layers, set2set_ntypes_direct=set2set_ntypes_direct,fc_num_layers=fc_num_layers,
        fc_hidden_size=fc_hidden_size, fc_batch_norm=fc_batch_norm, fc_activation=fc_activation, fc_dropout=fc_dropout, outdim=1, conv="GatedGCNConv")

Cross interaction map


In [27]:
optimizer = torch.optim.Adam( model.parameters(), lr=lr, weight_decay=weight_decay)
loss_func = MSELoss(reduction="mean")
metric = L1Loss(reduction="sum")
scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.4, patience=50, verbose=True)
stopper = EarlyStopping(patience=150)
state_dict_objs = {"model": model, "optimizer": optimizer, "scheduler": scheduler}

In [28]:
def train(optimizer, model, nodes, data_loader, loss_fn, metric_fn):
    import torch
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.train()
    epoch_loss = 0.0
    accuracy = 0.0
    count = 0.0

    for it, (solute_batched_graph, solvent_batched_graph, label) in enumerate(data_loader):
        solute_feats = {nt: solute_batched_graph.nodes[nt].data["feat"] for nt in nodes}
        solvent_feats = {nt: solvent_batched_graph.nodes[nt].data["feat"] for nt in nodes}
        target = torch.squeeze(label["value"])
        solute_norm_atom = label["solute_norm_atom"]
        solute_norm_bond = label["solute_norm_bond"]
        solvent_norm_atom = label["solvent_norm_atom"]
        solvent_norm_bond = label["solvent_norm_bond"]
        #stdev = label["scaler_stdev"]

        if device is not None:
            solute_feats = {k: v.to(device) for k, v in solute_feats.items()}
            solvent_feats = {k: v.to(device) for k, v in solvent_feats.items()}
            target = target.to(device)
            solute_norm_atom = solute_norm_atom #.to(device)
            solute_norm_bond = solute_norm_bond #.to(device)
            solvent_norm_atom = solvent_norm_atom #.to(device)
            solvent_norm_bond = solvent_norm_bond #.to(device)
            #stdev = stdev.to(device)
        
        pred = model(solute_batched_graph, solvent_batched_graph, solute_feats, solvent_feats, solute_norm_atom, solute_norm_bond, 
                     solvent_norm_atom, solvent_norm_bond)
        pred = pred.view(-1)
        target = target.view(-1)

        loss = loss_fn(pred, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.detach().item()
        accuracy += metric_fn(pred, target).detach().item()
        count += len(target)
    
    epoch_loss /= it + 1
    accuracy /= count

    return epoch_loss, accuracy


In [29]:
def evaluate(model, nodes, data_loader, metric_fn, scaler = None, return_preds=False):
    import torch
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.eval()

    with torch.no_grad():
        accuracy = 0.0
        count = 0.0

        preds = []
        y_true = []

        for solute_batched_graph, solvent_batched_graph, label in data_loader:
            solute_feats = {nt: solute_batched_graph.nodes[nt].data["feat"] for nt in nodes}
            solvent_feats = {nt: solvent_batched_graph.nodes[nt].data["feat"] for nt in nodes}
            target = torch.squeeze(label["value"])
            #stdev = label["scaler_stdev"]
            solvent_norm_atom = label["solvent_norm_atom"]
            solvent_norm_bond = label["solvent_norm_bond"]
            solute_norm_atom = label["solute_norm_atom"]
            solute_norm_bond = label["solute_norm_bond"]
            
            if device is not None:
                solute_feats = {k: v.to(device) for k, v in solute_feats.items()}
                solvent_feats = {k: v.to(device) for k, v in solvent_feats.items()}
                target = target.to(device)
                solute_norm_atom = solute_norm_atom.to(device)
                solute_norm_bond = solute_norm_bond.to(device)
                solvent_norm_atom = solvent_norm_atom.to(device)
                solvent_norm_bond = solvent_norm_bond.to(device)
                #stdev = stdev.to(device)

            pred = model(solute_batched_graph, solvent_batched_graph, solute_feats, 
                     solvent_feats, solute_norm_atom, solute_norm_bond, 
                     solvent_norm_atom, solvent_norm_bond)
            pred = pred.view(-1)
            target = target.view(-1)

            # Inverse scale 
            if scaler is not None:
                pred = scaler.inverse_transform(pred.cpu())
                pred = pred.to(device)

            accuracy += metric_fn(pred, target).detach().item()
            count += len(target)
            
            batch_pred = pred.tolist()
            batch_target = target.tolist()
            preds.extend(batch_pred)
            y_true.extend(batch_target)

    if return_preds:
        return y_true, preds

    else:
        return accuracy / count


In [16]:
for epoch in range(start_epoch, 100):
    ti = time.time()
    loss, train_acc = train(optimizer, model, feature_names, train_loader, loss_func, metric)

    if np.isnan(loss):
        print("\n\nBad, we get nan for loss. Exiting")
        sys.stdout.flush()
        sys.exit(1)
    #evaluate
    val_acc = evaluate(model, feature_names, val_loader, metric, label_scaler)
    if stopper.step(val_acc):
        pickle_dump(best, os.path.join(save_dir, output_file))  # save results for hyperparam tune
        break
    scheduler.step(val_acc)
    is_best = val_acc < best
    if is_best:
        best = val_acc
        #print("best", best)
    # save checkpoint
    misc_objs = {"best": best, "epoch": epoch}
    scaler_objs = {'label_scaler': {'means': label_scaler.mean, 'stds': label_scaler.std} if label_scaler is not None else None,
                             'solute_features_scaler': {'means': solute_features_scaler.mean,'stds': solute_features_scaler.std} if solute_features_scaler is not None else None,
                              'solvent_features_scaler': {'means': solvent_features_scaler.mean,'stds': solvent_features_scaler.std} if solvent_features_scaler is not None else None}
    save_checkpoints(state_dict_objs, misc_objs,scaler_objs,is_best, msg=f"epoch: {epoch}, score {val_acc}", save_dir=save_dir)
    tt = time.time() - ti
    if epoch ==0: 
        print("{:5}   {:12}   {:12}   {:12}   {:5}".format("Epoch", "MSE Loss", "Train MAE", "Val MAE", "Time (s)"))
    print("{:5d}   {:12.6e}   {:12.6e}   {:12.6e}   {:.2f}".format(epoch, loss, train_acc, val_acc, tt))
    if epoch % 10 == 0:
        sys.stdout.flush()

Epoch   MSE Loss       Train MAE      Val MAE        Time (s)
    0   1.232183e+00   7.850856e-01   2.512466e+00   18.36
    1   9.359962e-01   7.269607e-01   2.439717e+00   15.99
    2   7.265916e-01   6.413350e-01   2.333654e+00   15.67
    3   6.031480e-01   6.020138e-01   2.174176e+00   15.39
    4   4.287728e-01   4.902507e-01   1.615106e+00   13.57
    5   4.718710e-01   4.963723e-01   1.933856e+00   13.30
    6   3.546169e-01   4.421367e-01   1.471781e+00   13.42
    7   3.176426e-01   4.287195e-01   1.318893e+00   13.69
    8   3.033718e-01   4.154650e-01   1.208500e+00   13.63
    9   3.263239e-01   3.721361e-01   1.399044e+00   13.78
   10   2.616111e-01   3.825664e-01   1.112234e+00   13.97
   11   2.553513e-01   3.836095e-01   1.352343e+00   14.60
   12   2.901275e-01   3.904938e-01   1.026652e+00   13.81
   13   2.465743e-01   3.777791e-01   1.115096e+00   13.48
   14   2.085585e-01   3.424566e-01   1.083888e+00   13.86
   15   2.551448e-01   3.576358e-01   1.158770e+00   

In [17]:
test_acc = evaluate(model, feature_names, test_loader, metric, label_scaler)
y_true, y_pred = evaluate(model, feature_names, test_loader, metric, label_scaler, return_preds=True)
print("\n#Test MAE: {:12.6e} \n".format(test_acc))
print("\n#Test RMSE: {:12.6e} \n".format(mean_squared_error(y_true, y_pred, squared=False)))
print("\nFinish training at:", datetime.now())
results_dict = {'y_true': y_true, 'y_pred': y_pred}
pickle_dump(results_dict, os.path.join(save_dir, f'seed_{random_seed}_test_results.pkl'))


#Test MAE: 8.393481e-01 


#Test RMSE: 1.179982e+00 


Finish training at: 2025-03-10 01:32:57.467913


Test best model saved during training and save the Test results 

In [None]:
checkpoint = load_checkpoints(state_dict_objs, save_dir, filename="best_checkpoint.pkl")

In [31]:
test_acc = evaluate(model, feature_names, test_loader, metric, label_scaler)
y_true, y_pred = evaluate(model, feature_names, test_loader, metric, label_scaler, return_preds=True)
print("\n#Test MAE: {:12.6e} \n".format(test_acc))
print("\n#Test RMSE: {:12.6e} \n".format(mean_squared_error(y_true, y_pred, squared=False)))
print("\nFinish training at:", datetime.now())
results_dict = {'y_true': y_true, 'y_pred': y_pred}
pickle_dump(results_dict, os.path.join(save_dir, f'seed_{random_seed}_test_results.pkl'))


#Test MAE: 6.119657e-01 


#Test RMSE: 8.546843e-01 


Finish training at: 2025-03-10 01:37:52.723112


In [32]:
file_path = os.path.join(save_dir, f'seed_{random_seed}_test_results.pkl')

In [26]:
import pickle
import os
with open(file_path, 'rb') as file:
    results_dict = pickle.load(file)

# Access the values
y_true = results_dict['y_true']
y_pred = results_dict['y_pred']
# Print or process the data
print("y_true:", y_true)
print("y_pred:", y_pred)

y_true: [-2.3399999141693115, -4.849999904632568, -1.4500000476837158, -2.440000057220459, 1.5800000429153442, -2.4000000953674316, -4.289999961853027, -8.149999618530273, -5.820000171661377, -3.680000066757202, -3.0399999618530273, -5.659999847412109, -0.23000000417232513, -10.210000038146973, -4.150000095367432, 1.6799999475479126, -4.119999885559082, -9.649999618530273, -3.809999942779541, 0.009999999776482582, -9.729999542236328, -3.130000114440918, -0.5, -5.03000020980835, -2.009999990463257, -2.809999942779541, -2.549999952316284, -16.43000030517578, 0.4000000059604645, -2.9200000762939453, -4.059999942779541, -2.2100000381469727, -2.259999990463257, -3.450000047683716, -4.090000152587891, -9.520000457763672, 1.2300000190734863, -8.260000228881836, -5.739999771118164, -5.900000095367432, -2.869999885559082, 0.07999999821186066, -0.8199999928474426, -4.429999828338623, -1.9500000476837158, 1.309999942779541, -1.590000033378601, -4.690000057220459, -0.9900000095367432, -0.449999988