In [1]:
import pandas as pd
import numpy as np
import json
from tqdm import tqdm

In [2]:
gnn_train_data = pd.read_parquet('../data/curated/ML_data/gnn_train_data.parquet')
gnn_val_data = pd.read_parquet('../data/curated/ML_data/gnn_val_data.parquet')
gnn_test_data = pd.read_parquet('../data/curated/ML_data/gnn_test_data.parquet')

SA2_gnn_data = pd.read_parquet('../data/curated/ML_data/SA2_gnn_data.parquet')
station_inference_gnn_data = pd.read_parquet('../data/curated/ML_data/station_inference_gnn_data.parquet')

station_inference_gnn_data = station_inference_gnn_data.rename({'Station_Na': 'Station Name'}, axis=1)
inference_data = pd.concat([SA2_gnn_data, station_inference_gnn_data], axis=0)
inference_data = inference_data.rename({'Station Name': 'Station_Name'}, axis=1)

In [3]:
# stations_list = [x for x in gnn_train_data['Station_Name'].unique()]
# stations_index = {stations_list[i]:i for i in range(len(stations_list))}
# reverse_stations_index = {v: k for (k, v) in stations_index.items()}

# open npy
station_weights_matrix = np.load('../data/curated/ML_data/station_weights_matrix.npy')
SA2_weights_matrix = np.load('../data/curated/ML_data/station_weights_withSA2_matrix.npy')

with open('../data/curated/ML_features/station_weights_withSA2.json', 'r') as f:
    station_weights_withSA2 = json.load(f)

with open('../data/curated/ML_features/station_weights.json', 'r') as f:
    station_weights = json.load(f)

In [4]:
geospatial_features = ['log_Total_Demand']
non_geospatial_features = ['Weekday', 'PublicHoliday', 'mean_rainfall_value', 'has_school',
       'has_sport_facility', 'has_shopping_centre', 'has_hospital',
       'total_population', ' med_rent_weekly_c2021',
       ' med_mortg_rep_mon_c2021', ' med_person_inc_we_c2021',
       ' med_famly_inc_we_c2021']
label_columns = ['log_Total_Demand']

In [5]:
def DataFactory(raw_dataset, geospatial_features, non_geospatial_features, label_columns, stations_index, inference = False):

    """ Data Factory of GNN """
    
    geospatial_x_batches = []
    non_geospatial_x_batches = []
    y_batches = []
    masks = []

    if inference:
        groupby_column = 'Weekday'
    else:
        groupby_column = 'Business_Date'

    for day, daily_df in tqdm(raw_dataset.groupby([groupby_column])):

        geospatial_x = np.zeros([len(stations_index), len(geospatial_features)])
        y = np.zeros([len(stations_index), len(label_columns)])
        mask = np.zeros([len(stations_index), 1])
        non_geospatial_x = np.zeros([len(stations_index), len(non_geospatial_features)])

        daily_df.set_index('Station_Name', inplace=True)

        for station in daily_df.index:

            geospatial_x[stations_index[station]] = daily_df.loc[station][geospatial_features] # todo inference. 
            if not inference:
                y[stations_index[station]] = daily_df.loc[station][label_columns]
            mask[stations_index[station]] = 1
            non_geospatial_x[stations_index[station]] = daily_df.loc[station][non_geospatial_features]
                
        geospatial_x_batches.append(geospatial_x)
        y_batches.append(y)
        masks.append(mask.flatten())

        non_geospatial_x_batches.append(non_geospatial_x)

        
    return geospatial_x_batches, non_geospatial_x_batches, y_batches, masks

In [6]:
train_geospatial_X_batches, train_non_geospatial_X_batches, train_y_batches, train_masks = DataFactory(gnn_train_data, geospatial_features, non_geospatial_features, label_columns, station_weights)
val_geospatial_X_batches, val_non_geospatial_X_batches, val_y_batches, val_masks = DataFactory(gnn_val_data, geospatial_features, non_geospatial_features, label_columns, station_weights)
test_geospatial_X_batches, test_non_geospatial_X_batches, test_y_batches, test_masks = DataFactory(gnn_test_data, geospatial_features, non_geospatial_features, label_columns, station_weights)

# inference_geospatial_X_batches, inference_non_geospatial_X_batches, inference_y_batches, inference_masks = DataFactory(inference_data, geospatial_features, non_geospatial_features, label_columns, station_weights_withSA2, inference = True)

100%|██████████| 382/382 [01:09<00:00,  5.48it/s]
100%|██████████| 82/82 [00:15<00:00,  5.32it/s]
100%|██████████| 82/82 [00:14<00:00,  5.68it/s]


In [7]:
import sys
import os

py_file_location = '../'
home_directory = '../'

sys.path.append(os.path.abspath(py_file_location))
from model.model_class.environment import *

In [8]:
from model.model_class import GNN

In [9]:
SEED = 42

In [13]:
class GNN_config:
    # ----------------- architectual hyperparameters ----------------- #
    d_model = 256
    n_heads = 8
    dropout = 0.1
    n_gnn_layers = 1
    activation = nn.ReLU()
    res_learning = False
    # ----------------- optimisation hyperparameters ----------------- #
    random_state = SEED
    epochs = 32
    lr = 1e-3
    patience = 10
    loss = nn.MSELoss()
    validation_loss = nn.MSELoss()
    alpha = 0.1
    scheduler = True
    grad_clip = False
    # ----------------- operation hyperparameters ----------------- #
    spatial_input_dim = 1
    nonspatial_input_dim = 12
    # ----------------- saving hyperparameters ----------------- #
    rootpath = home_directory
    name = f'GNN'

model = GNN(GNN_config) # initialise the model

# train the model (all cells except this one will print training log and evaluation at each batch)
best_epoch = model.fit(train_geospatial_X_batches, train_non_geospatial_X_batches, train_y_batches, train_masks, val_geospatial_X_batches, val_non_geospatial_X_batches, val_y_batches, val_masks, station_weights_matrix)
print('\n\n')

# as model automatically saves best epoch, will now load the best epoch and evaluate on test set
model.load()
model.eval(val_geospatial_X_batches, val_non_geospatial_X_batches, val_y_batches, val_masks, station_weights_matrix, best_epoch, evaluation_mode = True)
model.eval(test_geospatial_X_batches, test_non_geospatial_X_batches, test_y_batches, test_masks, station_weights_matrix, best_epoch, evaluation_mode = True)

  0%|          | 0/382 [00:00<?, ?it/s]

100%|██████████| 382/382 [00:13<00:00, 28.85it/s]


 Epoch 1 Train | Loss:  0.0368 | R2:  0.9598| MSE:  0.0369 | RMSE:  0.1921 | MAE:  0.1140 


100%|██████████| 82/82 [00:00<00:00, 89.37it/s]


Epoch 1 Val | Loss:  0.0066 | R2:  0.9934| MSE:  0.0066 | RMSE:  0.0811 | MAE:  0.0679 


100%|██████████| 382/382 [00:13<00:00, 28.32it/s]


 Epoch 2 Train | Loss:  0.0152 | R2:  0.9833| MSE:  0.0152 | RMSE:  0.1233 | MAE:  0.0768 


100%|██████████| 82/82 [00:00<00:00, 94.74it/s]


Epoch 2 Val | Loss:  0.0079 | R2:  0.9920| MSE:  0.0079 | RMSE:  0.0892 | MAE:  0.0550 


100%|██████████| 382/382 [00:14<00:00, 27.07it/s]


 Epoch 3 Train | Loss:  0.0149 | R2:  0.9836| MSE:  0.0149 | RMSE:  0.1223 | MAE:  0.0758 


100%|██████████| 82/82 [00:00<00:00, 98.99it/s] 


Epoch 3 Val | Loss:  0.0068 | R2:  0.9931| MSE:  0.0068 | RMSE:  0.0827 | MAE:  0.0583 


100%|██████████| 382/382 [00:14<00:00, 26.25it/s]


 Epoch 4 Train | Loss:  0.0142 | R2:  0.9844| MSE:  0.0142 | RMSE:  0.1193 | MAE:  0.0727 


100%|██████████| 82/82 [00:00<00:00, 88.25it/s]


Epoch 4 Val | Loss:  0.0106 | R2:  0.9893| MSE:  0.0106 | RMSE:  0.1032 | MAE:  0.0656 


100%|██████████| 382/382 [00:14<00:00, 27.16it/s]


 Epoch 5 Train | Loss:  0.0134 | R2:  0.9852| MSE:  0.0134 | RMSE:  0.1160 | MAE:  0.0697 


100%|██████████| 82/82 [00:00<00:00, 97.51it/s] 


Epoch 5 Val | Loss:  0.0050 | R2:  0.9950| MSE:  0.0050 | RMSE:  0.0704 | MAE:  0.0479 


100%|██████████| 382/382 [00:14<00:00, 27.04it/s]


 Epoch 6 Train | Loss:  0.0132 | R2:  0.9854| MSE:  0.0132 | RMSE:  0.1150 | MAE:  0.0695 


100%|██████████| 82/82 [00:00<00:00, 106.34it/s]


Epoch 6 Val | Loss:  0.0021 | R2:  0.9979| MSE:  0.0021 | RMSE:  0.0463 | MAE:  0.0380 


100%|██████████| 382/382 [00:16<00:00, 22.97it/s]


 Epoch 7 Train | Loss:  0.0127 | R2:  0.9859| MSE:  0.0128 | RMSE:  0.1129 | MAE:  0.0679 


100%|██████████| 82/82 [00:00<00:00, 106.30it/s]


Epoch 7 Val | Loss:  0.0021 | R2:  0.9979| MSE:  0.0021 | RMSE:  0.0460 | MAE:  0.0343 


100%|██████████| 382/382 [00:13<00:00, 27.69it/s]


 Epoch 8 Train | Loss:  0.0123 | R2:  0.9865| MSE:  0.0123 | RMSE:  0.1111 | MAE:  0.0658 


100%|██████████| 82/82 [00:00<00:00, 111.26it/s]


Epoch 8 Val | Loss:  0.0057 | R2:  0.9943| MSE:  0.0057 | RMSE:  0.0756 | MAE:  0.0504 


100%|██████████| 382/382 [00:13<00:00, 27.54it/s]


 Epoch 9 Train | Loss:  0.0120 | R2:  0.9868| MSE:  0.0121 | RMSE:  0.1099 | MAE:  0.0649 


100%|██████████| 82/82 [00:00<00:00, 101.61it/s]


Epoch 9 Val | Loss:  0.0069 | R2:  0.9931| MSE:  0.0069 | RMSE:  0.0828 | MAE:  0.0724 


100%|██████████| 382/382 [00:13<00:00, 29.10it/s]


 Epoch 10 Train | Loss:  0.0118 | R2:  0.9870| MSE:  0.0119 | RMSE:  0.1089 | MAE:  0.0645 


100%|██████████| 82/82 [00:00<00:00, 91.76it/s] 


Epoch 10 Val | Loss:  0.0033 | R2:  0.9966| MSE:  0.0033 | RMSE:  0.0578 | MAE:  0.0403 


100%|██████████| 382/382 [00:13<00:00, 28.37it/s]


 Epoch 11 Train | Loss:  0.0119 | R2:  0.9869| MSE:  0.0120 | RMSE:  0.1093 | MAE:  0.0654 


100%|██████████| 82/82 [00:00<00:00, 89.69it/s]


Epoch 11 Val | Loss:  0.0057 | R2:  0.9943| MSE:  0.0057 | RMSE:  0.0757 | MAE:  0.0504 


100%|██████████| 382/382 [00:14<00:00, 27.14it/s]


 Epoch 12 Train | Loss:  0.0115 | R2:  0.9873| MSE:  0.0116 | RMSE:  0.1076 | MAE:  0.0634 


100%|██████████| 82/82 [00:00<00:00, 98.12it/s] 


Epoch 12 Val | Loss:  0.0043 | R2:  0.9957| MSE:  0.0043 | RMSE:  0.0658 | MAE:  0.0424 


100%|██████████| 382/382 [00:14<00:00, 26.94it/s]


 Epoch 13 Train | Loss:  0.0112 | R2:  0.9876| MSE:  0.0113 | RMSE:  0.1061 | MAE:  0.0623 


100%|██████████| 82/82 [00:00<00:00, 89.80it/s] 


Epoch 13 Val | Loss:  0.0043 | R2:  0.9957| MSE:  0.0043 | RMSE:  0.0652 | MAE:  0.0536 


100%|██████████| 382/382 [00:15<00:00, 25.05it/s]


 Epoch 14 Train | Loss:  0.0104 | R2:  0.9886| MSE:  0.0105 | RMSE:  0.1023 | MAE:  0.0585 


100%|██████████| 82/82 [00:00<00:00, 94.12it/s] 


Epoch 14 Val | Loss:  0.0023 | R2:  0.9977| MSE:  0.0023 | RMSE:  0.0475 | MAE:  0.0325 


100%|██████████| 82/82 [00:00<00:00, 86.65it/s]


Epoch 7 Val | Loss:  0.0021 | R2:  0.9979| MSE:  0.0021 | RMSE:  0.0460 | MAE:  0.0343 


100%|██████████| 82/82 [00:00<00:00, 104.26it/s]


Epoch 7 Val | Loss:  0.0019 | R2:  0.9979| MSE:  0.0019 | RMSE:  0.0438 | MAE:  0.0332 
