In [None]:
import torch 
import time
from functions import *
from graph_tools import *
from processed_datasets import *
from post_training import *
from nets import SantyxNet
from copy import copy, deepcopy
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

gnn_dataset = (group2_dataset, group2b_dataset,
               aromatics_dataset, aromatics2_dataset,
               amides_dataset, amidines_dataset,
               oximes_dataset, carbamate_esters_dataset,
               group3S_dataset, group3N_dataset,
               group4_dataset, gas_amides_dataset,
               gas_amidines_dataset, gas_aromatics_dataset,
               gas_aromatics2_dataset, gas_carbamate_esters_dataset,
               gas_group2_dataset, gas_group2b_dataset,
               gas_group3N_dataset, gas_group3S_dataset,
               gas_group4_dataset, gas_oximes_dataset) 

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Current device: {}".format(device))
if device == "cuda":
    print("Device name: {}".format(torch.cuda.get_device_name(0)))
    print("CUDA Version: {}".format(torch.version.cuda))
    print("CuDNN Version: {}".format(torch.backends.cudnn.version()))
    
DIM = 128  
EPOCHS = 300          
LOSS_FN = F.l1_loss   
BATCH_SIZE = 32       
SPLITS = 5            
LR = 0.001
PATIENCE = 5   
FACTOR = 0.7   
MIN_LR = 1e-7  

In [None]:

def split_list(a: list, n: int):
    k, m = divmod(len(a), n)
    return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))

def create_loaders_nested(datasets, split=5, batch_size=32):
    """
    Create dataloaders for training+validation and test.
    Args:
        datasets(tuple): tuple containing the HetGraphDataset objects.
        split(int): number of splits to generate train/val/test sets
        batch(int): batch size    
    Returns:
        (tuple): tuple with dataloaders for training, validation and testing.
    """
    chunk = [[] for _ in range(split)]
    for dataset in datasets:
        dataset.shuffle()
        iterator = split_list(dataset, split)
        for index, item in enumerate(iterator):
            chunk[index] += item
        chunk = sorted(chunk, key=len)
    
    for index in range(len(chunk)):
        proxy = copy(chunk)
        test_loader = DataLoader(proxy.pop(index), batch_size=batch_size, shuffle=False)
        for index_2 in range(len(proxy)):
            proxy_2 = copy(proxy)
            val_loader = DataLoader(proxy_2.pop(index_2), batch_size=batch_size, shuffle=False)
            flatten_training = [item for sublist in proxy_2 for item in sublist]
            train_loader = DataLoader(flatten_training, batch_size=batch_size, shuffle=True)
            yield deepcopy((train_loader, val_loader, test_loader))

In [None]:

iterator = create_loaders_nested(gnn_dataset, split=SPLITS, batch_size=BATCH_SIZE)
MAE_outer = []
counter = 0
tot_runs = SPLITS * (SPLITS - 1)
for outer in range(SPLITS):
    MAE_inner = []
    for inner in range(SPLITS - 1):
        counter += 1
        train_loader, val_loader, test_loader = next(iterator)
        train_loader, val_loader, test_loader, mean_tv, std_tv = scale_target(train_loader, val_loader, test_loader,
                                                                              mode="std", verbose=False)
        model = SantyxNet(dim=DIM, node_features=node_features).to(device)
        optimizer = Adam(model.parameters(), lr=LR)  
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=FACTOR, patience=PATIENCE, min_lr=MIN_LR)
        loss_list = []  
        train_list = [] 
        val_list = []   
        test_list = []  
        t0 = time.time()        
        for epoch in range(1, EPOCHS+1):
            lr = scheduler.optimizer.param_groups[0]['lr']
            loss, train_MAE = train_loop(model, device, train_loader, optimizer, LOSS_FN)  
            val_MAE = test_loop(model, val_loader, device, std_tv)            
            scheduler.step(val_MAE)                                                       
            test_MAE = test_loop(model, test_loader, device, std_tv)    
            print('{}/{}-Epoch {:03d}: LR={:.7f}  Loss={:.6f}  Validation MAE: {:.6f} eV, '
                'Test MAE: {:.6f} eV'.format(counter, tot_runs, epoch, lr, loss, val_MAE, test_MAE))
            if epoch == EPOCHS:
                MAE_inner.append(test_MAE)
        print("Training time: {:.2f} s".format(time.time() - t0))
        loss_list.append(loss)
        train_list.append(train_MAE * std_tv)
        val_list.append(val_MAE)
        test_list.append(test_MAE)
        x, y = test_performance("{}-{}".format(counter, tot_runs), model, 
                               train_loader, val_loader, test_loader, 
                               mean_tv, std_tv, 
                               SPLITS, EPOCHS, BATCH_SIZE, 
                               lr, MIN_LR, train_list, val_list, test_list)
        with open("./NestedCrossValidation/Test.csv", mode="a") as outfile:
            for a, b in zip(x, y):
                outfile.write(f"name,{a},ener{counter},{b}\n")
        del model
        if device == "cuda":
            torch.cuda.empty_cache()
    MAE_outer.append(np.mean(MAE_inner))
MAE = np.mean(MAE_outer)
print("NESTED CROSS VALIDATION: MAE = {:.2f}".format(MAE))