In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm

In [2]:
daily_trains_demand_post_covid = pd.read_csv('../data/curated/train_demand/daily_trains_demand_post_covid.csv')

In [3]:
stations_list = [x for x in daily_trains_demand_post_covid['Station_Name'].unique()]
stations_index = {stations_list[i]:i for i in range(len(stations_list))}
reverse_stations_index = {v: k for (k, v) in stations_index.items()}

In [4]:
geospatial_features = ['log_Total_Demand']
non_geospatial_features = ['Weekday', 'PublicHoliday']
label_columns = ['log_Total_Demand']

In [5]:
def DataFactory(raw_dataset, geospatial_features, non_geospatial_features, label_columns):

    """ Data Factory of GNN """
    
    geospatial_x_batches = []
    non_geospatial_x_batches = []
    y_batches = []
    masks = []

    for day, daily_df in tqdm(raw_dataset.groupby(['Business_Date'])):

        geospatial_x = np.zeros([len(stations_index), len(geospatial_features)])
        y = np.zeros([len(stations_index), len(label_columns)])
        mask = np.zeros([len(stations_index), 1])

        daily_df.set_index('Station_Name', inplace=True)

        for station in daily_df.index:

            geospatial_x[stations_index[station]] = daily_df.loc[station][geospatial_features]
            y[stations_index[station]] = daily_df.loc[station][label_columns]
            mask[stations_index[station]] = 1
                
        geospatial_x_batches.append(geospatial_x)
        y_batches.append(y)
        masks.append(mask)

        non_geospatial_x = daily_df[non_geospatial_features].values

        non_geospatial_x_batches.append(non_geospatial_x)

        
    return geospatial_x_batches, non_geospatial_x_batches, y_batches, masks

In [6]:
geospatial_X_batches, non_geospatial_X_batches, y_batches, masks = DataFactory(daily_trains_demand_post_covid, geospatial_features, non_geospatial_features, label_columns)

100%|██████████| 546/546 [00:58<00:00,  9.28it/s]
