In [1]:
import torch
from torch_geometric.nn import TransformerConv
from torch_geometric.data import Data
import torch.nn.functional as F
import networkx as nx
from math import asin, cos, radians, sin, sqrt

import osmium
from osmium import osm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class GTN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers,
                 dropout, beta=True, heads=1):
        super(GTN, self).__init__()

        # The list of transormer conv layers for the each layer block.
        self.num_layers = num_layers
        conv_layers = [TransformerConv(input_dim, hidden_dim//heads, heads=heads, beta=beta)]
        conv_layers += [TransformerConv(hidden_dim, hidden_dim//heads, heads=heads, beta=beta) for _ in range(num_layers - 2)]
        # In the last layer, we will employ averaging for multi-head output by
        # setting concat to True.
        conv_layers.append(TransformerConv(hidden_dim, output_dim, heads=heads, beta=beta, concat=True))
        self.convs = torch.nn.ModuleList(conv_layers)

        # The list of layerNorm for each layer block.
        norm_layers = [torch.nn.LayerNorm(hidden_dim) for _ in range(num_layers - 1)]
        self.norms = torch.nn.ModuleList(norm_layers)

        # Probability of an element getting zeroed.
        self.dropout = dropout

    def reset_parameters(self):
        """
        Resets the parameters of the convolutional and normalization layers,
        ensuring they are re-initialized when needed.
        """
        for conv in self.convs:
            conv.reset_parameters()
        for norm in self.norms:
            norm.reset_parameters()

    def forward(self, x, edge_index):
        """
        The input features are passed sequentially through the transformer
        convolutional layers. After each convolutional layer (except the last),
        the following operations are applied:
        - Layer normalization (`LayerNorm`).
        - ReLU activation function.
        - Dropout for regularization.
        The final layer is processed without layer normalization and ReLU
        to average the multi-head results for the expected output.

        Params:
        - x: node features x
        - edge_index: edge indices.

        """
        for i in range(self.num_layers - 1):
            # Construct the network as shown in the model architecture.
            x = self.convs[i](x, edge_index)
            x = self.norms[i](x)
            x = F.relu(x)
            # By setting training to self.training, we will only apply dropout
            # during model training.
            x = F.dropout(x, p = self.dropout, training = self.training)

        # Last layer, average multi-head output.
        print(x.shape)
        x = self.convs[-1](x, edge_index)

        return x

In [7]:
RADIUS_EARTH = 6371000

def compute_distance(n1_longitude, n1_latitude, n2_longitude, n2_latitude) -> float:
    lon1, lat1 = radians(n1_longitude), radians(n1_latitude)
    lon2, lat2 = radians(n2_longitude), radians(n2_latitude)

    # Haversine formula
    deltaLon, deltaLat = lon2 - lon1, lat2 - lat1
    haversine = (sin(deltaLat / 2) ** 2) + (cos(lat1) * cos(lat2)) * (
        sin(deltaLon / 2) ** 2
    )

    # Return distance d (factor in radius of earth in meters)
    return 2 * RADIUS_EARTH * asin(sqrt(haversine))

In [12]:
def construct_graph(osmPath: str):
    class MapCreationHandler(osmium.SimpleHandler):
        def __init__(self) -> None:
            super().__init__()
            self.nodes = []
            self.edges = [[], []]
            self.edge_dist = []

            self.node_id_to_idx = {}
            self.idx_to_node_id = {}
            self.id_counter = 0

        def node(self, n: osm.Node) -> None:
            self.nodes.append([n.location.lat, n.location.lon])
            self.node_id_to_idx[n.id] = self.id_counter
            self.idx_to_node_id[self.id_counter] = n.id
            self.id_counter += 1

        def way(self, w):
            node_refs = [node.ref for node in w.nodes]

            for i in range(len(node_refs) - 1):
                node_start = node_refs[i]
                node_end = node_refs[i + 1]
                
                node_1_idx = self.node_id_to_idx[node_start]
                node_2_idx = self.node_id_to_idx[node_end]

                self.edges[0].append(node_1_idx)
                self.edges[1].append(node_2_idx)

                node_1 = self.nodes[node_1_idx]
                node_2 = self.nodes[node_2_idx]

                n1_longitude, n1_latitude = node_1
                n2_longitude, n2_latitude = node_2

                dist = compute_distance(n1_longitude, n1_latitude, n2_longitude, n2_latitude)
                self.edge_dist.append(dist)

    mapCreator = MapCreationHandler()
    mapCreator.apply_file(osmPath, locations=True)

    x = torch.tensor(mapCreator.nodes, dtype=torch.float)
    edge_index = torch.tensor(mapCreator.edges, dtype=torch.long)
    edge_attr = torch.tensor(mapCreator.edge_dist, dtype=torch.long).unsqueeze(1)
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
    return data

In [13]:
graph = construct_graph("data/stanford.pbf")
print(graph.x)
print(graph.edge_index)
print(graph.edge_attr)

tensor([[  37.4340, -122.1725],
        [  37.4342, -122.1726],
        [  37.4344, -122.1724],
        ...,
        [  37.4312, -122.1713],
        [  37.4312, -122.1712],
        [  37.4314, -122.1710]])
tensor([[13919, 13766, 13981,  ..., 23687, 23688, 23689],
        [13766, 13981, 20493,  ..., 23006, 23689, 19361]])
tensor([[ 8],
        [ 7],
        [52],
        ...,
        [ 2],
        [12],
        [10]])


In [5]:
model = GTN(input_dim=2, hidden_dim=10, output_dim=10, num_layers=2,
            dropout=0.1, beta=True, heads=1)

In [6]:
model(graph.x, graph.edge_index)

torch.Size([23691, 10])


tensor([[-0.4025, -0.1049, -0.4049,  ...,  0.1012, -0.3605, -0.0113],
        [-0.7328, -0.0375, -0.6804,  ...,  0.4153, -0.1687, -0.0946],
        [-0.6782, -0.0241, -0.6366,  ...,  0.4120, -0.0943, -0.1269],
        ...,
        [-0.6764, -0.0318, -0.7375,  ...,  0.3310, -0.1892, -0.1302],
        [-0.6764, -0.0318, -0.7375,  ...,  0.3310, -0.1892, -0.1302],
        [-0.3751, -0.2072, -0.3531,  ..., -0.0116, -0.4586, -0.0183]],
       grad_fn=<AddBackward0>)

In [18]:
import osmium
import torch
import random
from torch_geometric.data import Data

def construct_graph_order(osmPath: str, num_graphs=100, num_nodes=10):
    class MapCreationHandler(osmium.SimpleHandler):
        def __init__(self) -> None:
            super().__init__()
            self.nodes = []
            self.edges = [[], []]
            self.node_id_to_idx = {}
            self.idx_to_node_id = {}
            self.id_counter = 0

        def node(self, n: osmium.osm.Node) -> None:
            self.nodes.append([n.location.lat, n.location.lon])
            self.node_id_to_idx[n.id] = self.id_counter
            self.idx_to_node_id[self.id_counter] = n.id
            self.id_counter += 1

        def way(self, w):
            node_refs = [node.ref for node in w.nodes]
            for i in range(len(node_refs) - 1):
                node_start = node_refs[i]
                node_end = node_refs[i + 1]
                if node_start in self.node_id_to_idx and node_end in self.node_id_to_idx:
                    self.edges[0].append(self.node_id_to_idx[node_start])
                    self.edges[1].append(self.node_id_to_idx[node_end])

    mapCreator = MapCreationHandler()
    mapCreator.apply_file(osmPath, locations=True)
    all_nodes = torch.tensor(mapCreator.nodes, dtype=torch.float)

    dataset = []

    for _ in range(num_graphs):
        sampled_node_indices = random.sample(range(len(all_nodes)), num_nodes)
        x = all_nodes[sampled_node_indices]
        local_node_indices = {global_idx: local_idx for local_idx, global_idx in enumerate(sampled_node_indices)}
        edge_index = torch.tensor([[local_node_indices[i], local_node_indices[i]] for i in sampled_node_indices], dtype=torch.long).t()
        target_order = torch.arange(num_nodes, dtype=torch.long)

        # Each datapoint has the following information: node information, edge information (partial or full?), and a target ordering
        data = Data(x=x, edge_index=edge_index, target_order=target_order)
        dataset.append(data)

    return dataset


In [28]:
# Construct the dataset
dataset = construct_graph_order("data/stanford.pbf")

In [21]:
class GTN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, dropout, beta=True, heads=1):
        super(GTN, self).__init__()

        self.num_layers = num_layers
        self.dropout = dropout

        conv_layers = [TransformerConv(input_dim, hidden_dim // heads, heads=heads, beta=beta)]
        conv_layers += [TransformerConv(hidden_dim, hidden_dim // heads, heads=heads, beta=beta) for _ in range(num_layers - 2)]
        conv_layers.append(TransformerConv(hidden_dim, output_dim, heads=heads, beta=beta, concat=True))
        self.convs = torch.nn.ModuleList(conv_layers)

        norm_layers = [torch.nn.LayerNorm(hidden_dim) for _ in range(num_layers - 1)]
        self.norms = torch.nn.ModuleList(norm_layers)

        # MLP for ordering scores
        self.fc1 = torch.nn.Linear(output_dim, output_dim // 2)
        self.fc2 = torch.nn.Linear(output_dim // 2, output_dim)
        self.fc3 = torch.nn.Linear(output_dim, 1)

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for norm in self.norms:
            norm.reset_parameters()
        for layer in [self.fc1, self.fc2, self.fc3]:
            layer.reset_parameters()

    def forward(self, x, edge_index):

        if edge_index.size(1) == 0: 
            edge_index = torch.stack([torch.arange(x.size(0)), torch.arange(x.size(0))], dim=0).to(x.device)

        # We should add additional information to the network
        for i in range(self.num_layers - 1):
            x = self.convs[i](x, edge_index)
            x = self.norms[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.convs[-1](x, edge_index)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        scores = self.fc3(x).squeeze()  
        
        # Since this is an ordering problem, how should we compare our score to ground truth?
        return scores


In [29]:
import torch
import torch.optim as optim
from torch_geometric.loader import DataLoader

# Training parameters
input_dim = 2      
hidden_dim = 64
output_dim = 32
num_layers = 3
dropout = 0.5
learning_rate = 0.01
num_epochs = 20
batch_size = 1 

device = "cpu"
model = GTN(input_dim, hidden_dim, output_dim, num_layers, dropout).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

def loss_(scores, target_order):
   
   # TODO: How should we handle the loss using score and target?
    target_ranks = torch.arange(len(target_order), dtype=torch.float, device=scores.device)
    ordered_scores = scores[target_order]
    loss = F.mse_loss(ordered_scores, target_ranks)

    return loss


def train(model, data_loader, optimizer, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0.0
        for batch in data_loader:
            batch = batch.to(device)  

            scores = model(batch.x, batch.edge_index)
            loss = loss_(scores, batch.target_order)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(data_loader)
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}")

train(model, data_loader, optimizer, num_epochs)


Epoch [1/20], Loss: 9.8694
Epoch [2/20], Loss: 8.8360
Epoch [3/20], Loss: 8.5482
Epoch [4/20], Loss: 8.5193
Epoch [5/20], Loss: 8.2930
Epoch [6/20], Loss: 8.2387
Epoch [7/20], Loss: 8.3118
Epoch [8/20], Loss: 8.2731
Epoch [9/20], Loss: 8.2657
Epoch [10/20], Loss: 8.2668
Epoch [11/20], Loss: 8.2569
Epoch [12/20], Loss: 8.2494
Epoch [13/20], Loss: 8.2760
Epoch [14/20], Loss: 8.2552
Epoch [15/20], Loss: 8.2524
Epoch [16/20], Loss: 8.2500
Epoch [17/20], Loss: 8.2500
Epoch [18/20], Loss: 8.2500
Epoch [19/20], Loss: 8.2500
Epoch [20/20], Loss: 8.2500
