In [1]:
from preprocessing import preprocess_flowdata, preprocess_graph
from dataset import WaterFlowDataSet
from model import Model
import torch
import hyperparameters as hp
from torch_geometric.loader import DataLoader
import matplotlib.pyplot as plt

In [2]:
criterion = torch.nn.MSELoss()

def train(model, optimiser, loader):
    model.train()
    total_loss = 0

    for batch in loader:
        optimiser.zero_grad()
        out = model(batch.x, batch.edge_index) 
        loss = criterion(out[batch.mask], batch.y[batch.mask].squeeze())
        loss.backward()
        optimiser.step()
        total_loss += loss.item() * batch.num_graphs 

    return total_loss / len(loader.dataset)

def calc_index_of_agreement(y_pred, y_actual, eps=1e-8):
    y_actual_mean = torch.mean(y_actual)

    numerator = torch.sum((y_pred - y_actual) ** 2)
    denominator = torch.sum((torch.abs(y_pred - y_actual_mean) + torch.abs(y_actual - y_actual_mean)) ** 2)
    
    index_of_agreement = 1 - numerator / (denominator + eps)
    return index_of_agreement

def test(model, loader):
    model.eval()
    total_loss = 0.0

    all_y_pred = []
    all_y_actual = []

    with torch.no_grad():
        for batch in loader:
            out = model(batch.x, batch.edge_index)
            mask = ~batch.mask 
            y_actual = batch.y[mask]

            if mask.sum() == 0:
                continue

            y_pred = out[mask]

            loss = criterion(y_pred, y_actual)
            total_loss += loss.item() * batch.num_graphs

            all_y_pred.append(y_pred)
            all_y_actual.append(y_actual)

    average_loss = total_loss / len(loader.dataset)

    all_y_pred = torch.cat(all_y_pred, dim=0)
    all_y_actual = torch.cat(all_y_actual, dim=0)

    index_of_agreement = calc_index_of_agreement(all_y_pred, all_y_actual)

    return average_loss, index_of_agreement

In [3]:
histories = [96, 144, 192]
horizons = [24, 48, 96]
NB_EPOCHS = 10

results = {}

for context_window in histories:
    results[context_window] = {}

    for forecast_window in horizons:

        print(f"Context window: {context_window}, Forecast window: {forecast_window}")

        total_window = context_window + forecast_window
        df_node_features, df_node_features_strata = preprocess_flowdata(hp.FLOWDATA_PATH, total_window)
        edge_indices, edge_weights = preprocess_graph(hp.SUBSETGRAPH_PATH)

        train_dataset = WaterFlowDataSet(df_node_features[0], df_node_features_strata[0], edge_indices, edge_weights, forecast_window)
        train_loader = DataLoader(train_dataset, batch_size=hp.BATCH_SIZE, shuffle=True)
        test_dataset = WaterFlowDataSet(df_node_features[2], df_node_features_strata[2], edge_indices, edge_weights, forecast_window)
        test_loader = DataLoader(test_dataset, batch_size=16, shuffle=True)

        model = Model(input_channels = total_window, 
                      hidden_channels = total_window, 
                      output_channels = total_window, 
                      num_heads = 4, 
                      embed_dim = 32, 
                      context_window=context_window, 
                      forecast_window=forecast_window
                      )
        
        optimiser = torch.optim.Adam(model.parameters(), lr=0.005)

        for epoch in range(NB_EPOCHS):
            epoch_train_loss = train(model, optimiser, train_loader)
        
        average_loss, index_of_agreement = test(model, test_loader)
        results[context_window][forecast_window] = average_loss

        print(f"Loss: {average_loss}")
        

Context window: 96, Forecast window: 24
Loss: 0.5990758577982584
Context window: 96, Forecast window: 48
Loss: 0.6641614635785421
Context window: 96, Forecast window: 96
Loss: 0.5779445171356201
Context window: 144, Forecast window: 24
Loss: 0.5956897366614569
Context window: 144, Forecast window: 48
Loss: 0.6775455474853516
Context window: 144, Forecast window: 96
Loss: 0.7218520045280457
Context window: 192, Forecast window: 24
Loss: 0.8559219241142273
Context window: 192, Forecast window: 48
Loss: 0.603398859500885
Context window: 192, Forecast window: 96


ValueError: The least populated classes in y have only 1 member, which is too few. The minimum number of groups for any class cannot be less than 2. Classes with too few members are: ['weekday_spring']