## Create Test Set

### Importing Dependencies

We import the necessary libraries and functions, ensuring that all required modules and helper functions are properly integrated.

In [2]:
import os
import networkx as nx
import torch
import sys
import import_ipynb 

src_path = os.path.abspath(os.path.join(os.getcwd(), "..", ".."))
if src_path not in sys.path:
    sys.path.append(src_path)

from torch_geometric.data import Batch
from utils.helper_functions.load_graphml_files import load_graphml_files
from utils.data_splits.normalize_feature_attributes import normalize_test_features
from utils.data_splits.print_datasplit_info import print_batch_shape



### Executing the Pipeline for Creating the Test Dataset

This script defines a `main` function that orchestrates the complete pipeline for generating a test dataset for Graph Neural Networks (GNNs). The previously defined functions are called sequentially to:

1. **Load the graph data**: The GraphML files for the specified years are loaded into PyTorch Geometric `Data` objects.
2. **Batch the graphs**: The individual graphs are then combined into a single batched dataset.
3. **Feature normalization**: The features are normalized using the scalers that were previously initialized during training.
4. **Print batch statistics**: To gain insight into the structure of the test set, the number of graphs, nodes, and edges is printed.

The resulting test dataset is then saved for later use in model evaluation.


In [3]:
def main(years=[2024]):
    save_dir = os.path.join("..", "..", "..", "data", "data_splits")
    os.makedirs(save_dir, exist_ok=True)
    test_save_path = os.path.join(save_dir, "test_data.pt")

    test_data_list = load_graphml_files(years)
    test_data_batch = Batch.from_data_list(test_data_list)
    test_data_batch = normalize_test_features(test_data_batch)
    print("\nTest Data Statistics:")
    print_batch_shape(test_data_batch)

    torch.save(test_data_batch, test_save_path)
    print(f"\nTest data saved to: {test_save_path}")

main()


Number of loaded graphs: 12

Test Data Statistics:
Number of graphs in batch: 12
Number of nodes: 163599
Node feature shape: torch.Size([163599, 2])
Number of edges: 390360
Edge index shape: torch.Size([2, 390360])
Edge attributes shape: torch.Size([390360, 4])
Node features shape: torch.Size([163599, 2])

Test data saved to: ..\..\..\data\data_splits\test_data.pt
