## Wrapper Method: `transform_networkx_into_pyg`

This wrapper provides a method for converting a **NetworkX graph** into a **PyTorch Geometric `Data` object**, including all available **node and edge attributes**. This step is essential because **PyTorch Geometric does not natively support `.graphml` files** and therefore cannot load them directly. The conversion enables further processing and training with Graph Neural Networks (GNNs) in PyTorch Geometric.


### Purpose and Considerations

NetworkX is a widely used library for creating and manipulating graphs in Python. However, to use these graphs within a PyTorch Geometric pipeline, they must first be converted into the framework’s specific `Data` format. This object typically includes:

- `x`: Node features (Geographical coordinates such as longitude and latitude) 
*(*Additional nodes after feature engineering*)*
- `edge_index`: Graph connectivity in COO format
- `edge_attr`: Edge attributes

This wrapper ensures that all relevant node and edge information is preserved during the conversion.

In [None]:
import torch
from torch_geometric.utils import from_networkx

def transform_networkx_into_pyg(G):
    """
    Converts a NetworkX graph to a PyTorch Geometric Data object.
    All node and edge attributes are preserved and added as node features and edge features, respectively.

    Args:
        G (networkx.Graph): The input NetworkX graph, which should contain all node and edge attributes.

    Returns:
        torch_geometric.data.Data: A PyTorch Geometric Data object containing the graph's node features, 
        edge indices, and edge attributes.
    """

    # Convert the NetworkX graph to a PyTorch Geometric Data object
    data = from_networkx(G)
    
    # --- 1. Add node features (e.g., lon, lat) to data.x ---
    node_features = []
    for node, attrs in G.nodes(data=True):
        lon = float(attrs.get("lon", 0.0)) 
        lat = float(attrs.get("lat", 0.0))
        node_features.append([lon, lat])  

    data.x = torch.tensor(node_features, dtype=torch.float) 
    
    # --- 2. Edge attribute handling
    edge_keys = list(next(iter(G.edges(data=True)))[2].keys())

    edge_features = []
    for key in edge_keys:
        dtype = torch.float32 if key == 'speed_rel' else torch.long
        edge_features.append(torch.tensor([G.edges[u, v][key] for u, v in G.edges()], dtype=dtype).unsqueeze(1))

    data.edge_attr = torch.cat(edge_features, dim=1)

    return data
