## Graph Attention Network (GAT)


### Importing Dependencies

We import the necessary libraries and functions, ensuring that all required modules and helper functions are properly integrated.

In [18]:
import networkx as nx
import os
import sys

from torch_geometric.data import Data
from torch_geometric.utils import from_networkx

# gat → models → src
src_path = os.path.abspath(os.path.join(os.getcwd(), "..", ".."))
if src_path not in sys.path:
    sys.path.append(src_path)


## Implementing a GAT Model

### From GAT to GATv2

We initially started by implementing a vanilla GAT model—an advanced type of Graph Neural Network (GNN) that leverages **attention mechanisms**. However, we soon realized that this model is not capable of learning from **edge attributes**, which is essential for our task. This limitation became especially critical because our original dataset does not contain any **node attributes** at all.

This led us to adopt the **GATv2** model, which is specifically designed to aggregate node features while also considering **edge attributes** during message passing. It is more suitable for our purposes.

The GATv2 model expects the following **inputs**:
- `x`: Node features → `data.x`
- `edge_index`: Edge list → `data.edge_index`
- `edge_attr`: Edge attributes → `data.edge_attr`

These inputs are automatically passed from the `Data` object when calling the model.

**Output:**  
The model returns **node representations (embeddings)**—a tensor with one row per node and one column per output feature.



In [19]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv


class GATv2(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, edge_dim, heads=1):
        super(GATv2, self).__init__()

        # First GATv2 layer, with edge attributes
        self.gat1 = GATv2Conv(in_channels, hidden_channels, heads=heads, edge_dim=edge_dim)

        # Second GATv2 layer, output dimension = out_channels
        self.gat2 = GATv2Conv(hidden_channels * heads, out_channels, heads=1, edge_dim=edge_dim)

    def forward(self, x, edge_index, edge_attr):
        # Apply first GATv2 layer with edge attributes
        x = self.gat1(x, edge_index, edge_attr)
        x = F.elu(x)

        # Apply second GATv2 layer
        x = self.gat2(x, edge_index, edge_attr)
        return x


### Advancing to an Encoder-Decoder Architecture

However, the original GATv2 model is primarily designed to learn **node embeddings**. These are useful for tasks such as node classification but are **not directly applicable to predicting edge attributes** like our target edge weight `tracks`.

Since our objective is to **predict edge values**, using a node-only model is insufficient. To address this, we extend the GATv2 architecture by incorporating a **decoder module** that transforms node embeddings into edge-level predictions.

Our final model follows a typical **encoder-decoder architecture**:

- **Encoder:**  
  We use the GATv2 model as the encoder to compute **informative node embeddings** based on the graph structure, edge attributes, and (if available) node features.

- **Decoder:**  
  As the decoder, we use a **small multilayer perceptron (MLP)** that takes as input the **concatenated embeddings** of each edge's source and target nodes.  
  This MLP outputs a **single scalar value per edge**, which serves as the prediction for the edge attribute `tracks`.


In [20]:
import torch.nn as nn

class GATv2EdgePredictor(nn.Module):
    def __init__(self, 
                 in_channels, 
                 hidden_channels, 
                 out_channels, 
                 edge_dim, 
                 heads=1):
        super(GATv2EdgePredictor, self).__init__()

        # 1. GATv2 model for computing node embeddings
        self.gnn = GATv2(in_channels, hidden_channels, out_channels, edge_dim, heads)

        # 2. Edge MLP to predict edge attributes (e.g., "tracks")
        self.edge_mlp = nn.Sequential(
            nn.Linear(out_channels * 2, out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, 1)  # Output: a single scalar per edge
        )

    def forward(self, data):
        """
        Args:
            data: PyTorch Geometric Data object with attributes:
                  - x: node features
                  - edge_index: edge connectivity (COO format)
                  - edge_attr: edge attributes

        Returns:
            pred: Tensor of shape [num_edges, 1] with predicted edge weights (e.g., "tracks")
        """
        # Compute node embeddings using the GATv2 model
        x = self.gnn(data.x, data.edge_index, data.edge_attr)  # [num_nodes, out_channels]

        # Construct edge representations by concatenating source and target node embeddings
        row, col = data.edge_index  # source & target node indices for each edge
        edge_inputs = torch.cat([x[row], x[col]], dim=1)  # [num_edges, out_channels * 2]

        # Predict edge weights
        pred = self.edge_mlp(edge_inputs)  # [num_edges, 1]
        return pred
