In [46]:
import torch
from torch_geometric.nn import TransformerConv
from torch_geometric.data import Data
import torch.nn.functional as F
import torch.nn as nn
import networkx as nx
from math import asin, cos, radians, sin, sqrt

import osmium
from osmium import osm

In [7]:
RADIUS_EARTH = 6371000

def compute_distance(n1_longitude, n1_latitude, n2_longitude, n2_latitude) -> float:
    lon1, lat1 = radians(n1_longitude), radians(n1_latitude)
    lon2, lat2 = radians(n2_longitude), radians(n2_latitude)

    # Haversine formula
    deltaLon, deltaLat = lon2 - lon1, lat2 - lat1
    haversine = (sin(deltaLat / 2) ** 2) + (cos(lat1) * cos(lat2)) * (
        sin(deltaLon / 2) ** 2
    )

    # Return distance d (factor in radius of earth in meters)
    return 2 * RADIUS_EARTH * asin(sqrt(haversine))

In [33]:
def construct_graph(osmPath: str):
    class MapCreationHandler(osmium.SimpleHandler):
        def __init__(self) -> None:
            super().__init__()
            self.nodes = []
            self.edges = [[], []]
            self.edge_dist = []

            self.node_id_to_idx = {}
            self.idx_to_node_id = {}
            self.id_counter = 0

        def node(self, n: osm.Node) -> None:
            self.nodes.append([n.location.lat, n.location.lon])
            self.node_id_to_idx[n.id] = self.id_counter
            self.idx_to_node_id[self.id_counter] = n.id
            self.id_counter += 1

        def way(self, w):
            node_refs = [node.ref for node in w.nodes]

            for i in range(len(node_refs) - 1):
                node_start = node_refs[i]
                node_end = node_refs[i + 1]
                
                node_1_idx = self.node_id_to_idx[node_start]
                node_2_idx = self.node_id_to_idx[node_end]

                self.edges[0].append(node_1_idx)
                self.edges[1].append(node_2_idx)

                node_1 = self.nodes[node_1_idx]
                node_2 = self.nodes[node_2_idx]

                n1_longitude, n1_latitude = node_1
                n2_longitude, n2_latitude = node_2

                dist = compute_distance(n1_longitude, n1_latitude, n2_longitude, n2_latitude)
                self.edge_dist.append(dist)

    mapCreator = MapCreationHandler()
    mapCreator.apply_file(osmPath, locations=True)

    x = torch.tensor(mapCreator.nodes, dtype=torch.float)
    edge_index = torch.tensor(mapCreator.edges, dtype=torch.long)
    edge_attr = torch.tensor(mapCreator.edge_dist, dtype=torch.float).unsqueeze(1)
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr,)
    return data

In [34]:
graph = construct_graph("data/stanford.pbf")
print(graph.x)
print(graph.edge_index)
print(graph.edge_attr)

tensor([[  37.4340, -122.1725],
        [  37.4342, -122.1726],
        [  37.4344, -122.1724],
        ...,
        [  37.4312, -122.1713],
        [  37.4312, -122.1712],
        [  37.4314, -122.1710]])
tensor([[13919, 13766, 13981,  ..., 23687, 23688, 23689],
        [13766, 13981, 20493,  ..., 23006, 23689, 19361]])
tensor([[ 8.3168],
        [ 7.1985],
        [52.1610],
        ...,
        [ 2.1193],
        [12.8045],
        [10.2401]])


In [38]:
class GTN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers,
                 dropout, beta=True, heads=1):
        super(GTN, self).__init__()

        self.num_layers = num_layers

        # Initialize transformer convolution layers with edge attributes
        conv_layers = [TransformerConv(input_dim, hidden_dim // heads, heads=heads, edge_dim=1, beta=beta)]
        conv_layers += [TransformerConv(hidden_dim, hidden_dim // heads, heads=heads, edge_dim=1, beta=beta) for _ in range(num_layers - 2)]
        conv_layers.append(TransformerConv(hidden_dim, output_dim, heads=heads, edge_dim=1, beta=beta, concat=True))
        self.convs = torch.nn.ModuleList(conv_layers)

        # Initialize LayerNorm layers for normalization
        norm_layers = [torch.nn.LayerNorm(hidden_dim) for _ in range(num_layers - 1)]
        self.norms = torch.nn.ModuleList(norm_layers)

        self.dropout = dropout

    def reset_parameters(self):
        """Resets parameters for the convolutional and normalization layers."""
        for conv in self.convs:
            conv.reset_parameters()
        for norm in self.norms:
            norm.reset_parameters()

    def forward(self, x, edge_index, edge_attr):
        """
        Forward pass with edge attributes.
        - x: Node features
        - edge_index: Edge indices
        - edge_attr: Edge attributes
        """
        for i in range(self.num_layers - 1):
            x = self.convs[i](x, edge_index, edge_attr)  # Include edge_attr
            x = self.norms[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

        # Last layer, average multi-head output.
        x = self.convs[-1](x, edge_index, edge_attr)  # Include edge_attr

        return x

In [39]:
model = GTN(input_dim=2, hidden_dim=10, output_dim=10, num_layers=2,
            dropout=0.1, beta=True, heads=1)

In [43]:
node_embeddings = model(graph.x, graph.edge_index, graph.edge_attr)
print(node_embeddings.shape)

torch.Size([23691, 10])


In [73]:
class NodeTransformer(nn.Module):
    def __init__(self, embed_dim, num_heads, num_layers, max_seq_len=48, ff_dim=2048, dropout=0.1):
        super(NodeTransformer, self).__init__()
        
        self.max_seq_len = max_seq_len
        self.embed_dim = embed_dim
        
        # Learnable special embeddings for fixed start and end node
        self.start_node_embed_tag = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.end_node_embed_tag = nn.Parameter(torch.randn(1, 1, embed_dim))
        
        # Transformer encoder layers
        self.encoder_layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=embed_dim, 
                nhead=num_heads, 
                dim_feedforward=ff_dim, 
                dropout=dropout, 
                activation='gelu'
            ) 
            for _ in range(num_layers)
        ])
        
        # LayerNorm to stabilize the output
        self.norm = nn.LayerNorm(embed_dim)

        self.head = nn.Linear(embed_dim, 1)

    def forward(self, waypoint_node_embeds, start_node_embed, end_node_embed):
        """
        Args:
            waypoint_node_embeds: Tensor of shape (batch_size, seq_len, embed_dim), where seq_len <= max_seq_len.
            start_node_embed: Tensor of shape (batch_size, 1, embed_dim) representing the first fixed node embedding.
            end_node_embed: Tensor of shape (batch_size, 1, embed_dim) representing the second fixed node embedding.

        Returns:
            Tensor of shape (batch_size, seq_len + 2, embed_dim).
        """
        batch_size, seq_len, embed_dim = waypoint_node_embeds.shape
        
        assert seq_len <= self.max_seq_len, f"Sequence length should be <= {self.max_seq_len}"
        assert embed_dim == self.embed_dim, f"Embedding dimension mismatch: {embed_dim} != {self.embed_dim}"

        # Add learnable tags to fixed nodes
        start_node_embed = start_node_embed + self.start_node_embed_tag  # Shape: (batch_size, 1, embed_dim)
        end_node_embed = end_node_embed + self.end_node_embed_tag  # Shape: (batch_size, 1, embed_dim)

        # Concatenate fixed nodes with the variable-length sequence
        fixed_nodes = torch.cat([start_node_embed, end_node_embed], dim=1)  # Shape: (batch_size, 2, embed_dim)
        full_sequence = torch.cat([fixed_nodes, waypoint_node_embeds], dim=1)  # Shape: (batch_size, seq_len+2, embed_dim)
        
        # Pass through the Transformer encoder layers
        x = full_sequence
        for layer in self.encoder_layers:
            x = layer(x)
        
        # Apply LayerNorm
        x = self.norm(x)
        # x = self.head(x)
        
        return x

In [74]:
batch_size = 32
seq_len = 48
embed_dim = 10
node_transformer_model = NodeTransformer(embed_dim=embed_dim, num_heads=1, num_layers=4)

In [76]:
start_node_idx = 13
end_node_idx = 14
waypoint_node_indices = [10, 20, 30]

start_node_embed = node_embeddings[start_node_idx].unsqueeze(0).unsqueeze(0)
end_node_embed = node_embeddings[end_node_idx].unsqueeze(0).unsqueeze(0)
waypoint_node_embeds = node_embeddings[waypoint_node_indices].unsqueeze(0)

In [78]:
output = node_transformer_model(waypoint_node_embeds, start_node_embed, end_node_embed)
output.shape

torch.Size([1, 5, 10])

In [29]:
import torch
import torch.optim as optim
from torch_geometric.loader import DataLoader

# Training parameters
input_dim = 2      
hidden_dim = 64
output_dim = 32
num_layers = 3
dropout = 0.5
learning_rate = 0.01
num_epochs = 20
batch_size = 1 

device = "cpu"
model = GTN(input_dim, hidden_dim, output_dim, num_layers, dropout).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

def loss_(scores, target_order):
   
   # TODO: How should we handle the loss using score and target?
    target_ranks = torch.arange(len(target_order), dtype=torch.float, device=scores.device)
    ordered_scores = scores[target_order]
    loss = F.mse_loss(ordered_scores, target_ranks)

    return loss


def train(model, data_loader, optimizer, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0.0
        for batch in data_loader:
            batch = batch.to(device)  

            scores = model(batch.x, batch.edge_index)
            loss = loss_(scores, batch.target_order)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(data_loader)
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}")

train(model, data_loader, optimizer, num_epochs)


Epoch [1/20], Loss: 9.8694
Epoch [2/20], Loss: 8.8360
Epoch [3/20], Loss: 8.5482
Epoch [4/20], Loss: 8.5193
Epoch [5/20], Loss: 8.2930
Epoch [6/20], Loss: 8.2387
Epoch [7/20], Loss: 8.3118
Epoch [8/20], Loss: 8.2731
Epoch [9/20], Loss: 8.2657
Epoch [10/20], Loss: 8.2668
Epoch [11/20], Loss: 8.2569
Epoch [12/20], Loss: 8.2494
Epoch [13/20], Loss: 8.2760
Epoch [14/20], Loss: 8.2552
Epoch [15/20], Loss: 8.2524
Epoch [16/20], Loss: 8.2500
Epoch [17/20], Loss: 8.2500
Epoch [18/20], Loss: 8.2500
Epoch [19/20], Loss: 8.2500
Epoch [20/20], Loss: 8.2500
