# Dependencies

In [15]:
import numpy as np
import torch_geometric as pyg
import networkx as nx

from typing import Tuple

# Overview 
These functions will allow you to convert an osmnx graph into a Pytorch Geometric data object. 

In [21]:
def clear_node_features(G: nx.Graph, ignore: Tuple[str, ...] = ('x', 'y')) -> None:
    """
    Remove node features from the graph G except those given in ignore. 
    """
    for _, attr_dict in G.nodes(data=True):
        temp = {f: attr_dict[f] for f in ignore}
        attr_dict.clear()
        attr_dict.update(temp)


def clear_edge_features(G: nx.Graph, ignore: Tuple[str, ...] = ('length', )) -> None:
    """
    Remove node features from the graph G except those given in ignore. 
    """
    for u, v, attr_dict in G.edges(data=True):
        temp = {f: attr_dict[f] for f in ignore}
        attr_dict.clear()
        attr_dict.update(temp)


def extract_pos(data: pyg.data.Data, x_idx: int, y_idx: int) -> None:
    """
    Move the x, y coordinates from the x dictionary of the data object into the new pos dictionary. 
    """
    data.pos = data.x[:, [x_idx, y_idx]]
    data.x = data.x[:, np.delete(np.arange(data.x.shape[1]), [x_idx, y_idx])]


def osmnx_to_pyg(G: nx.Graph, node_features: Tuple[str, ...] = ('x', 'y'), 
                 edge_features: Tuple[str, ...] = ('length', ), pos: bool = True, inplace: bool = False)\
    -> pyg.data.Data: 
    """
    Convert a networkx graph G into a Pytorch Geometric data object keeping the desired node features
    and edge features. Additionally, if the nodes features contain the 'x' and 'y' keys and pos is given 
    as True, the x and y features are extracted to a separate 'pos' dictionary of the returned data object.
    """
    
    if inplace is False: 
        G = G.copy()

    clear_node_features(G, ignore=node_features)
    clear_edge_features(G, ignore=edge_features)

    data = pyg.utils.from_networkx(G, group_node_attrs=node_features, group_edge_attrs=edge_features)

    if pos is True:
        extract_pos(data, x_idx=node_features.index('x'), y_idx=node_features.index('y'))

    return data

# Example

In [1]:
import osmnx as ox 

In [23]:
G = ox.graph_from_place('New Delhi')
data = osmnx_to_pyg(G=G)

(tensor([], size=(16802, 0)),
 tensor([[77.1644, 28.5384],
         [77.1646, 28.5390],
         [77.1647, 28.5394],
         ...,
         [77.1253, 28.5861],
         [77.1110, 28.5855],
         [77.1134, 28.5811]]))

In [27]:
print(f"{data.x=} is empty because all features were extracted to \n{data.pos=}")

data.x=tensor([], size=(16802, 0)) is empty because all features were extracted to 
data.pos=tensor([[77.1644, 28.5384],
        [77.1646, 28.5390],
        [77.1647, 28.5394],
        ...,
        [77.1253, 28.5861],
        [77.1110, 28.5855],
        [77.1134, 28.5811]])
