In [1]:
import numpy as np
import torch
from library.GCN import *
from pathlib import Path
from torch.utils.data import DataLoader,SubsetRandomSampler
from sklearn.model_selection import StratifiedKFold
import itertools

In [2]:
#### Fix seeds
np.random.seed(10)
torch.manual_seed(10)
use_GPU = torch.cuda.is_available()

#### Inputs
max_atoms = 30 # fixed value
node_vec_len = 16 # fixed value
batch_size = 256
hidden_nodes = 64
n_conv_layers = 3
n_hidden_layers = 2
learning_rate = 0.005
n_epochs = 30

#### Start by creating dataset
main_path = Path.cwd().parents[0]
data_path = main_path / "data" / "RDKit" / "rdkit_only_valid_smiles_qm9.pkl"
dataset = GraphData(dataset_path=data_path, max_atoms=max_atoms, 
                        node_vec_len=node_vec_len)

dataset_indices = np.arange(0, len(dataset), 1)

#### Split data into training and test sets

y = np.array([float(dataset[i][1]) for i in range(len(dataset))])

num_bins = 10
gap_bins_outer = pd.qcut(y, q=num_bins, labels=False)

outer_seed = 42
inner_seed = 123

outer_cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=outer_seed)

temp = outer_cv.split(dataset_indices, gap_bins_outer)
outer_fold_splits = [x for x in temp]

for outer_fold_idx in range(5):
    train_val_idx = outer_fold_splits[outer_fold_idx][0]
    test_idx = outer_fold_splits[outer_fold_idx][1]

    # Create dataoaders
    train_loader = DataLoader(dataset, batch_size=batch_size, 
                            sampler=SubsetRandomSampler(train_val_idx), 
                            collate_fn=collate_graph_dataset)
    test_loader = DataLoader(dataset, batch_size=batch_size, 
                            sampler=SubsetRandomSampler(test_idx),
                            collate_fn=collate_graph_dataset)

    #### Initialize model, standardizer, optimizer, and loss function
    # Model
    model = ChemGCN(node_vec_len=node_vec_len, node_fea_len=hidden_nodes,
                    hidden_fea_len=hidden_nodes, n_conv=n_conv_layers, 
                    n_hidden=n_hidden_layers, n_outputs=1, p_dropout=0.1)
    # Transfer to GPU if needed
    if use_GPU:
        model.cuda()

    # Standardizer
    outputs = [dataset[i][1] for i in range(len(dataset))]
    standardizer = Standardizer(torch.Tensor(outputs))

    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Loss function
    loss_fn = torch.nn.MSELoss()
    # loss_fn = torch.nn.L1Loss()

    #### Train the model
    loss = []
    mae = []
    epoch = []
    for i in range(n_epochs):
        epoch_loss, epoch_mae = train_model(
            i,
            model,
            train_loader,
            optimizer,
            loss_fn,
            standardizer,
            use_GPU,
            max_atoms,
            node_vec_len,
        )
        loss.append(epoch_loss)
        mae.append(epoch_mae)
        epoch.append(i)
        
    #### Save the model weights
    save_dir = main_path / 'data' / 'GCN' / 'GCN_outer_fold_weights'
    save_dir.mkdir(parents=True, exist_ok=True)   # create directory if missing

    save_path = save_dir / f'GCN_weight_outer_fold_{outer_fold_idx}.pth'

    # Move model to CPU before saving
    model_cpu = model.to('cpu')
    torch.save(model_cpu.state_dict(), save_path)

    print(f"Saved model weights for fold {outer_fold_idx} to {save_path}")
    

Epoch: [0]	Training Loss: [0.48]	Training MAE: [0.52]
Epoch: [1]	Training Loss: [0.20]	Training MAE: [0.33]
Epoch: [2]	Training Loss: [0.17]	Training MAE: [0.31]
Epoch: [3]	Training Loss: [0.16]	Training MAE: [0.29]
Epoch: [4]	Training Loss: [0.15]	Training MAE: [0.29]
Epoch: [5]	Training Loss: [0.14]	Training MAE: [0.28]
Epoch: [6]	Training Loss: [0.14]	Training MAE: [0.27]
Epoch: [7]	Training Loss: [0.13]	Training MAE: [0.27]
Epoch: [8]	Training Loss: [0.13]	Training MAE: [0.26]
Epoch: [9]	Training Loss: [0.13]	Training MAE: [0.26]
Epoch: [10]	Training Loss: [0.12]	Training MAE: [0.26]
Epoch: [11]	Training Loss: [0.12]	Training MAE: [0.26]
Epoch: [12]	Training Loss: [0.12]	Training MAE: [0.25]
Epoch: [13]	Training Loss: [0.12]	Training MAE: [0.25]
Epoch: [14]	Training Loss: [0.11]	Training MAE: [0.25]
Epoch: [15]	Training Loss: [0.11]	Training MAE: [0.25]
Epoch: [16]	Training Loss: [0.11]	Training MAE: [0.24]
Epoch: [17]	Training Loss: [0.11]	Training MAE: [0.24]
Epoch: [18]	Training