In [21]:
#import: standard libraries, plus the classes
from model_utils import SingleTaskModelTrainer, MultiTaskModelTrainer
from ray import tune
import torch
import pandas as pd
import sys
sys.path.append('modules')
from utils_data import get_graphs #will be changed later by Riccardo
from sklearn.model_selection import train_test_split

In [2]:
#just an example, we can think about different params/ranges
hp_search_config = {
        "lr": tune.loguniform(1e-5, 1e-2),
        "batch_size": tune.choice([8, 16]),
        "hidden_channels": tune.choice([32, 64, 128]),
        "num_layers": tune.choice([2, 3, 4]),
        "num_timesteps": tune.choice([1, 2, 3]),
        "gamma": tune.loguniform(0.9, 0.99),
        "Scheduler": tune.choice(["ReduceLROnPlateau", "ExponentialLR"]),
    }

In [3]:
#start with getting the data
train = pd.read_csv("train.csv")
train_graphs_DASH_charge_scaled = get_graphs(train,dash_charges=True,scaled =True,save_graphs = True)

Loading previously created graphs


In [None]:
#we need to split double: we use the first validation set to tune our hyperparameters, and then a second one to be used for early stopping of the final model. We could have smaller sets I think
train_data, val_data = train_test_split(train_graphs_DASH_charge_scaled, test_size=0.2, random_state=2000)
val1_data, val2_data = train_test_split(val_data, test_size=0.5, random_state=2000)

In [None]:
#because for some reason I run into memory issues, did not have this before, to fix. now it will give crap results for the full thing because much less data
from random import sample 
train_data_hp_opt = sample(train_data, 1000)
val_data_hp_opt = sample(val1_data, 100)

In [6]:
example_mtl_model = MultiTaskModelTrainer(sandbox=True,verbose=True,name='example_MTL',seed = 18012000,train_data = train_data_hp_opt,val_data=val_data_hp_opt)

In [7]:
example_mtl_model.tune_hyperparameters(config=hp_search_config,num_samples=1,max_num_epochs=1,gpus_per_trial=1,cpus_per_trial=16)

2024-03-21 14:40:53,810	INFO worker.py:1715 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8266 [39m[22m
2024-03-21 14:40:56,056	INFO tune.py:583 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949


0,1
Current time:,2024-03-21 14:41:03
Running for:,00:00:07.63
Memory:,44.0/62.5 GiB

Trial name,status,loc,Scheduler,batch_size,gamma,hidden_channels,lr,num_layers,num_timesteps,iter,total time (s),kendall_tau
train_model_with_ray_a4c0a_00000,TERMINATED,129.132.218.153:2021480,ExponentialLR,8,0.958413,32,0.000606811,2,3,1,4.52228,0.315335


Trial name,kendall_tau,should_checkpoint
train_model_with_ray_a4c0a_00000,0.315335,True


2024-03-21 14:41:03,717	INFO tune.py:1042 -- Total run time: 7.66 seconds (7.63 seconds for the tuning loop).


Best trial config: {'lr': 0.000606810841366676, 'batch_size': 8, 'hidden_channels': 32, 'num_layers': 2, 'num_timesteps': 3, 'gamma': 0.9584126887372598, 'Scheduler': 'ExponentialLR'}
Best trial final kendall_tau: 0.315335176041088


In [8]:
#change the train and val data back first
example_mtl_model.train_data = train_data
example_mtl_model.val_data = val1_data

example_mtl_model.train_and_validate(num_epochs=50, save_models=True, es_patience=10, save_losses=True)

[36m(func pid=2021480)[0m Checkpoint successfully created at: Checkpoint(filesystem=local, path=/home/cschiebroek/ray_results/train_model_with_ray_2024-03-21_14-40-56/train_model_with_ray_a4c0a_00000_0_Scheduler=ExponentialLR,batch_size=8,gamma=0.9584,hidden_channels=32,lr=0.0006,num_layers=2,num_2024-03-21_14-40-56/checkpoint_000000)


Epoch 1: Train Loss: 0.4068, Val Loss: 0.3968
Epoch 2: Train Loss: 0.3910, Val Loss: 0.3897


KeyboardInterrupt: 

In [27]:
#example: how to load and predict
import torch
from torch_geometric.nn.models import AttentiveFP
import pickle
from torch_geometric.loader import DataLoader
import pandas as pd
model = AttentiveFP(in_channels=24, hidden_channels=200, out_channels=14,
                    edge_dim=11, num_layers=2, num_timesteps=2,
                    dropout=0.0)
model.load_state_dict(torch.load("/localhome/cschiebroek/other/Digital_Chemistry/sandbox/models/example_MTL_seed_18012000.pt"))
model.eval()
train = pd.read_csv("train.csv")
train_graphs_DASH_charge_scaled = get_graphs(train,dash_charges=True,scaled =True,save_graphs = True)
train_loader = DataLoader(train_graphs_DASH_charge_scaled, batch_size=32, shuffle=False)
for data in train_loader:
    out = model(data.x, data.edge_index, data.edge_attr,data.batch)
    print(out)
    break


Loading previously created graphs
tensor([[-0.5214, -0.7871, -0.9921, -0.9109, -0.9067, -0.1491, -0.9188, -0.9088,
         -1.0701, -0.9206, -0.9765, -1.0306, -0.0363, -0.8770],
        [-0.4625, -0.4841, -0.8560, -0.9216, -1.0713, -0.1384, -0.9661, -0.9778,
         -0.9552, -0.9406, -0.9719, -1.0371, -0.6547, -0.2935],
        [-0.6098, -0.2845, -0.8240, -0.9208, -1.0263, -0.2171, -0.9865, -0.9963,
         -0.8447, -0.9473, -0.9377, -0.9593, -0.0413, -0.2785],
        [-0.9883, -0.9365, -1.0168, -1.1672, -1.0389, -1.1151, -1.1015, -0.6023,
         -1.0092, -1.0607, -1.0325, -1.1102, -0.0531,  0.5089],
        [-0.1845, -0.6231, -0.5642, -1.0305, -0.9269,  0.4017, -1.0493, -1.0526,
         -0.8965, -1.0288, -0.9095, -1.0694, -0.4504, -0.4449],
        [-0.3976, -0.5278, -0.6276, -1.0880, -1.0510, -0.0174, -1.0960, -1.1113,
         -0.7685, -0.9971, -0.9589, -1.0912, -0.3168, -0.3657],
        [-0.6970, -0.0869, -0.9033, -0.8946, -1.0130, -0.1053, -0.9602, -0.9756,
         -0.913

In [28]:
#save model
torch.save(model,'model_full_test.pt')
#load model
model = torch.load('model_full_test.pt')
train_loader = DataLoader(train_graphs_DASH_charge_scaled, batch_size=32, shuffle=False)
for data in train_loader:
    out = model(data.x, data.edge_index, data.edge_attr,data.batch)
    print(out)
    break


tensor([[-0.5214, -0.7871, -0.9921, -0.9109, -0.9067, -0.1491, -0.9188, -0.9088,
         -1.0701, -0.9206, -0.9765, -1.0306, -0.0363, -0.8770],
        [-0.4625, -0.4841, -0.8560, -0.9216, -1.0713, -0.1384, -0.9661, -0.9778,
         -0.9552, -0.9406, -0.9719, -1.0371, -0.6547, -0.2935],
        [-0.6098, -0.2845, -0.8240, -0.9208, -1.0263, -0.2171, -0.9865, -0.9963,
         -0.8447, -0.9473, -0.9377, -0.9593, -0.0413, -0.2785],
        [-0.9883, -0.9365, -1.0168, -1.1672, -1.0389, -1.1151, -1.1015, -0.6023,
         -1.0092, -1.0607, -1.0325, -1.1102, -0.0531,  0.5089],
        [-0.1845, -0.6231, -0.5642, -1.0305, -0.9269,  0.4017, -1.0493, -1.0526,
         -0.8965, -1.0288, -0.9095, -1.0694, -0.4504, -0.4449],
        [-0.3976, -0.5278, -0.6276, -1.0880, -1.0510, -0.0174, -1.0960, -1.1113,
         -0.7685, -0.9971, -0.9589, -1.0912, -0.3168, -0.3657],
        [-0.6970, -0.0869, -0.9033, -0.8946, -1.0130, -0.1053, -0.9602, -0.9756,
         -0.9136, -0.9233, -0.9744, -0.9624, -0.0

torch_geometric.nn.models.attentive_fp.AttentiveFP