In [6]:
import torch
from torch_geometric.nn import TransformerConv
from torch_geometric.data import Data
import torch.nn.functional as F

import osmium
from osmium import osm

  from .autonotebook import tqdm as notebook_tqdm


In [15]:
class GTN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers,
                 dropout, beta=True, heads=1):
        super(GTN, self).__init__()

        # The list of transormer conv layers for the each layer block.
        self.num_layers = num_layers
        conv_layers = [TransformerConv(input_dim, hidden_dim//heads, heads=heads, beta=beta)]
        conv_layers += [TransformerConv(hidden_dim, hidden_dim//heads, heads=heads, beta=beta) for _ in range(num_layers - 2)]
        # In the last layer, we will employ averaging for multi-head output by
        # setting concat to True.
        conv_layers.append(TransformerConv(hidden_dim, output_dim, heads=heads, beta=beta, concat=True))
        self.convs = torch.nn.ModuleList(conv_layers)

        # The list of layerNorm for each layer block.
        norm_layers = [torch.nn.LayerNorm(hidden_dim) for _ in range(num_layers - 1)]
        self.norms = torch.nn.ModuleList(norm_layers)

        # Probability of an element getting zeroed.
        self.dropout = dropout

    def reset_parameters(self):
        """
        Resets the parameters of the convolutional and normalization layers,
        ensuring they are re-initialized when needed.
        """
        for conv in self.convs:
            conv.reset_parameters()
        for norm in self.norms:
            norm.reset_parameters()

    def forward(self, x, edge_index):
        """
        The input features are passed sequentially through the transformer
        convolutional layers. After each convolutional layer (except the last),
        the following operations are applied:
        - Layer normalization (`LayerNorm`).
        - ReLU activation function.
        - Dropout for regularization.
        The final layer is processed without layer normalization and ReLU
        to average the multi-head results for the expected output.

        Params:
        - x: node features x
        - edge_index: edge indices.

        """
        for i in range(self.num_layers - 1):
            # Construct the network as shown in the model architecture.
            x = self.convs[i](x, edge_index)
            x = self.norms[i](x)
            x = F.relu(x)
            # By setting training to self.training, we will only apply dropout
            # during model training.
            x = F.dropout(x, p = self.dropout, training = self.training)

        # Last layer, average multi-head output.
        print(x.shape)
        x = self.convs[-1](x, edge_index)

        return x

In [10]:
def construct_graph(osmPath: str):
    class MapCreationHandler(osmium.SimpleHandler):
        def __init__(self) -> None:
            super().__init__()
            self.nodes = []
            self.edges = [[], []]

            self.node_id_to_idx = {}
            self.idx_to_node_id = {}
            self.id_counter = 0

        def node(self, n: osm.Node) -> None:
            self.nodes.append([n.location.lat, n.location.lon])
            self.node_id_to_idx[n.id] = self.id_counter
            self.idx_to_node_id[self.id_counter] = n.id
            self.id_counter += 1

        def way(self, w):
            node_refs = [node.ref for node in w.nodes]

            for i in range(len(node_refs) - 1):
                node_start = node_refs[i]
                node_end = node_refs[i + 1]
                
                self.edges[0].append(self.node_id_to_idx[node_start])
                self.edges[1].append(self.node_id_to_idx[node_end])

    mapCreator = MapCreationHandler()
    mapCreator.apply_file(osmPath, locations=True)

    x = torch.tensor(mapCreator.nodes, dtype=torch.float)
    edge_index = torch.tensor(mapCreator.edges, dtype=torch.long)
    data = Data(x=x, edge_index=edge_index)
    return data

In [11]:
graph = construct_graph("data/stanford.pbf")
print(graph)

Data(x=[23691, 2], edge_index=[2, 25324])


In [18]:
model = GTN(input_dim=2, hidden_dim=10, output_dim=10, num_layers=2,
            dropout=0.1, beta=True, heads=1)

In [19]:
model(graph.x, graph.edge_index)

torch.Size([23691, 10])


tensor([[-0.2378,  0.0272, -0.6937,  ..., -0.2134,  0.5544, -0.6964],
        [-0.2396,  0.0283, -0.6955,  ..., -0.2065,  0.5538, -0.6931],
        [-0.2188,  0.1258, -0.6812,  ..., -0.2771,  0.5459, -0.7264],
        ...,
        [-0.2496,  0.0200, -0.7001,  ..., -0.2077,  0.5651, -0.6873],
        [-0.2376,  0.0285, -0.6975,  ..., -0.2036,  0.5561, -0.6909],
        [-0.2396,  0.0283, -0.6955,  ..., -0.2065,  0.5538, -0.6931]],
       grad_fn=<AddBackward0>)