In [None]:
import os.path as osp
import numpy as np
import torch
from torch_geometric.loader import DataLoader

from XASNet.data import QM9_XAS
from XASNet.data import save_split

from XASNet.models import XASNet_GNN, XASNet_GAT, XASNet_GraphNet

from XASNet.trainer import GNNTrainer

In [None]:
model_name = 'model_name.pt'
# number of epochs in training
num_epochs = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#learning rate 
lr =1e-3
# milestones to reduce learning rate in steps 
milestones = np.arange(10, 100, 10).tolist()

# loading QM9-XAS data

In [None]:
root = 'path-to-qm9xas-dataset'
qm9_spec = QM9_XAS(root=root,
             raw_dir='./raw/',
             spectra=[])

In [None]:
len(qm9_spec)

In [None]:
# save/load split file
idxs = save_split(
    path='path-to-split-file',
    ndata=len(qm9_spec),
    ntrain=40000,
    nval=10000,
    ntest=0,
    save_split=True,
    shuffle=True, 
    print_nsample=True
)

In [None]:
#train, val and test data
train_qm9 = [qm9_spec[i] for i in idxs['train']]
val_qm9 = [qm9_spec[i] for i in idxs['val']]
#test_qm9 = qm9_spec[idxs['test']]

In [None]:
# data loaders 
train_loader = DataLoader(train_qm9, batch_size=100, shuffle=True)
val_loader = DataLoader(val_qm9, batch_size=100, shuffle=True)
#test_loader = DataLoader(test_qm9, batch_size=100)

In [None]:
len(qm9_spec)

# some more metrics

In [None]:
def RSE_loss(prediction, target):
    dE = (300 - 270) / 100
    nom = torch.sum(dE*torch.pow((target-prediction), 2))
    denom = torch.sum(dE*target)
    return torch.sqrt(nom) / denom 

In [None]:
def RMSE(prediction, target):
    return torch.sqrt(torch.mean((target - prediction)**2))

# loading the model 

## XASNet_GNN

In [None]:
xasnet_gnn = XASNet_GNN(
    gnn_name='gatv2',
    in_channels=[11, 128, 256, 512],
    out_channels=[128, 256, 512, 600],
    num_targets=100,
    num_layers=4,
    heads=3
).to(device)

# loading the saved model 
path_to_model = osp.join('./best_model', 
                         model_name)

if osp.exists(path_to_model):
    xasnet_gnn.load_state_dict(torch.load(path_to_model))
else:
    print('model is not loaded')

In [None]:
xasnet_gnn

## XASNet_GAT

In [None]:
xasnet_gat = XASNet_GAT(
    node_features_dim=11,
    in_channels=[128, 128, 128, 128],
    out_channels=[128, 128, 128, 400],
    targets=100,
    n_layers=4,
    n_heads=3,
    gat_type = 'gatv2_custom',
    use_residuals=True,
    use_jk=True
).to(device)

# loading the saved model 
path_to_model = osp.join('./best_model', 
                         model_name)

if osp.exists(path_to_model):
    xasnet_gat.load_state_dict(torch.load(path_to_model))
else:
    print('model is not loaded')

In [None]:
xasnet_gat

# XASNet_GraphNet

In [None]:
xasnet_graphnet = XASNet_GraphNet(
                 node_dim = 14,
                 edge_dim = 5,
                 hidden_channels = 512,
                 out_channels = 200,
                 gat_hidd = 512,
                 gat_out = 100,
                 n_layers = 3,
                 n_targets = 100).to(device)

# loading the saved model 
path_to_model = osp.join('./best_model', 
                         model_name)

if osp.exists(path_to_model):
    xasnet_graphnet.load_state_dict(torch.load(path_to_model))
else:
    print('model is not loaded')

In [None]:
xasnet_graphnet

# Training with the trainer class

In [None]:
optimizer = torch.optim.AdamW(xasnet_graphnet.parameters(), lr=lr)
loss_fn = torch.nn.L1Loss()
loss_fn2 = torch.nn.MSELoss()
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 
                                                 milestones=milestones,
                                                 gamma=0.8)

In [None]:
trainer = GNNTrainer(model=xasnet_graphnet, 
                     model_name="spectragraphnet_test",
                     device=device,
                     metric_path="./metrics")

In [None]:
trainer.train_val(train_loader, val_loader, optimizer,
                  RMSE, scheduler, num_epochs, write_every=1, train_graphnet=True)