In [1]:
import os
import networkx as nx
import torch
import sys
import import_ipynb 

src_path = os.path.abspath(os.path.join(os.getcwd(), "..", ".."))
if src_path not in sys.path:
    sys.path.append(src_path)

from torch_geometric.data import Batch
from utils.wrapper.transform_networkx_into_pyg import transform_networkx_into_pyg
from utils.helper_functions.normalize_feature_attributes import normalize_test_features



Working directory: c:\Users\yanni\OneDrive\Dokumente\Master Wifo\2\Web Mining\Projekt\WebMining_Team6_bikeNetworkAnalysis\src\utils\data_splits


In [2]:
def load_graphml_files(years=[2024]):
    """
    Loads multiple directed graph files in GraphML format and converts them 
    into PyTorch Geometric (PyG) Data objects.

    Parameters:
    -----------
    years : list of int, optional (default=[2024])
        List of years for which graph files should be loaded. 
        Assumes 12 monthly files per year.

    Returns:
    --------
    data_list : list of torch_geometric.data.Data
        List of PyG data objects created from the loaded NetworkX graphs.
    """

    testdata_list = []

    for year in years:
        for i in range(12):
            path = f"../../../data/graphml/{year}/bike_network_{year}_{i}.graphml"
            if not os.path.exists(path):
                print(f"[WARN] File not found: {path}")
                continue

            G_nx = nx.read_graphml(path)
            G_nx = nx.DiGraph(G_nx)

            data = transform_networkx_into_pyg(G_nx)
            testdata_list.append(data)

    print(f"Number of loaded graphs: {len(testdata_list)}")
    return testdata_list


In [3]:
def main(years=[2024]):
    save_dir = "../../../data/data_splits"
    os.makedirs(save_dir, exist_ok=True)
    test_save_path = os.path.join(save_dir, "test_data.pt")

    test_data_list = load_graphml_files(years)
    test_data_batch = Batch.from_data_list(test_data_list)
    test_data_batch = normalize_test_features(test_data_batch)

    torch.save(test_data_batch, test_save_path)
    print(f"\nTest data saved to: {test_save_path}")

main()


Number of loaded graphs: 12

Test data saved to: ../../../data/data_splits\test_data.pt


In [4]:
def print_graph_info(batch):
    """
    Gibt Informationen über die Anzahl der Knoten, Kanten und Attribute des Graphen aus.
    """
    # Anzahl der Graphen im Batch
    num_graphs = batch.num_graphs
    print(f"Number of graphs in batch: {num_graphs}")
    
    # Knoten
    num_nodes = batch.x.shape[0]  # Anzahl der Knoten
    print(f"Number of nodes: {num_nodes}")
    print(f"Node feature shape: {batch.x.shape}")
    
    # Kanten
    num_edges = batch.edge_index.shape[1]  # Anzahl der Kanten (Spalten von edge_index)
    print(f"Number of edges: {num_edges}")
    print(f"Edge index shape: {batch.edge_index.shape}")
    
    if batch.edge_attr is not None:
        print(f"Edge attributes shape: {batch.edge_attr.shape}")
    
    if batch.x is not None:
        print(f"Node features shape: {batch.x.shape}")
    else:
        print("No node features available.")




In [6]:
import torch

# Laden der gespeicherten Batch-Datei
test_data_batch = torch.load('../../../data/data_splits/test_data.pt', weights_only=False)

# Funktionsaufruf zur Anzeige der Graph-Informationen
print_graph_info(test_data_batch)


Number of graphs in batch: 12
Number of nodes: 163599
Node feature shape: torch.Size([163599, 2])
Number of edges: 390360
Edge index shape: torch.Size([2, 390360])
Edge attributes shape: torch.Size([390360, 4])
Node features shape: torch.Size([163599, 2])
