In [1]:
import pandas as pd
import numpy as np
import json
from tqdm import tqdm

In [2]:
gnn_train_data = pd.read_parquet('../data/curated/ML_data/gnn_train_data.parquet')
gnn_val_data = pd.read_parquet('../data/curated/ML_data/gnn_val_data.parquet')
gnn_test_data = pd.read_parquet('../data/curated/ML_data/gnn_test_data.parquet')

SA2_gnn_data = pd.read_parquet('../data/curated/ML_data/SA2_gnn_data.parquet')
station_inference_gnn_data = pd.read_parquet('../data/curated/ML_data/station_inference_gnn_data.parquet')

station_inference_gnn_data = station_inference_gnn_data.rename({'Station_Na': 'Station Name'}, axis=1)
inference_data = pd.concat([SA2_gnn_data, station_inference_gnn_data], axis=0)
inference_data = inference_data.rename({'Station Name': 'Station_Name'}, axis=1)

In [3]:
# stations_list = [x for x in gnn_train_data['Station_Name'].unique()]
# stations_index = {stations_list[i]:i for i in range(len(stations_list))}
# reverse_stations_index = {v: k for (k, v) in stations_index.items()}

# open npy
station_weights_matrix = np.load('../data/curated/ML_data/station_weights_matrix.npy')
SA2_weights_matrix = np.load('../data/curated/ML_data/station_weights_withSA2_matrix.npy')

with open('../data/curated/ML_features/station_weights_withSA2.json', 'r') as f:
    station_weights_withSA2 = json.load(f)

with open('../data/curated/ML_features/station_weights.json', 'r') as f:
    station_weights = json.load(f)

In [4]:
geospatial_features = ['log_Total_Demand']
non_geospatial_features = ['Weekday', 'PublicHoliday', 'mean_rainfall_value', 'has_school',
       'has_sport_facility', 'has_shopping_centre', 'has_hospital',
       'total_population', ' med_rent_weekly_c2021',
       ' med_mortg_rep_mon_c2021', ' med_person_inc_we_c2021',
       ' med_famly_inc_we_c2021']
label_columns = ['log_Total_Demand']

In [5]:
def DataFactory(raw_dataset, geospatial_features, non_geospatial_features, label_columns, stations_index, inference = False):

    """ Data Factory of GNN """
    
    geospatial_x_batches = []
    non_geospatial_x_batches = []
    y_batches = []
    masks = []

    if inference:
        groupby_column = 'Weekday'
    else:
        groupby_column = 'Business_Date'

    for day, daily_df in tqdm(raw_dataset.groupby([groupby_column])):

        geospatial_x = np.zeros([len(stations_index), len(geospatial_features)])
        y = np.zeros([len(stations_index), len(label_columns)])
        mask = np.zeros([len(stations_index), 1])
        non_geospatial_x = np.zeros([len(stations_index), len(non_geospatial_features)])

        daily_df.set_index('Station_Name', inplace=True)

        for station in daily_df.index:

            geospatial_x[stations_index[station]] = daily_df.loc[station][geospatial_features] # todo inference. 
            if not inference:
                y[stations_index[station]] = daily_df.loc[station][label_columns]
            mask[stations_index[station]] = 1
            non_geospatial_x[stations_index[station]] = daily_df.loc[station][non_geospatial_features]
                
        geospatial_x_batches.append(geospatial_x)
        y_batches.append(y)
        masks.append(mask.flatten())

        non_geospatial_x_batches.append(non_geospatial_x)

        
    return geospatial_x_batches, non_geospatial_x_batches, y_batches, masks

In [6]:
train_geospatial_X_batches, train_non_geospatial_X_batches, train_y_batches, train_masks = DataFactory(gnn_train_data, geospatial_features, non_geospatial_features, label_columns, station_weights)
val_geospatial_X_batches, val_non_geospatial_X_batches, val_y_batches, val_masks = DataFactory(gnn_val_data, geospatial_features, non_geospatial_features, label_columns, station_weights)
test_geospatial_X_batches, test_non_geospatial_X_batches, test_y_batches, test_masks = DataFactory(gnn_test_data, geospatial_features, non_geospatial_features, label_columns, station_weights)

# inference_geospatial_X_batches, inference_non_geospatial_X_batches, inference_y_batches, inference_masks = DataFactory(inference_data, geospatial_features, non_geospatial_features, label_columns, station_weights_withSA2, inference = True)

100%|██████████| 382/382 [01:18<00:00,  4.86it/s]
100%|██████████| 82/82 [00:16<00:00,  5.04it/s]
100%|██████████| 82/82 [00:16<00:00,  5.07it/s]


In [7]:
import sys
import os

py_file_location = '../'
home_directory = '../'

sys.path.append(os.path.abspath(py_file_location))
from model.model_class.environment import *

In [8]:
from model.model_class import GNN

In [9]:
SEED = 42

In [10]:
class GNN_config:
    # ----------------- architectual hyperparameters ----------------- #
    d_model = 256
    n_heads = 8
    dropout = 0.1
    n_gnn_layers = 1
    activation = nn.ReLU()
    res_learning = False
    bottleneck = True
    # ----------------- optimisation hyperparameters ----------------- #
    random_state = SEED
    epochs = 32
    lr = 1e-3
    patience = 5
    loss = nn.MSELoss()
    validation_loss = nn.MSELoss()
    alpha = 0.1
    scheduler = True
    grad_clip = False
    # ----------------- operation hyperparameters ----------------- #
    spatial_input_dim = 1
    nonspatial_input_dim = 12
    # ----------------- saving hyperparameters ----------------- #
    rootpath = home_directory
    name = f'GNN'

model = GNN(GNN_config) # initialise the model

# train the model (all cells except this one will print training log and evaluation at each batch)
best_epoch = model.fit(train_geospatial_X_batches, train_non_geospatial_X_batches, train_y_batches, train_masks, val_geospatial_X_batches, val_non_geospatial_X_batches, val_y_batches, val_masks, station_weights_matrix)
print('\n\n')

# as model automatically saves best epoch, will now load the best epoch and evaluate on test set
model.load()
model.eval(val_geospatial_X_batches, val_non_geospatial_X_batches, val_y_batches, val_masks, station_weights_matrix, best_epoch, evaluation_mode = True)
model.eval(test_geospatial_X_batches, test_non_geospatial_X_batches, test_y_batches, test_masks, station_weights_matrix, best_epoch, evaluation_mode = True)

  from .autonotebook import tqdm as notebook_tqdm
100%|██████████| 382/382 [00:17<00:00, 21.50it/s]


 Epoch 1 Train | Loss:  0.0871 | R2:  0.9055| MSE:  0.0867 | RMSE:  0.2944 | MAE:  0.1938 


100%|██████████| 82/82 [00:01<00:00, 47.88it/s]
  val_y_tensor = torch.FloatTensor(val_y).to(self.device)


Epoch 1 Val | Loss:  0.0148 | R2:  0.9852| MSE:  0.0148 | RMSE:  0.1215 | MAE:  0.0969 


100%|██████████| 382/382 [00:15<00:00, 24.82it/s]


 Epoch 2 Train | Loss:  0.0224 | R2:  0.9754| MSE:  0.0224 | RMSE:  0.1497 | MAE:  0.1035 


100%|██████████| 82/82 [00:02<00:00, 34.62it/s]


Epoch 2 Val | Loss:  0.0065 | R2:  0.9935| MSE:  0.0065 | RMSE:  0.0804 | MAE:  0.0644 


100%|██████████| 382/382 [00:17<00:00, 21.72it/s]


 Epoch 3 Train | Loss:  0.0156 | R2:  0.9828| MSE:  0.0157 | RMSE:  0.1252 | MAE:  0.0811 


100%|██████████| 82/82 [00:00<00:00, 106.66it/s]


Epoch 3 Val | Loss:  0.0035 | R2:  0.9965| MSE:  0.0035 | RMSE:  0.0590 | MAE:  0.0442 


100%|██████████| 382/382 [00:15<00:00, 24.49it/s]


 Epoch 4 Train | Loss:  0.0145 | R2:  0.9840| MSE:  0.0146 | RMSE:  0.1208 | MAE:  0.0738 


100%|██████████| 82/82 [00:00<00:00, 87.42it/s] 


Epoch 4 Val | Loss:  0.0036 | R2:  0.9964| MSE:  0.0036 | RMSE:  0.0598 | MAE:  0.0486 


100%|██████████| 382/382 [00:17<00:00, 22.10it/s]


 Epoch 5 Train | Loss:  0.0131 | R2:  0.9855| MSE:  0.0132 | RMSE:  0.1147 | MAE:  0.0696 


100%|██████████| 82/82 [00:00<00:00, 109.52it/s]


Epoch 5 Val | Loss:  0.0034 | R2:  0.9966| MSE:  0.0034 | RMSE:  0.0583 | MAE:  0.0454 


100%|██████████| 382/382 [00:15<00:00, 23.97it/s]


 Epoch 6 Train | Loss:  0.0126 | R2:  0.9861| MSE:  0.0126 | RMSE:  0.1124 | MAE:  0.0676 


100%|██████████| 82/82 [00:00<00:00, 107.66it/s]


Epoch 6 Val | Loss:  0.0052 | R2:  0.9948| MSE:  0.0052 | RMSE:  0.0722 | MAE:  0.0636 


100%|██████████| 382/382 [00:15<00:00, 24.09it/s]


 Epoch 7 Train | Loss:  0.0123 | R2:  0.9864| MSE:  0.0124 | RMSE:  0.1112 | MAE:  0.0668 


100%|██████████| 82/82 [00:00<00:00, 86.31it/s]


Epoch 7 Val | Loss:  0.0030 | R2:  0.9970| MSE:  0.0030 | RMSE:  0.0544 | MAE:  0.0402 


100%|██████████| 382/382 [00:14<00:00, 25.49it/s]


 Epoch 8 Train | Loss:  0.0122 | R2:  0.9866| MSE:  0.0123 | RMSE:  0.1108 | MAE:  0.0659 


100%|██████████| 82/82 [00:00<00:00, 103.42it/s]


Epoch 8 Val | Loss:  0.0020 | R2:  0.9979| MSE:  0.0020 | RMSE:  0.0452 | MAE:  0.0315 


100%|██████████| 382/382 [00:15<00:00, 24.60it/s]


 Epoch 9 Train | Loss:  0.0118 | R2:  0.9870| MSE:  0.0118 | RMSE:  0.1087 | MAE:  0.0650 


100%|██████████| 82/82 [00:00<00:00, 95.79it/s] 


Epoch 9 Val | Loss:  0.0029 | R2:  0.9971| MSE:  0.0029 | RMSE:  0.0539 | MAE:  0.0417 


100%|██████████| 382/382 [00:15<00:00, 24.80it/s]


 Epoch 10 Train | Loss:  0.0115 | R2:  0.9873| MSE:  0.0115 | RMSE:  0.1074 | MAE:  0.0643 


100%|██████████| 82/82 [00:00<00:00, 82.23it/s] 


Epoch 10 Val | Loss:  0.0037 | R2:  0.9963| MSE:  0.0037 | RMSE:  0.0605 | MAE:  0.0371 


100%|██████████| 382/382 [00:14<00:00, 26.13it/s]


 Epoch 11 Train | Loss:  0.0139 | R2:  0.9847| MSE:  0.0139 | RMSE:  0.1181 | MAE:  0.0714 


100%|██████████| 82/82 [00:01<00:00, 80.22it/s]


Epoch 11 Val | Loss:  0.0035 | R2:  0.9965| MSE:  0.0035 | RMSE:  0.0589 | MAE:  0.0446 





100%|██████████| 82/82 [00:00<00:00, 108.01it/s]


Epoch 8 Val | Loss:  0.0020 | R2:  0.9979| MSE:  0.0020 | RMSE:  0.0452 | MAE:  0.0315 


100%|██████████| 82/82 [00:00<00:00, 107.88it/s]


Epoch 8 Val | Loss:  0.0021 | R2:  0.9977| MSE:  0.0021 | RMSE:  0.0456 | MAE:  0.0324 


In [15]:
class GNN_config:
    # ----------------- architectual hyperparameters ----------------- #
    d_model = 256
    n_heads = 8
    dropout = 0.1
    n_gnn_layers = 2
    activation = nn.ReLU()
    res_learning = False
    bottleneck = True
    # ----------------- optimisation hyperparameters ----------------- #
    random_state = SEED
    epochs = 32
    lr = 1e-3
    patience = 5
    loss = nn.MSELoss()
    validation_loss = nn.MSELoss()
    alpha = 0.1
    scheduler = True
    grad_clip = False
    # ----------------- operation hyperparameters ----------------- #
    spatial_input_dim = 1
    nonspatial_input_dim = 12
    # ----------------- saving hyperparameters ----------------- #
    rootpath = home_directory
    name = f'GNN'

model = GNN(GNN_config) # initialise the model

# train the model (all cells except this one will print training log and evaluation at each batch)
best_epoch = model.fit(train_geospatial_X_batches, train_non_geospatial_X_batches, train_y_batches, train_masks, val_geospatial_X_batches, val_non_geospatial_X_batches, val_y_batches, val_masks, station_weights_matrix)
print('\n\n')

# as model automatically saves best epoch, will now load the best epoch and evaluate on test set
model.load()
model.eval(val_geospatial_X_batches, val_non_geospatial_X_batches, val_y_batches, val_masks, station_weights_matrix, best_epoch, evaluation_mode = True)
model.eval(test_geospatial_X_batches, test_non_geospatial_X_batches, test_y_batches, test_masks, station_weights_matrix, best_epoch, evaluation_mode = True)

  0%|          | 0/382 [00:00<?, ?it/s]

100%|██████████| 382/382 [00:20<00:00, 18.83it/s]


 Epoch 1 Train | Loss:  0.1123 | R2:  0.8780| MSE:  0.1119 | RMSE:  0.3345 | MAE:  0.2205 


100%|██████████| 82/82 [00:01<00:00, 66.16it/s]


Epoch 1 Val | Loss:  0.0171 | R2:  0.9829| MSE:  0.0171 | RMSE:  0.1308 | MAE:  0.1019 


100%|██████████| 382/382 [00:20<00:00, 18.80it/s]


 Epoch 2 Train | Loss:  0.0329 | R2:  0.9641| MSE:  0.0327 | RMSE:  0.1808 | MAE:  0.1275 


100%|██████████| 82/82 [00:01<00:00, 68.74it/s]


Epoch 2 Val | Loss:  0.0101 | R2:  0.9899| MSE:  0.0101 | RMSE:  0.1003 | MAE:  0.0672 


100%|██████████| 382/382 [00:20<00:00, 18.96it/s]


 Epoch 3 Train | Loss:  0.0237 | R2:  0.9740| MSE:  0.0237 | RMSE:  0.1541 | MAE:  0.1045 


100%|██████████| 82/82 [00:01<00:00, 71.90it/s]


Epoch 3 Val | Loss:  0.0098 | R2:  0.9902| MSE:  0.0098 | RMSE:  0.0989 | MAE:  0.0742 


100%|██████████| 382/382 [00:19<00:00, 19.14it/s]


 Epoch 4 Train | Loss:  0.0217 | R2:  0.9761| MSE:  0.0218 | RMSE:  0.1476 | MAE:  0.0982 


100%|██████████| 82/82 [00:01<00:00, 70.62it/s]


Epoch 4 Val | Loss:  0.0080 | R2:  0.9920| MSE:  0.0080 | RMSE:  0.0893 | MAE:  0.0748 


100%|██████████| 382/382 [00:20<00:00, 18.83it/s]


 Epoch 5 Train | Loss:  0.0185 | R2:  0.9796| MSE:  0.0185 | RMSE:  0.1361 | MAE:  0.0890 


100%|██████████| 82/82 [00:01<00:00, 69.90it/s]


Epoch 5 Val | Loss:  0.0087 | R2:  0.9912| MSE:  0.0087 | RMSE:  0.0935 | MAE:  0.0770 


100%|██████████| 382/382 [00:19<00:00, 19.14it/s]


 Epoch 6 Train | Loss:  0.0180 | R2:  0.9801| MSE:  0.0180 | RMSE:  0.1343 | MAE:  0.0871 


100%|██████████| 82/82 [00:01<00:00, 72.27it/s]


Epoch 6 Val | Loss:  0.0091 | R2:  0.9909| MSE:  0.0091 | RMSE:  0.0954 | MAE:  0.0802 


100%|██████████| 382/382 [00:20<00:00, 18.88it/s]


 Epoch 7 Train | Loss:  0.0161 | R2:  0.9822| MSE:  0.0161 | RMSE:  0.1269 | MAE:  0.0814 


100%|██████████| 82/82 [00:01<00:00, 72.38it/s]


Epoch 7 Val | Loss:  0.0070 | R2:  0.9930| MSE:  0.0070 | RMSE:  0.0836 | MAE:  0.0684 


100%|██████████| 382/382 [00:20<00:00, 18.96it/s]


 Epoch 8 Train | Loss:  0.0173 | R2:  0.9809| MSE:  0.0174 | RMSE:  0.1320 | MAE:  0.0854 


100%|██████████| 82/82 [00:01<00:00, 72.76it/s]


Epoch 8 Val | Loss:  0.0084 | R2:  0.9915| MSE:  0.0084 | RMSE:  0.0919 | MAE:  0.0757 


100%|██████████| 382/382 [00:20<00:00, 18.87it/s]


 Epoch 9 Train | Loss:  0.0157 | R2:  0.9828| MSE:  0.0157 | RMSE:  0.1253 | MAE:  0.0792 


100%|██████████| 82/82 [00:01<00:00, 72.67it/s]


Epoch 9 Val | Loss:  0.0096 | R2:  0.9904| MSE:  0.0096 | RMSE:  0.0980 | MAE:  0.0820 


100%|██████████| 382/382 [00:20<00:00, 18.83it/s]


 Epoch 10 Train | Loss:  0.0168 | R2:  0.9815| MSE:  0.0169 | RMSE:  0.1299 | MAE:  0.0822 


100%|██████████| 82/82 [00:01<00:00, 72.56it/s]


Epoch 10 Val | Loss:  0.0062 | R2:  0.9938| MSE:  0.0062 | RMSE:  0.0788 | MAE:  0.0682 


100%|██████████| 382/382 [00:20<00:00, 18.93it/s]


 Epoch 11 Train | Loss:  0.0152 | R2:  0.9833| MSE:  0.0152 | RMSE:  0.1234 | MAE:  0.0782 


100%|██████████| 82/82 [00:01<00:00, 72.49it/s]


Epoch 11 Val | Loss:  0.0071 | R2:  0.9929| MSE:  0.0071 | RMSE:  0.0841 | MAE:  0.0687 





100%|██████████| 82/82 [00:01<00:00, 70.35it/s]


Epoch 10 Val | Loss:  0.0062 | R2:  0.9938| MSE:  0.0062 | RMSE:  0.0788 | MAE:  0.0682 


100%|██████████| 82/82 [00:01<00:00, 72.08it/s]


Epoch 10 Val | Loss:  0.0062 | R2:  0.9931| MSE:  0.0062 | RMSE:  0.0786 | MAE:  0.0685 


In [16]:
class GNN_config:
    # ----------------- architectual hyperparameters ----------------- #
    d_model = 256
    n_heads = 8
    dropout = 0.1
    n_gnn_layers = 3
    activation = nn.ReLU()
    res_learning = False
    bottleneck = True
    # ----------------- optimisation hyperparameters ----------------- #
    random_state = SEED
    epochs = 32
    lr = 1e-3
    patience = 5
    loss = nn.MSELoss()
    validation_loss = nn.MSELoss()
    alpha = 0.1
    scheduler = True
    grad_clip = False
    # ----------------- operation hyperparameters ----------------- #
    spatial_input_dim = 1
    nonspatial_input_dim = 12
    # ----------------- saving hyperparameters ----------------- #
    rootpath = home_directory
    name = f'GNN'

model = GNN(GNN_config) # initialise the model

# train the model (all cells except this one will print training log and evaluation at each batch)
best_epoch = model.fit(train_geospatial_X_batches, train_non_geospatial_X_batches, train_y_batches, train_masks, val_geospatial_X_batches, val_non_geospatial_X_batches, val_y_batches, val_masks, station_weights_matrix)
print('\n\n')

# as model automatically saves best epoch, will now load the best epoch and evaluate on test set
model.load()
model.eval(val_geospatial_X_batches, val_non_geospatial_X_batches, val_y_batches, val_masks, station_weights_matrix, best_epoch, evaluation_mode = True)
model.eval(test_geospatial_X_batches, test_non_geospatial_X_batches, test_y_batches, test_masks, station_weights_matrix, best_epoch, evaluation_mode = True)

100%|██████████| 382/382 [00:29<00:00, 13.05it/s]


 Epoch 1 Train | Loss:  0.1345 | R2:  0.8540| MSE:  0.1339 | RMSE:  0.3660 | MAE:  0.2439 


100%|██████████| 82/82 [00:01<00:00, 46.28it/s]


Epoch 1 Val | Loss:  0.0384 | R2:  0.9615| MSE:  0.0384 | RMSE:  0.1960 | MAE:  0.1198 


100%|██████████| 382/382 [00:29<00:00, 12.79it/s]


 Epoch 2 Train | Loss:  0.0405 | R2:  0.9554| MSE:  0.0406 | RMSE:  0.2016 | MAE:  0.1431 


100%|██████████| 82/82 [00:01<00:00, 45.11it/s]


Epoch 2 Val | Loss:  0.0124 | R2:  0.9876| MSE:  0.0124 | RMSE:  0.1112 | MAE:  0.0833 


100%|██████████| 382/382 [00:29<00:00, 13.02it/s]


 Epoch 3 Train | Loss:  0.0284 | R2:  0.9689| MSE:  0.0284 | RMSE:  0.1685 | MAE:  0.1169 


100%|██████████| 82/82 [00:01<00:00, 47.65it/s]


Epoch 3 Val | Loss:  0.0125 | R2:  0.9875| MSE:  0.0125 | RMSE:  0.1117 | MAE:  0.0839 


100%|██████████| 382/382 [00:29<00:00, 13.09it/s]


 Epoch 4 Train | Loss:  0.0251 | R2:  0.9724| MSE:  0.0251 | RMSE:  0.1586 | MAE:  0.1063 


100%|██████████| 82/82 [00:01<00:00, 46.79it/s]


Epoch 4 Val | Loss:  0.0105 | R2:  0.9895| MSE:  0.0105 | RMSE:  0.1024 | MAE:  0.0769 


100%|██████████| 382/382 [00:29<00:00, 13.15it/s]


 Epoch 5 Train | Loss:  0.0452 | R2:  0.9497| MSE:  0.0457 | RMSE:  0.2138 | MAE:  0.1382 


100%|██████████| 82/82 [00:01<00:00, 44.89it/s]


Epoch 5 Val | Loss:  0.0327 | R2:  0.9673| MSE:  0.0327 | RMSE:  0.1807 | MAE:  0.0934 


100%|██████████| 382/382 [00:29<00:00, 13.07it/s]


 Epoch 6 Train | Loss:  0.0375 | R2:  0.9585| MSE:  0.0375 | RMSE:  0.1938 | MAE:  0.1305 


100%|██████████| 82/82 [00:01<00:00, 49.94it/s]


Epoch 6 Val | Loss:  0.0150 | R2:  0.9849| MSE:  0.0150 | RMSE:  0.1226 | MAE:  0.0954 


100%|██████████| 382/382 [00:29<00:00, 13.11it/s]


 Epoch 7 Train | Loss:  0.0269 | R2:  0.9703| MSE:  0.0270 | RMSE:  0.1642 | MAE:  0.1109 


100%|██████████| 82/82 [00:01<00:00, 48.88it/s]


Epoch 7 Val | Loss:  0.0071 | R2:  0.9929| MSE:  0.0071 | RMSE:  0.0844 | MAE:  0.0665 


100%|██████████| 382/382 [00:29<00:00, 13.02it/s]


 Epoch 8 Train | Loss:  0.0205 | R2:  0.9775| MSE:  0.0205 | RMSE:  0.1432 | MAE:  0.0967 


100%|██████████| 82/82 [00:01<00:00, 48.92it/s]


Epoch 8 Val | Loss:  0.0032 | R2:  0.9968| MSE:  0.0032 | RMSE:  0.0561 | MAE:  0.0411 


100%|██████████| 382/382 [00:29<00:00, 12.85it/s]


 Epoch 9 Train | Loss:  0.0189 | R2:  0.9792| MSE:  0.0190 | RMSE:  0.1377 | MAE:  0.0911 


100%|██████████| 82/82 [00:01<00:00, 43.83it/s]


Epoch 9 Val | Loss:  0.0051 | R2:  0.9949| MSE:  0.0051 | RMSE:  0.0716 | MAE:  0.0440 


100%|██████████| 382/382 [00:30<00:00, 12.44it/s]


 Epoch 10 Train | Loss:  0.0321 | R2:  0.9649| MSE:  0.0320 | RMSE:  0.1790 | MAE:  0.1181 


100%|██████████| 82/82 [00:02<00:00, 31.96it/s]


Epoch 10 Val | Loss:  0.0062 | R2:  0.9938| MSE:  0.0062 | RMSE:  0.0787 | MAE:  0.0485 





100%|██████████| 82/82 [00:02<00:00, 34.44it/s]


Epoch 8 Val | Loss:  0.0032 | R2:  0.9968| MSE:  0.0032 | RMSE:  0.0561 | MAE:  0.0411 


100%|██████████| 82/82 [00:02<00:00, 35.68it/s]


Epoch 8 Val | Loss:  0.0031 | R2:  0.9966| MSE:  0.0031 | RMSE:  0.0556 | MAE:  0.0413 


In [14]:
class GNN_config:
    # ----------------- architectual hyperparameters ----------------- #
    d_model = 256
    n_heads = 8
    dropout = 0
    n_gnn_layers = 4
    activation = nn.ReLU()
    res_learning = False
    bottleneck = True
    # ----------------- optimisation hyperparameters ----------------- #
    random_state = SEED
    epochs = 32
    lr = 1e-3
    patience = 5
    loss = nn.MSELoss()
    validation_loss = nn.MSELoss()
    alpha = 0.1
    scheduler = True
    grad_clip = False
    # ----------------- operation hyperparameters ----------------- #
    spatial_input_dim = 1
    nonspatial_input_dim = 12
    # ----------------- saving hyperparameters ----------------- #
    rootpath = home_directory
    name = f'GNN'

model = GNN(GNN_config) # initialise the model

# train the model (all cells except this one will print training log and evaluation at each batch)
best_epoch = model.fit(train_geospatial_X_batches, train_non_geospatial_X_batches, train_y_batches, train_masks, val_geospatial_X_batches, val_non_geospatial_X_batches, val_y_batches, val_masks, station_weights_matrix)
print('\n\n')

# as model automatically saves best epoch, will now load the best epoch and evaluate on test set
model.load()
model.eval(val_geospatial_X_batches, val_non_geospatial_X_batches, val_y_batches, val_masks, station_weights_matrix, best_epoch, evaluation_mode = True)
model.eval(test_geospatial_X_batches, test_non_geospatial_X_batches, test_y_batches, test_masks, station_weights_matrix, best_epoch, evaluation_mode = True)

100%|██████████| 382/382 [00:49<00:00,  7.78it/s]


 Epoch 1 Train | Loss:  0.1500 | R2:  0.8370| MSE:  0.1495 | RMSE:  0.3867 | MAE:  0.2570 


100%|██████████| 82/82 [00:02<00:00, 28.88it/s]


Epoch 1 Val | Loss:  0.0254 | R2:  0.9745| MSE:  0.0254 | RMSE:  0.1595 | MAE:  0.1151 


100%|██████████| 382/382 [02:09<00:00,  2.94it/s]


 Epoch 2 Train | Loss:  0.0420 | R2:  0.9538| MSE:  0.0420 | RMSE:  0.2050 | MAE:  0.1468 


100%|██████████| 82/82 [00:02<00:00, 30.37it/s]


Epoch 2 Val | Loss:  0.0178 | R2:  0.9821| MSE:  0.0178 | RMSE:  0.1335 | MAE:  0.0956 


100%|██████████| 382/382 [25:46<00:00,  4.05s/it]   


 Epoch 3 Train | Loss:  0.0357 | R2:  0.9607| MSE:  0.0358 | RMSE:  0.1893 | MAE:  0.1316 


100%|██████████| 82/82 [00:02<00:00, 34.43it/s]


Epoch 3 Val | Loss:  0.0187 | R2:  0.9813| MSE:  0.0187 | RMSE:  0.1366 | MAE:  0.1034 


100%|██████████| 382/382 [00:49<00:00,  7.78it/s]


 Epoch 4 Train | Loss:  0.0243 | R2:  0.9733| MSE:  0.0244 | RMSE:  0.1561 | MAE:  0.1059 


100%|██████████| 82/82 [00:02<00:00, 38.05it/s]


Epoch 4 Val | Loss:  0.0057 | R2:  0.9943| MSE:  0.0057 | RMSE:  0.0753 | MAE:  0.0526 


100%|██████████| 382/382 [00:44<00:00,  8.65it/s]


 Epoch 5 Train | Loss:  0.0255 | R2:  0.9718| MSE:  0.0256 | RMSE:  0.1600 | MAE:  0.1028 


100%|██████████| 82/82 [00:03<00:00, 26.22it/s]


Epoch 5 Val | Loss:  0.0483 | R2:  0.9515| MSE:  0.0483 | RMSE:  0.2199 | MAE:  0.1159 


100%|██████████| 382/382 [00:43<00:00,  8.70it/s]


 Epoch 6 Train | Loss:  0.0456 | R2:  0.9500| MSE:  0.0452 | RMSE:  0.2126 | MAE:  0.1402 


100%|██████████| 82/82 [00:02<00:00, 33.69it/s]


Epoch 6 Val | Loss:  0.0168 | R2:  0.9831| MSE:  0.0168 | RMSE:  0.1297 | MAE:  0.0828 


100%|██████████| 382/382 [00:35<00:00, 10.62it/s]


 Epoch 7 Train | Loss:  0.0478 | R2:  0.9472| MSE:  0.0479 | RMSE:  0.2188 | MAE:  0.1399 


100%|██████████| 82/82 [00:02<00:00, 33.14it/s]


Epoch 7 Val | Loss:  0.0153 | R2:  0.9847| MSE:  0.0153 | RMSE:  0.1236 | MAE:  0.1033 


100%|██████████| 382/382 [00:36<00:00, 10.59it/s]


 Epoch 8 Train | Loss:  0.0200 | R2:  0.9780| MSE:  0.0201 | RMSE:  0.1418 | MAE:  0.0945 


100%|██████████| 82/82 [00:02<00:00, 35.21it/s]


Epoch 8 Val | Loss:  0.0070 | R2:  0.9930| MSE:  0.0070 | RMSE:  0.0835 | MAE:  0.0522 





100%|██████████| 82/82 [00:02<00:00, 35.49it/s]


Epoch 4 Val | Loss:  0.0057 | R2:  0.9943| MSE:  0.0057 | RMSE:  0.0753 | MAE:  0.0526 


100%|██████████| 82/82 [00:02<00:00, 35.24it/s]

Epoch 4 Val | Loss:  0.0056 | R2:  0.9938| MSE:  0.0056 | RMSE:  0.0748 | MAE:  0.0541 



