In [None]:
import torch
import matplotlib.pyplot as plt

import time

import os.path as osp
import numpy as np

from dataset import make_dataset
from train import make_data_loader, train_step, test_evaluations, save_model_GCN
from utils import get_device, plot_training_progress
from model import GCNNetwork, CEALNetwork

from args import *
from utils import *

Prepare the dataset

In [None]:
train_dataset, validation_dataset, test_dataset = make_dataset()
# for i, j, k in zip(train_dataset, validation_dataset, test_dataset):
#     i.edge_attr = edge_weight_to_edge_attr(i.edge_weight)
#     j.edge_attr = edge_weight_to_edge_attr(j.edge_weight)
#     k.edge_attr = edge_weight_to_edge_attr(k.edge_weight)
train_loader, val_loader, test_loader = make_data_loader(train_dataset, validation_dataset, test_dataset)

Model training parameters

In [None]:
model_name = "CEAL"
model_network = model_name + "Network"
model_args = args[model_name]

In [None]:
print(len(train_dataset), len(validation_dataset), len(test_dataset))

# device = get_device()
device = torch.device("cpu")

in_dim = train_dataset[0].x.shape[-1]
deg = generate_deg(train_dataset)
# model = globals()[model_network](deg, in_dim)
model = CEALNetwork(deg, in_dim)
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=model_args["learning_rate"])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode=model_args["sche_mode"], factor=model_args["sche_factor"], patience=model_args["sche_patience"], min_lr=model_args["sche_min_lr"]
)

In [None]:
result_path = createResultFolder(osp.join("./results", model_name))
test_best_loss = None
epoch = None

train_losses = []
test_losses = []
val_losses = []

plt.figure(figsize=(8, 6))
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Loss vs. Epoch during Training")
plt.grid(True)

Model training

In [None]:
epoch_start = 0 if epoch is None else epoch
epoch = epoch_start
epochs = model_args["epochs"]

for epoch in range(epoch_start + 1, epochs + 1):

    model, train_loss = train_step(model, train_loader, train_dataset, optimizer, device)
    val_loss, _, _ = test_evaluations(model, val_loader, validation_dataset, device, ret_data=False)
    test_loss, _, _ = test_evaluations(model, test_loader, test_dataset, device, ret_data=False)

    scheduler.step(val_loss)
    current_lr = optimizer.param_groups[0]["lr"]

    train_losses.append(train_loss)
    val_losses.append(val_loss)
    test_losses.append(test_loss)

    plot_training_progress(len(train_losses), train_losses, val_losses, test_losses, split=15)

    # save best model
    if test_best_loss is None or test_loss < test_best_loss:
        test_best_loss = test_loss
        save_model_GCN(epoch, model, optimizer, scheduler, result_path)

    progress_msg = "Epoch " + str(epoch)
    progress_msg += ", train loss(MAE)=" + str(round(train_loss, 4))
    progress_msg += ", valid loss(MAE)=" + str(round(val_loss, 4))
    progress_msg += ", test loss(MAE)=" + str(round(test_loss, 4))
    progress_msg += ", lr=" + str(round(current_lr, 8))
    progress_msg += ", best_test=" + str(round(test_best_loss, 4))

    print(progress_msg)

Save model and checkpoint

In [None]:
checkpoint = torch.load(osp.join(result_path, "checkpoint.pt"),
                        map_location=get_device())

Show results

In [None]:
best_model = checkpoint["model"]

save_hyper_parameter(args, result_path)
save_train_progress(epoch - 1, train_losses, val_losses, test_losses, result_path)
test_loss, test_out, test_y = test_evaluations(best_model, test_loader, test_dataset, device, ret_data=True)

# Reverse normalization of test_out and y
min, max = get_data_scale(args)
test_y = reverse_min_max_scalar_1d(test_y, min, max)
test_out = reverse_min_max_scalar_1d(test_out, min, max)
loss = (test_out.squeeze() - test_y).abs().mean()
print("MAE loss of formation energy is: ", loss.item())

# save results
plot_training_progress(len(train_losses), train_losses, val_losses, test_losses, res_path=result_path)
save_regression_result(test_out, test_y, result_path)
plot_regression_result("GCN", result_path, plotfilename="regression_figure.jpeg")