## Wrapper Method: `transform_networkx_into_pyg`

This wrapper provides a method for converting a **NetworkX graph** into a **PyTorch Geometric `Data` object**, including all available **edge attributes**. This step is essential because **PyTorch Geometric does not natively support `.graphml` files** and therefore cannot load them directly. The conversion enables further processing and training with Graph Neural Networks (GNNs) in PyTorch Geometric.

### Purpose and Considerations

NetworkX is a widely used library for creating and manipulating graphs in Python. However, to use these graphs within a PyTorch Geometric pipeline, they must first be converted into the framework’s specific `Data` format. This object typically includes:

- `x`: Node features  
  *(*Not yet implemented – to be added after feature engineering*)*
- `edge_index`: Graph connectivity in COO format
- `edge_attr`: Edge attributes

This wrapper ensures that all relevant edge information is preserved during the conversion.


In [None]:
import torch
from torch_geometric.utils import from_networkx

def transform_networkx_into_pyg(G):
    """
    Converts a NetworkX graph to a PyTorch Geometric Data object.
    All edge attributes are preserved and added as edge features.

    Args:
        G (networkx.Graph): The input NetworkX graph, which should have edge attributes.

    Returns:
        torch_geometric.data.Data: A PyTorch Geometric Data object containing the graph's node features, 
        edge indices, and edge attributes.
    """
    # Convert the NetworkX graph to a PyTorch Geometric Data object
    data = from_networkx(G)
    
    # Get the edge attributes (keys) from the first edge
    edge_keys = list(next(iter(G.edges(data=True)))[2].keys())  # Get all edge attribute keys
    
    # Prepare a list to store the edge features (attributes)
    edge_features = []

    # Loop through each edge attribute key
    for key in edge_keys:
        # Set the appropriate data type for the edge attribute
        dtype = torch.float32 if key == 'speed_rel' else torch.long
        
        # Collect the edge attributes for the current key
        edge_features.append(torch.tensor([G.edges[u, v][key] for u, v in G.edges()], dtype=dtype).unsqueeze(1))
    
    # Concatenate all edge features along the second dimension (column-wise)
    data.edge_attr = torch.cat(edge_features, dim=1)
    
    return data
