#### Final model Architecture

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops

In [None]:
# Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.linear = nn.Linear(1, d_model)

    def forward(self, x):
        pos = torch.arange(x.shape[0], dtype=torch.float32, device=x.device).unsqueeze(1)
        return self.linear(pos)

# Temporal Encoding
class TemporalEncoding(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.linear = nn.Linear(1, d_model)

    def forward(self, time_steps):
        return self.linear(time_steps.unsqueeze(1).float())

# Custom GNN Layer
class CustomGNNLayer(MessagePassing):
    def __init__(self, in_channels, out_channels, edge_in_chanel=5):
        super().__init__(aggr='mean')
        self.node_fc = nn.Linear(in_channels, out_channels)
        self.edge_fc = nn.Linear(edge_in_chanel, out_channels)
        self.norm = nn.LayerNorm(out_channels)

    def forward(self, x, edge_index, edge_attr):
        x = self.node_fc(x)
        edge_attr = self.edge_fc(edge_attr)
        return self.norm(self.propagate(edge_index, x=x, edge_attr=edge_attr))

    def message(self, x_j, edge_attr):
        return x_j + edge_attr

# Transformer Layer
class TransformerLayer(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim=in_dim, num_heads=4, batch_first=True)
        self.norm = nn.LayerNorm(out_dim)

    def forward(self, x, edge_index):
        x = x.unsqueeze(1)  
        attn_output, _ = self.self_attn(x, x, x)
        return self.norm(attn_output.squeeze(1))

In [None]:
# GNN + GRU Model
class GNN_GRU_Model(nn.Module):
    """
GNN-GRU Model for  Graph-Based Travel Time Prediction

This script defines a neural network architecture combining Graph Neural Networks (GNNs),
Transformers, and Gated Recurrent Units (GRUs) for predicting travel time in  graphs.
The model captures both spatial and temporal dependencies using the following components:

1. **PositionalEncoding**: Encodes node positions in the input sequence.
2. **TemporalEncoding**: Captures time-dependent features using a linear transformation.
3. **CustomGNNLayer**: A message-passing GNN layer that processes node and edge features.
4. **TransformerLayer**: Applies self-attention to enhance feature representation.
5. **GNN_GRU_Model**: The main architecture combining GNNs, Transformers, and GRUs.
6. **train_model**: Implements the training loop with MSE loss and Adam optimizer.

### Workflow:
- Load or generate graph snapshots containing node and edge features.
- Train the GNN-GRU model using past time-step snapshots.
- Predict travel time for the next time steps.
- Evaluate performance using Mean Squared Error (MSE).

### Inputs:
- **Node features**: Temporal and spatial attributes per node.
- **Edge features**: Graph connectivity and relationships between nodes.
- **Time-step sequences**: Past snapshots for temporal modeling.

### Outputs:
- **Predicted travel time** for each edge at the next time step.

This model is optimized for real-world traffic prediction tasks, where  graphs
capture evolving traffic conditions and vehicle movements over time.
"""

    def __init__(self, node_feat_dim, edge_feat_dim, hidden_dim, output_dim, num_layers=3):
        super().__init__()
        self.pos_enc = PositionalEncoding(hidden_dim)
        self.temp_enc = TemporalEncoding(hidden_dim)
        self.gnn_layers = nn.ModuleList([CustomGNNLayer(hidden_dim, hidden_dim, edge_feat_dim) for _ in range(num_layers)])
        self.transformer = TransformerLayer(hidden_dim, hidden_dim)
        self.gru = nn.GRU(hidden_dim, edge_feat_dim, batch_first=True)
        self.fc_out = nn.Linear(edge_feat_dim, output_dim)

    def forward(self, x_seq, edge_index, edge_attr_seq, time_steps_seq):
        num_nodes = x_seq.shape[1]  
        past_steps = x_seq.shape[0]  

        x_out = []
        for t in range(past_steps):
            x_t = self.pos_enc(x_seq[t]) + self.temp_enc(time_steps_seq[t])  
            for gnn in self.gnn_layers:
                x_t = gnn(x_t, edge_index, edge_attr_seq[t])  
            x_out.append(x_t)

        x_out = torch.stack(x_out, dim=1)  
        
        x_out, _ = self.gru(x_out)  

        return self.fc_out(x_out[:, -1, :])

In [None]:
# Sample Graph Generation
def generate_sample_graphs(num_snapshots, num_nodes, node_feat_dim, edge_feat_dim , past_steps=12):
    graphs = []
    all_x = torch.randn(num_snapshots, num_nodes, node_feat_dim)
    all_edge_attr = torch.randn(num_snapshots, num_nodes * 2, edge_feat_dim)
    all_time_steps = torch.randint(0, 24, (num_snapshots, num_nodes))

    for i in range(num_snapshots - past_steps):
        x_seq = all_x[i:i + past_steps]  # (past_steps, num_nodes, node_feat_dim)
        edge_attr_seq = all_edge_attr[i:i + past_steps]  # (past_steps, num_edges, edge_feat_dim)
        time_seq = all_time_steps[i:i + past_steps]  # (past_steps, num_nodes)
        edge_index = torch.randint(0, num_nodes, (2, num_nodes * 2))  # Static edge index
        graphs.append((x_seq, edge_index, edge_attr_seq, time_seq))
    return graphs

In [12]:
# Evaluation
def evaluate_model(model, graphs, save_path, details):
    """
    Evaluates the given model on the provided graph snapshots.
    
    For each graph snapshot in 'graphs', the model's output is compared with the corresponding
    target loaded from the file. The function computes both the Mean Squared Error (MSE) and 
    the Mean Absolute Error (MAE) across all graph snapshots, and then prints the average losses.
    
    Args:
        model (nn.Module): The trained model to evaluate.
        graphs (list): A list of tuples, where each tuple contains 
                       (x, edge_index, edge_attr, time_steps) for a graph snapshot.
                       
    Returns:
        None. Prints the average MSE and MAE losses.
    """
    model.eval()
    # Load targets and select the evaluation portion
    targets = torch.load("../train_data/rio_data/rio_data_target.pth")
    targets = targets[3010:]
    
    mse_loss_total = 0.0
    mae_loss_total = 0.0
    criterion_mse = nn.MSELoss()
    criterion_mae = nn.L1Loss()  # L1 loss corresponds to MAE
    
    with torch.no_grad():
        for i, (x, edge_index, edge_attr, time_steps) in enumerate(graphs):
            output = model(x, edge_index, edge_attr, time_steps)
            target = targets[i]
            mse_loss = criterion_mse(output, target)
            mae_loss = criterion_mae(output, target)
            mse_loss_total += mse_loss.item()
            mae_loss_total += mae_loss.item()
            
        avg_mse_loss = mse_loss_total / len(graphs)
        avg_mae_loss = mae_loss_total / len(graphs)
        print(f"Evaluation MSE Loss: {avg_mse_loss:.4f}, MAE Loss: {avg_mae_loss:.4f}")
        with open(save_path, "a") as f:
            f.write(f"Evaluation Results {details}\n")
            f.write(f"MSE Loss: {avg_mse_loss:.4f}\n")
            f.write(f"MAE Loss: {avg_mae_loss:.4f}\n")
            f.write("=" * 30 + "\n")

In [None]:
# Training Loop
def train_model(model, graphs, epochs=10, lr=0.001):
    """
    Trains a spatiotemporal graph-based model using a given dataset.

    Parameters:
    -----------
    model : torch.nn.Module
        The neural network model to be trained.
    graphs : iterable
        An iterable containing graph data tuples (x_seq, edge_index, edge_attr_seq, time_steps_seq),
        where:
        - x_seq : Tensor representing node features over time.
        - edge_index : Tensor defining graph connectivity.
        - edge_attr_seq : Tensor representing edge features over time.
        - time_steps_seq : Tensor representing the temporal sequence.
    epochs : int, optional (default=10)
        The number of training epochs.
    lr : float, optional (default=0.001)
        Learning rate for the Adam optimizer.

    Training Process:
    -----------------
    - Uses Mean Squared Error (MSE) loss function.
    - Performs forward pass, computes loss, backpropagates gradients, and updates model weights.
    - Iterates through the dataset, training the model on each graph sample.

    Prints:
    -------
    - The average loss per epoch.

    """
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    targets = torch.load("../train_data/rio_data/rio_data_target.pth")
    targets = targets[:3010]

    for epoch in range(epochs):
        total_loss = 0
        i = 0
        for x_seq, edge_index, edge_attr_seq, time_steps_seq in graphs:
            optimizer.zero_grad()
            output = model(x_seq, edge_index, edge_attr_seq, time_steps_seq)
            target = targets[i]
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            i+=1
        print(f"Epoch {epoch+1}, Loss: {total_loss / len(graphs):.4f}")
        # print(target)



num_snapshots = 4300
num_nodes = 196
node_feat_dim = 8
edge_feat_dim = 2
hidden_dim = 16
output_dim = 12  # Predict next 12 time steps
batch_size = 32

sample_graphs = loaded_tensor = torch.load("../train_data/rio_data/rio_data.pth")
train = sample_graphs[:3010]
test = sample_graphs[3010:]
# sample_graphs = generate_sample_graphs(num_snapshots, num_nodes, node_feat_dim, edge_feat_dim)
model = GNN_GRU_Model(node_feat_dim, edge_feat_dim, hidden_dim, output_dim)
train_model(model, train,10)
evaluate_model(model, test, save_path='../train_results/rio/evaluation.txt', details='10 epochs')

  sample_graphs = loaded_tensor = torch.load("../train_data/rio_data/rio_data.pth")
  targets = torch.load("../train_data/rio_data/rio_data_target.pth")


Evaluation MSE Loss: 2871.0100, MAE Loss: 15.2254
