## Wrapper Method: `networkx_to_pyg`

This wrapper provides a simple method for converting a **NetworkX graph** into a **PyTorch Geometric `Data` object**, while also considering all **edge attributes**. This conversion is necessary to process graph-based data in PyTorch Geometric and train it with Graph Neural Networks (GNNs).

### Purpose and Considerations

A NetworkX graph is a commonly used data structure for manipulating graphs in Python. However, to use these graphs in a PyTorch Geometric model, they need to be converted into the specific **`Data` object** expected by PyTorch Geometric. The `Data` object contains:
- Node features (`x`) [not implemented yet. Should be done after feature engineering]
- Edge indices (`edge_index`)
- Edge attributes (`edge_attr`)

In [None]:
import torch
from torch_geometric.utils import from_networkx

def networkx_to_pyg(G):
    """
    Converts a NetworkX graph to a PyTorch Geometric Data object.
    All edge attributes are preserved and added as edge features.

    Args:
        G (networkx.Graph): The input NetworkX graph, which should have edge attributes.

    Returns:
        torch_geometric.data.Data: A PyTorch Geometric Data object containing the graph's node features, 
        edge indices, and edge attributes.
    """
    # Convert the NetworkX graph to a PyTorch Geometric Data object
    data = from_networkx(G)
    
    # Get the edge attributes (keys) from the first edge
    edge_keys = list(next(iter(G.edges(data=True)))[2].keys())  # Get all edge attribute keys
    
    # Prepare a list to store the edge features (attributes)
    edge_features = []

    # Loop through each edge attribute key
    for key in edge_keys:
        # Set the appropriate data type for the edge attribute
        dtype = torch.float32 if key == 'speed_rel' else torch.long
        
        # Collect the edge attributes for the current key
        edge_features.append(torch.tensor([G.edges[u, v][key] for u, v in G.edges()], dtype=dtype).unsqueeze(1))
    
    # Concatenate all edge features along the second dimension (column-wise)
    data.edge_attr = torch.cat(edge_features, dim=1)
    
    return data
