## Create Training and Validation Set

### Importing Dependencies

We import the necessary libraries and functions, ensuring that all required modules and helper functions are properly integrated.

In [1]:
import os
import networkx as nx
import sys
import torch
import import_ipynb 

src_path = os.path.abspath(os.path.join(os.getcwd(), "..", ".."))
if src_path not in sys.path:
    sys.path.append(src_path)

from torch_geometric.data import Batch
from utils.data_splits.normalize_feature_attributes import normalize_feature
from utils.helper_functions.load_graphml_files import load_graphml_files
from utils.data_splits.print_datasplit_info import print_batch_shape


### Train-Validation Split

For predicting edge attributes (e.g., `tracks`), an 80/20 train/validation split is applied to the **existing edges within each graph**. This means that the edges are randomly split into training and validation sets, while the nodes remain unchanged during the process.

In our application, the **nodes represent physically existing bike stations**, which typically do not change or only change very infrequently. The aim of the analysis is to model the **connections between stations**, i.e., to understand and predict how many bicycles move along certain routes (in other words: edges with weights).

A **node-level split** (i.e., an 80/20 split of the nodes themselves) would mean that some stations would be completely unseen during training. This would not be meaningful because:

- The **stations themselves are not the prediction target**;
- It is the **relationships or transitions between the stations (edges)** that should be modeled;
- In deployment, **all stations are known** (they are physically installed in the system);

Initially, we wanted to use the `RandomLinkSplit()` function, but this is designed for classic link prediction – i.e., binary classification. It adds both positive examples (existing edges) AND negative examples (non-existing edges). Since our task is an edge attribute regression task, this method is unsuitable, and we manually implemented the split mechanism using a random permutation of the edges. The `edge_attr_key_index` parameter is used to specify which edge attribute (e.g., `tracks`) is the target variable for prediction.


In [2]:
def split_train_val(data_list, val_ratio=0.2, seed=42, edge_attr_key_index=4):
    """
    Splits a list of PyTorch Geometric Data objects into training and validation sets 
    for edge regression tasks.

    Parameters:
    -----------
    data_list : list of torch_geometric.data.Data
        List of graphs to be split into training and validation sets.
    
    val_ratio : float, optional (default=0.2)
        Proportion of edges to be used for validation in each graph.
    
    seed : int, optional (default=42)
        Random seed for reproducibility.
    
    edge_attr_key_index : int, optional (default=4)
        The index of the edge attribute that should be predicted (e.g., 'tracks'). 
        This attribute will be used as the target (`y`) for the regression task.
    
    Returns:
    --------
    train_data, val_data : torch_geometric.data.Batch
        Batched training and validation data containing the graphs' edge indices, 
        edge attributes, and the target edge attribute (`y`) for regression.
    """

    torch.manual_seed(seed)

    train_list, val_list = [], []
    total_train_edges = 0
    total_val_edges = 0

    for i, data in enumerate(data_list):
        edge_index = data.edge_index
        edge_attr = data.edge_attr

        num_edges = edge_index.size(1)
        num_val = int(val_ratio * num_edges)
        perm = torch.randperm(num_edges)

        val_idx = perm[:num_val]
        train_idx = perm[num_val:]

        # Training Data
        train_data = data.clone()
        train_data.edge_index = edge_index[:, train_idx]
        train_data.edge_attr = torch.cat([edge_attr[train_idx][:, :edge_attr_key_index], edge_attr[train_idx][:, edge_attr_key_index+1:]], dim=1)
        train_data.y = edge_attr[train_idx][:, edge_attr_key_index]  

        # Validation Data
        val_data = data.clone()
        val_data.edge_index = edge_index[:, val_idx]
        val_data.edge_attr = torch.cat([edge_attr[val_idx][:, :edge_attr_key_index], edge_attr[val_idx][:, edge_attr_key_index+1:]], dim=1)
        val_data.y = edge_attr[val_idx][:, edge_attr_key_index]

        train_list.append(train_data)
        val_list.append(val_data)

        total_train_edges += train_data.edge_index.size(1)
        total_val_edges += val_data.edge_index.size(1)

        print(f"Graph {i}: Train edges = {train_data.edge_index.size(1)}, Val edges = {val_data.edge_index.size(1)}")

    # Batch the split data
    train_data = Batch.from_data_list(train_list)
    val_data = Batch.from_data_list(val_list)

    print(f"\nTotal train edges (batched): {total_train_edges}")
    print(f"Total val edges   (batched): {total_val_edges}")

    return train_data, val_data


### Executing the Pipeline for Creating Training and Validation Data Sets

This script defines a `main` function that orchestrates the entire pipeline for generating training and validation splits for Graph Neural Networks (GNNs). The previously defined functions are called sequentially to:

1. **Load the graph data**: The GraphML files for the specified years are loaded into PyTorch Geometric `Data` objects.
2. **Split the edges** into training and validation sets: The edges of each graph are randomly split into training and validation subsets, with the specified ratio for validation data (default 20%).
3. **Feature normalization**: The features of the training and validation sets are normalized.
4. **Print batch statistics**: To gain insight into the structure of the train and validation sets, the number of graphs, nodes, and edges is printed.

The training and validation data are then saved for later use in model training.


In [3]:
def main(years=[2021, 2022, 2023], val_ratio=0.2):
    """
    Main pipeline for loading graph data, preprocessing it, and splitting into train/val sets.

    Parameters:
    -----------
    years : list of int, optional (default=[2021, 2022, 2023])
        The years for which GraphML files will be loaded.

    val_ratio : float, optional (default=0.2)
        Proportion of edges to be used for validation during the train/validation split.

    Returns:
    --------
    None
    """
    
    save_dir = os.path.join("..", "..", "..", "data", "data_splits")
    os.makedirs(save_dir, exist_ok=True)
    train_save_path = os.path.join(save_dir, "train_data.pt")
    val_save_path = os.path.join(save_dir, "val_data.pt")

    # Load GraphML files for the specified years and convert to PyTorch Geometric Data objects
    data_list = load_graphml_files(years)
    
    # Split data into train and validation sets
    train_data, val_data = split_train_val(data_list, val_ratio=val_ratio)
    
    # Normalize features in the training and validation data
    train_data, val_data = normalize_feature(train_data, val_data)
    
    print("\nTrain Data Statistics:")
    print_batch_shape(train_data)
    print("\nValidation Data Statistics:")
    print_batch_shape(val_data)

    # Save the train and validation data
    torch.save(train_data, train_save_path)
    torch.save(val_data, val_save_path)

    print(f"\nTrain data saved to: {train_save_path}")
    print(f"Val data saved to: {val_save_path}")

main()


Number of loaded graphs: 36
Graph 0: Train edges = 12952, Val edges = 3238
Graph 1: Train edges = 15959, Val edges = 3989
Graph 2: Train edges = 19904, Val edges = 4976
Graph 3: Train edges = 20536, Val edges = 5134
Graph 4: Train edges = 31279, Val edges = 7819
Graph 5: Train edges = 38328, Val edges = 9582
Graph 6: Train edges = 32741, Val edges = 8185
Graph 7: Train edges = 30916, Val edges = 7728
Graph 8: Train edges = 30703, Val edges = 7675
Graph 9: Train edges = 21666, Val edges = 5416
Graph 10: Train edges = 17965, Val edges = 4491
Graph 11: Train edges = 16285, Val edges = 4071
Graph 12: Train edges = 17178, Val edges = 4294
Graph 13: Train edges = 17021, Val edges = 4255
Graph 14: Train edges = 23511, Val edges = 5877
Graph 15: Train edges = 25231, Val edges = 6307
Graph 16: Train edges = 35055, Val edges = 8763
Graph 17: Train edges = 41335, Val edges = 10333
Graph 18: Train edges = 35775, Val edges = 8943
Graph 19: Train edges = 33536, Val edges = 8384
Graph 20: Train edges