
Title: Graph attention network (GAT) for node classification
Author: [akensert](https://github.com/akensert)
Date created: 2021/09/13
Last modified: 2021/12/26
Description: An implementation of a Graph Attention Network (GAT) for node classification.
Accelerator: GPU



## Introduction

[Graph neural networks](https://en.wikipedia.org/wiki/Graph_neural_network)
is the preferred neural network architecture for processing data structured as
graphs (for example, social networks or molecule structures), yielding
better results than fully-connected networks or convolutional networks.

In this tutorial, we will implement a specific graph neural network known as a
[Graph Attention Network](https://arxiv.org/abs/1710.10903) (GAT) to predict labels of
scientific papers based on what type of papers cite them (using the
[Cora](https://linqs.soe.ucsc.edu/data) dataset).

### References

For more information on GAT, see the original paper
[Graph Attention Networks](https://arxiv.org/abs/1710.10903) as well as
[DGL's Graph Attention Networks](https://docs.dgl.ai/en/0.4.x/tutorials/models/1_gnn/9_gat.html)
documentation.


In [54]:

### Import packages


import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import pandas as pd
import os
import warnings

warnings.filterwarnings("ignore")
pd.set_option("display.max_columns", 6)
pd.set_option("display.max_rows", 6)
np.random.seed(2)


## Obtain the dataset

In [55]:
# zip_file = keras.utils.get_file(
#     fname="cora.tgz",
#     origin="https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz",
#     extract=True,
# )

In [56]:
# data_dir = os.path.join(os.path.dirname(zip_file), "cora_extracted/cora")
# data_dir

In [57]:
# citations = pd.read_csv(
#     os.path.join(data_dir, "cora.cites"),
#     sep="\t",
#     header=None,
#     names=["target", "source"],
# )

# papers = pd.read_csv(
#     os.path.join(data_dir, "cora.content"),
#     sep="\t",
#     header=None,
#     names=["paper_id"] + [f"term_{idx}" for idx in range(1433)] + ["subject"],
# )

In [58]:
nodes_df.isnull().any().any()

False

In [59]:
# Define the dataset directory
dataset_dir = "dataset"


# Function to find all scene IDs in the dataset directory
def find_all_scene_ids(dataset_dir):
    scene_ids = []
    for file in os.listdir(dataset_dir):
        if file.endswith(".edges"):
            scene_id = file.split(".")[0]
            scene_ids.append(scene_id)
    return scene_ids


# Function to load all subgraphs for the found scene IDs
def load_all_subgraphs(dataset_dir):
    scene_ids = find_all_scene_ids(dataset_dir)
    subgraphs = {}

    for scene_id in scene_ids:
        edges_file = os.path.join(dataset_dir, f"{scene_id}.edges")
        nodes_file = os.path.join(dataset_dir, f"{scene_id}.nodes")

        # Check if both files exist
        if not os.path.exists(edges_file) or not os.path.exists(nodes_file):
            print(f"Skipping scene ID {scene_id}: Missing files.")
            continue

        # Load edges
        edges = pd.read_csv(edges_file, sep=",", header=None, names=["target", "source"])

        # Load nodes
        nodes = pd.read_csv(
            nodes_file,
            sep=",",
            header=None,
            names=["node_id", "current_x", "current_y", "previous_x", "previous_y", "future_x", "future_y"],
        )

        for col in nodes.columns:
            nodes[col] = pd.to_numeric(nodes[col], errors="coerce")

        if nodes.isnull().any().any():
            # Step 1: Identify rows with NaN values in nodes_df
            nan_nodes = nodes[nodes.isnull().any(axis=1)]

            # Step 2: Extract the node_id values of those rows
            nan_node_ids = nan_nodes['node_id'].tolist()

            # Step 3: Filter out edges in edges_df where source or target is in nan_node_ids
            # Display the filtered edges
            print(f"Original edges count: {len(edges)}")
            print(f"Original nodes count: {len(nodes)}")
            edges = edges[
                ~edges['source'].isin(nan_node_ids) & ~edges['target'].isin(nan_node_ids)
            ]

            print(f"Filtered edges count: {len(edges)}")
            nodes = nodes.dropna(subset=
            ["future_x", "future_y"])
            print(f"Filtered nodes count: {len(nodes)}")

        # # Filter out edges with -1 as source value
        # edges = edges[edges["source"] != -1]

        # Check if there are any -1 edges
        if (edges["source"] == -1).any() or (edges["target"] == -1).any():
            print(f"Scene ID {scene_id} contains -1 edges. Processing...")

            # Remove edges with -1 as source or target
            edges = edges[(edges["source"] != -1) & (edges["target"] != -1)]

            # Get unique node IDs from the remaining edges
            connected_nodes = pd.unique(edges[["target", "source"]].values.ravel())

            # Filter nodes to keep only those that are connected
            nodes = nodes[nodes["node_id"].isin(connected_nodes)]

        # Store the subgraph
        subgraphs[scene_id] = {"edges": edges, "nodes": nodes}

    return subgraphs


# Example usage
subgraphs = load_all_subgraphs(dataset_dir)
print(f"Loaded {len(subgraphs)} subgraphs.")
for scene_id, subgraph in list(subgraphs.items())[:3]:  # Display the first 3 subgraphs
    print(f"\nScene ID: {scene_id}")
    print("Edges:")
    print(subgraph["edges"].head())
    print("Nodes:")
    print(subgraph["nodes"].head())

Original edges count: 15
Original nodes count: 13
Filtered edges count: 14
Filtered nodes count: 13
Scene ID 1352890817715 contains -1 edges. Processing...
Original edges count: 16
Original nodes count: 14
Filtered edges count: 15
Filtered nodes count: 13
Scene ID 1352890814428 contains -1 edges. Processing...
Scene ID 1352890802323 contains -1 edges. Processing...
Original edges count: 23
Original nodes count: 12
Filtered edges count: 22
Filtered nodes count: 11
Scene ID 1352890800322 contains -1 edges. Processing...
Scene ID 1352890875617 contains -1 edges. Processing...
Original edges count: 16
Original nodes count: 13
Filtered edges count: 13
Filtered nodes count: 13
Scene ID 1352890804562 contains -1 edges. Processing...
Original edges count: 14
Original nodes count: 10
Filtered edges count: 13
Filtered nodes count: 9
Scene ID 1352890841688 contains -1 edges. Processing...
Scene ID 1352890837555 contains -1 edges. Processing...
Scene ID 1352890825684 contains -1 edges. Processing.

Scene ID 1352890828882 contains -1 edges. Processing...
Scene ID 1352890891802 contains -1 edges. Processing...
Scene ID 1352890801118 contains -1 edges. Processing...
Original edges count: 5
Original nodes count: 6
Filtered edges count: 4
Filtered nodes count: 6
Scene ID 1352890849798 contains -1 edges. Processing...
Original edges count: 7
Original nodes count: 9
Filtered edges count: 4
Filtered nodes count: 9
Scene ID 1352890894347 contains -1 edges. Processing...
Original edges count: 23
Original nodes count: 12
Filtered edges count: 22
Filtered nodes count: 11
Scene ID 1352890800459 contains -1 edges. Processing...
Original edges count: 14
Original nodes count: 13
Filtered edges count: 10
Filtered nodes count: 12
Scene ID 1352890803486 contains -1 edges. Processing...
Original edges count: 35
Original nodes count: 13
Filtered edges count: 32
Filtered nodes count: 13
Scene ID 1352890829713 contains -1 edges. Processing...
Scene ID 1352890828668 contains -1 edges. Processing...
Scen

In [60]:
all_nodes = []
all_edges = []

for scene_id, graph in subgraphs.items():
    nodes = graph["nodes"].copy()
    edges = graph["edges"].copy()

    nodes["scene_id"] = scene_id
    edges["scene_id"] = scene_id

    all_nodes.append(nodes)
    all_edges.append(edges)

nodes_df = pd.concat(all_nodes, ignore_index=True)
edges_df = pd.concat(all_edges, ignore_index=True)

In [61]:
nodes_df

Unnamed: 0,node_id,current_x,current_y,...,future_x,future_y,scene_id
0,19585800,27320.0,-17405.0,...,28033.0,-17874.0,1352890817715
1,19590700,27689.0,-16188.0,...,28276.0,-16736.0,1352890817715
2,19591900,21413.0,-13728.0,...,22412.0,-14133.0,1352890817715
...,...,...,...,...,...,...,...
1666,20011200,-13353.0,15703.0,...,-12540.0,15886.0,1352890910955
1667,20013400,30532.0,-16974.0,...,31451.0,-17345.0,1352890910955
1668,20013402,30190.0,-17870.0,...,31137.0,-18268.0,1352890910955


In [62]:
edges_df

Unnamed: 0,target,source,scene_id
0,19585800,19590700,1352890817715
1,19591900,19592201,1352890817715
2,19591900,19595300,1352890817715
...,...,...,...
2718,20002900,20004700,1352890910955
2719,20010700,20011200,1352890910955
2720,20013400,20013402,1352890910955


In [64]:
import tensorflow as tf
from tensorflow.keras import Model, layers, optimizers, losses
from sklearn.model_selection import train_test_split

np.random.seed(42)
tf.random.set_seed(42)

def process_scene(scene_data):
    try:
        # Hämta dataframes
        df_nodes = scene_data["nodes"].reset_index(drop=True)
        # df_edges = scene_data["edges"].reset_index(drop=True)

        # Konvertera noder till input features och targets
        # Här antar vi att vi använder current_x och current_y som input, och future_x och future_y som target.
        features = df_nodes[["current_x", "current_y"]].to_numpy().astype(np.float32)
        targets = df_nodes[["future_x", "future_y"]].to_numpy().astype(np.float32)

        # num_nodes = features.shape[0]

        # # Skapa adjacency-matris
        # adj = np.zeros((num_nodes, num_nodes), dtype=np.float32)
        # # För varje rad i df_edges, lägg in en 1:a där det finns en edge.
        # # Eftersom det ofta kan vara en riktad edge men vi vill ha en symmetrisk matris (om inte annat önskas)
        # for _, row in df_edges.iterrows():
        #     tgt = int(row["target"])
        #     src = int(row["source"])
        #     if tgt < num_nodes and src < num_nodes:  # säkerställ att index finns
        #         adj[tgt, src] = 1.0
        #         adj[src, tgt] = 1.0  # Gör matrisen symmetrisk, kan uteslutas om det är riktad data.
    except:
        print(f"Scene {scene_id} caused error.")
        # continue

    return features, targets


all_data = []
scene_ids = list(subgraphs.keys())
for scene_id in scene_ids:
    scene = subgraphs[scene_id]
    if "nodes" in scene and "edges" in scene:
        features, targets = process_scene(scene)
        all_data.append((features, targets))
    else:
        print(f"Scene {scene_id} misses 'nodes' or 'edges'.")

train_data, test_data = train_test_split(all_data, test_size=0.3, random_state=42)
print(f"Number of training scenes: {len(train_data)}")
print(f"Number of test scenes: {len(test_data)}")

Number of training scenes: 135
Number of test scenes: 58


In [None]:
train_data[]

In [66]:
# ### Split the dataset


# # Obtain random indices
# random_indices = np.random.permutation(range(len(subgraphs.keys())))

# # 50/50 split
# train_nodes = nodes_df.iloc[random_indices[: len(random_indices) // 2]]
# train_edges = edges_df.iloc[random_indices[: len(random_indices) // 2]]
# test_nodes = nodes_df.iloc[random_indices[len(random_indices) // 2 :]]
# test_edges = edges_df.iloc[random_indices[len(random_indices) // 2 :]]

# ### Prepare the graph data


# # Obtain paper indices which will be used to gather node states
# # from the graph later on when training the model
train_indices = train_data["scene_id"].to_numpy()
test_indices = test_data["scence_id"].to_numpy()

# # Obtain ground truth labels corresponding to each paper_id
# train_labels = train_data["subject"].to_numpy()
# test_labels = test_data["subject"].to_numpy()

# Define graph, namely an edge tensor and a node feature tensor
edges = tf.convert_to_tensor(edges_df[["target", "source"]])
node_states = tf.convert_to_tensor(nodes_df.sort_values("node_id").iloc[:, 1:-1])

# Print shapes of the graph
print("Edges shape:\t\t", edges.shape)
print("Node features shape:", node_states.shape)

Edges shape:		 (2721, 2)
Node features shape: (1669, 6)


In [None]:
class GATLayer(layers.Layer):
    def __init__(self, output_dim, **kwargs):
        super().__init__(**kwargs)
        self.output_dim = output_dim

    def build(self, input_shape):
        # input_shape: (batch_shape kanske ej existerar då vi tränar med en graph per sample) -> (num_nodes, feature_dim)
        self.linear = layers.Dense(self.output_dim)
        super().build(input_shape)

    def call(self, inputs):
        # inputs: (x, adj)
        # x: (num_nodes, feature_dim)
        # adj: (num_nodes, num_nodes)
        x, adj = inputs
        h = self.linear(x)  # (num_nodes, output_dim)

        # Cosine-similaritet: normalisera längden på varje nods representation
        h_norm = tf.math.l2_normalize(h, axis=-1)  # (num_nodes, output_dim)
        # Attention-poäng: beräkna den parvisa cosine likheten
        attn = tf.matmul(h_norm, h_norm, transpose_b=True)  # (num_nodes, num_nodes)
        # Maska bort icke-existerande kanter genom att multiplicera med adjacency
        attn = attn * adj
        # Normalisera attention: dela varje rad med summan över den raden
        attn_sum = tf.reduce_sum(attn, axis=-1, keepdims=True) + 1e-6
        attn_norm = attn / attn_sum
        # Aggregera informationsflödet
        out = tf.matmul(attn_norm, h)
        return out


# %%
class GATModel(Model):
    def __init__(self, hidden_dim=32, **kwargs):
        super().__init__(**kwargs)
        # Första GAT-lagret
        self.gat1 = GATLayer(hidden_dim)
        # Andra GAT-lagret, output dim blir 2 för regression (future_x, future_y)
        self.gat2 = GATLayer(2)

    def call(self, inputs):
        x, adj = inputs
        h = self.gat1((x, adj))
        h = tf.nn.relu(h)
        h = self.gat2((h, adj))
        return h


# Eftersom varje subgraph har olika antal noder kommer vi att träna modellen med ett **custom training loop** där vi itererar över varje graf (subgraph) separat.
#
# Vi använder Mean Squared Error (MSE) som förlustfunktion för regression.

# %%
# Hyperparametrar
EPOCHS = 100
LEARNING_RATE = 1e-3

# Initiera modell, optimizer och loss
model = GATModel(hidden_dim=32)
optimizer = optimizers.Adam(learning_rate=LEARNING_RATE)
loss_fn = losses.MeanSquaredError()


# Funktion för att träna på en epok
def train_one_epoch(data_list):
    total_loss = 0
    for features, adj, targets in data_list:
        # Konvertera till tensor
        features_tf = tf.convert_to_tensor(features, dtype=tf.float32)
        adj_tf = tf.convert_to_tensor(adj, dtype=tf.float32)
        targets_tf = tf.convert_to_tensor(targets, dtype=tf.float32)

        with tf.GradientTape() as tape:
            preds = model((features_tf, adj_tf))
            loss = loss_fn(targets_tf, preds)
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        total_loss += loss.numpy()
    return total_loss / len(data_list)


# %%
# Träna modellen
for epoch in range(EPOCHS):
    loss_value = train_one_epoch(train_data)
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}: Genomsnittlig loss = {loss_value:.4f}")

# %% [markdown]
# ## Utvärdering
#
# Låt oss beräkna MSE på testdata:
#
# Vi går igenom varje subgraph i test_data, gör en prediktion och jämför med de sanna värdena.


# %%
def evaluate(data_list):
    total_loss = 0
    for features, adj, targets in data_list:
        features_tf = tf.convert_to_tensor(features, dtype=tf.float32)
        adj_tf = tf.convert_to_tensor(adj, dtype=tf.float32)
        targets_tf = tf.convert_to_tensor(targets, dtype=tf.float32)

        preds = model((features_tf, adj_tf))
        loss = loss_fn(targets_tf, preds)
        total_loss += loss.numpy()
    return total_loss / len(data_list)


test_loss = evaluate(test_data)
print(f"Test MSE: {test_loss:.4f}")

# %% [markdown]
# ## Predictions
#
# Exempelvis kan vi göra prediktioner på en test-subgraph och visualisera (skriva ut) både sanna och predikterade värden.

# %%
# Välj ett exempelscenario från test_data
sample_features, sample_adj, sample_targets = test_data[0]
sample_features_tf = tf.convert_to_tensor(sample_features, dtype=tf.float32)
sample_adj_tf = tf.convert_to_tensor(sample_adj, dtype=tf.float32)

# Prediktion
sample_preds = model((sample_features_tf, sample_adj_tf)).numpy()

# Visa resultat för varje nod
print("Nod\tTrue (future_x, future_y)\tPred (future_x, future_y)")
for i in range(sample_targets.shape[0]):
    true_vals = sample_targets[i]
    pred_vals = sample_preds[i]
    print(f"{i}\t({true_vals[0]:.2f}, {true_vals[1]:.2f})\t\t({pred_vals[0]:.2f}, {pred_vals[1]:.2f})")

# %% [markdown]
# ## Sammanfattning
#
# Vi har nu:
# - Laddat och bearbetat en dataset med subgrafer (där varje subgraph motsvaras av ett scene id).
# - Delat in subgraferna i tränings- och testdata.
# - Definierat en GAT-modell i Keras modifierad för regression (predikterar `future_x` och `future_y`).
# - Tränat modellen med ett custom training loop där varje subgraph behandlas som en enskild batch.
# - Utvärderat modellen på testdata och visat ett exempel på prediktioner.
#
# Anpassa gärna hyperparametrar och arkitektur efter behov. Lycka till med din modell!

In [22]:
# Merge edges_df with nodes_df for target nodes
merged_target = pd.merge(
    edges_df, nodes_df, left_on=["scene_id", "target"], right_on=["scene_id", "node_id"], how="inner"
)

# Rename columns to distinguish target node attributes
merged_target = merged_target.rename(
    columns={
        "node_id": "target_node_id",
        "current_x": "target_current_x",
        "current_y": "target_current_y",
        "previous_x": "target_previous_x",
        "previous_y": "target_previous_y",
        "future_x": "target_future_x",
        "future_y": "target_future_y",
    }
)

# Merge edges_df with nodes_df for source nodes
merged_source = pd.merge(
    edges_df, nodes_df, left_on=["scene_id", "source"], right_on=["scene_id", "node_id"], how="inner"
)

# Rename columns to distinguish source node attributes
merged_source = merged_source.rename(
    columns={
        "node_id": "source_node_id",
        "current_x": "source_current_x",
        "current_y": "source_current_y",
        "previous_x": "source_previous_x",
        "previous_y": "source_previous_y",
        "future_x": "source_future_x",
        "future_y": "source_future_y",
    }
)

# Combine target and source node attributes into a single DataFrame
merged_df = pd.merge(merged_target, merged_source, on=["scene_id", "target", "source"], how="inner")

# Display the merged DataFrame
print(merged_df)

        target    source       scene_id  ...  source_previous_y  \
0     19585800  19590700  1352890817715  ...           -15576.0   
1     19591900  19592201  1352890817715  ...           -14512.0   
2     19591900  19595300  1352890817715  ...           -12614.0   
...        ...       ...            ...  ...                ...   
2928  20002900  20004700  1352890910955  ...           -22062.0   
2929  20010700  20011200  1352890910955  ...            15355.0   
2930  20013400  20013402  1352890910955  ...           -17719.0   

      source_future_x  source_future_y  
0             28276.0         -16736.0  
1             24269.0         -15188.0  
2             22476.0         -13642.0  
...               ...              ...  
2928          35544.0         -21641.0  
2929         -12540.0          15886.0  
2930          31137.0         -18268.0  

[2931 rows x 17 columns]


In [20]:
merged_df = merged_df[(merged_df["target"] == merged_df["node_id"]) | (merged_df["source"] == merged_df["node_id"])]

In [21]:
merged_df

Unnamed: 0,scene_id,target,source,...,previous_y,future_x,future_y
0,1352890817715,19585800,19590700,...,-16671.0,28033.0,-17874.0
1,1352890817715,19585800,19590700,...,-15576.0,28276.0,-16736.0
11,1352890817715,19591900,19592201,...,-13198.0,22412.0,-14133.0
...,...,...,...,...,...,...,...
29851,1352890910955,20010700,20011200,...,15355.0,-12540.0,15886.0
29858,1352890910955,20013400,20013402,...,-16662.0,31451.0,-17345.0
29859,1352890910955,20013400,20013402,...,-17719.0,31137.0,-18268.0


In [None]:
# loader = DisjointLoader(dataset, batch_size=16, epochs=1, shuffle=True)

# loader

In [80]:
random_indices = np.random.permutation(range(len(subgraphs.keys())))

array([ 15, 129, 135,  75, 189,  27,  63,  62,  26, 183,  74, 122,  22,
       177,  28,  29,  43, 132,   4,  36, 181,   3,  77,  54,  61, 103,
       157,  47, 119, 173, 147,  50,  42,  79,  12, 158, 174, 156, 164,
        83,  46,  73,  18,  10,  98, 149, 120,  88, 170,  60, 165, 155,
        25, 114, 192, 145, 150, 186,  95,  23, 139, 116, 111,  16,   9,
       184, 185,  90,  92,  56,  21, 188, 106,  64,  31,  38, 124, 112,
        91, 180,  39, 168, 125, 102,  59,  34,  17, 167,  72, 104,  67,
        96,  65,  86, 109, 117,  93,  51, 161, 182, 131,  11, 134, 179,
        41, 160,  58,   8, 178,  19, 146,  40, 128, 163, 169, 152, 130,
       108,   5,   1, 143, 191,  94, 176, 100, 137, 154, 148,   2,  30,
        20,  35,  82, 138, 127,  49, 175,   7, 162,  44,  69,  85, 141,
        33,  87, 136, 123,  70, 115, 113, 133,  37, 151,  53, 166, 159,
        68, 153,  57, 187, 101, 171,  55,  13,  76,  32,  45, 107, 121,
       140,  89,  14,   6,  99,   0,  24, 126,  52,  81,  80,  7

In [81]:
train_edges = edges_df.iloc[random_indices[: len(random_indices) // 2]]

train_edges

Unnamed: 0,target,source,scene_id
15,19591900,19595800,1352890814428
129,19585800,19590700,1352890801553
135,19591900,19594200,1352890801553
...,...,...,...
86,19595800,20001800,1352890841688
109,19591900,19595300,1352890825684
117,19592201,20000300,1352890825684


In [82]:
train_nodes = nodes_df.iloc[random_indices[: len(random_indices) // 2]]

train_nodes

Unnamed: 0,node_id,current_x,current_y,...,future_x,future_y,scene_id
15,19590700,24745.0,-14648.0,...,25516.0,-15070.0,1352890814428
129,20000200,35414.0,-18285.0,...,36754.0,-17928.0,1352890832535
135,19585800,35907.0,-21128.0,...,35654.0,-21456.0,1352890829505
...,...,...,...,...,...,...,...
86,19591900,41210.0,-21106.0,...,42093.0,-21566.0,1352890837555
109,19502500,41331.0,-16669.0,...,41147.0,-16802.0,1352890801553
117,19595200,14210.0,-6598.0,...,15379.0,-7696.0,1352890801553


In [83]:
test_nodes = nodes_df.iloc[random_indices[len(random_indices) // 2 :]]

test_nodes

Unnamed: 0,node_id,current_x,current_y,...,future_x,future_y,scene_id
93,20000700,39242.0,-18792.0,...,40368.0,-18522.0,1352890837555
51,19502500,41279.0,-16920.0,...,40423.0,-16728.0,1352890875617
161,19502500,41076.0,-16662.0,...,40962.0,-16618.0,1352890842354
...,...,...,...,...,...,...,...
97,19502500,40099.0,-16680.0,...,40032.0,-16532.0,1352890825684
110,19585800,13864.0,-7959.0,...,14498.0,-8687.0,1352890801553
118,19595300,8839.0,-4141.0,...,9555.0,-4348.0,1352890801553


In [84]:
test_edges = edges_df.iloc[random_indices[len(random_indices) // 2 :]]

test_edges

Unnamed: 0,target,source,scene_id
93,19591900,19595800,1352890837555
51,19592400,19595200,1352890800322
161,19592201,19595800,1352890832535
...,...,...,...
97,19592201,19595800,1352890837555
110,19591900,19595800,1352890825684
118,19592201,20001800,1352890825684


In [None]:
subgraphs_list = list(subgraphs)

In [93]:
from sklearn.model_selection import train_test_split

scene_ids = list(subgraphs.keys())
train_ids, test_ids = train_test_split(scene_ids, test_size=0.2, random_state=42)

train_graphs = [subgraphs[sid] for sid in train_ids]
test_graphs = [subgraphs[sid] for sid in test_ids]

In [96]:
train_graphs

[{'edges':      target    source
  0  19592800  20010000
  1  20002900  20004700
  2  20004900  20005100
  3  20004900  20010000
  4  20005100  20010000,
  'nodes':      node_id  current_x  current_y  ...  previous_y  future_x  future_y
  0   19502500    40118.0   -16503.0  ...    -16689.0   40166.0  -16406.0
  1   19592800     8851.0     1079.0  ...      1410.0    8804.0    1042.0
  2   20002900    34432.0   -21963.0  ...    -22015.0   34578.0  -21851.0
  ..       ...        ...        ...  ...         ...       ...       ...
  7   20010000     7146.0    -1377.0  ...      -763.0    7913.0   -1954.0
  8   20010700   -28932.0     -409.0  ...     -1067.0  -28322.0     352.0
  9   20011200   -32104.0     -991.0  ...     -1480.0  -31149.0    -472.0
  
  [10 rows x 7 columns]},
 {'edges':       target    source
  0   19585800  19590700
  1   19585800  19595200
  2   19590700  19592201
  ..       ...       ...
  18  19594200  19595300
  19  19594200  19595800
  20  19595300  19595800
  
  [2

In [None]:
### Split the dataset


# Obtain random indices
random_indices = np.random.permutation(range(len(subgraphs.keys())))

# 50/50 split
train_nodes = nodes_df.iloc[random_indices[: len(random_indices) // 2]]
train_edges = edges_df.iloc[random_indices[: len(random_indices) // 2]]
test_nodes = nodes_df.iloc[random_indices[len(random_indices) // 2 :]]
test_edges = edges_df.iloc[random_indices[len(random_indices) // 2 :]]

### Prepare the graph data


# Obtain paper indices which will be used to gather node states
# from the graph later on when training the model
train_indices = train_data["paper_id"].to_numpy()
test_indices = test_data["paper_id"].to_numpy()

# Obtain ground truth labels corresponding to each paper_id
train_labels = train_data["subject"].to_numpy()
test_labels = test_data["subject"].to_numpy()

# Define graph, namely an edge tensor and a node feature tensor
edges = tf.convert_to_tensor(citations[["target", "source"]])
node_states = tf.convert_to_tensor(papers.sort_values("paper_id").iloc[:, 1:-1])

# Print shapes of the graph
print("Edges shape:\t\t", edges.shape)
print("Node features shape:", node_states.shape)

      target  source
0          0      21
1          0     905
2          0     906
...      ...     ...
5426    1874    2586
5427    1876    1874
5428    1897    2707

[5429 rows x 2 columns]
      paper_id  term_0  term_1  ...  term_1431  term_1432  subject
0          462       0       0  ...          0          0        2
1         1911       0       0  ...          0          0        5
2         2002       0       0  ...          0          0        4
...        ...     ...     ...  ...        ...        ...      ...
2705      2372       0       0  ...          0          0        1
2706       955       0       0  ...          0          0        0
2707       376       0       0  ...          0          0        2

[2708 rows x 1435 columns]
Edges shape:		 (5429, 2)
Node features shape: (2708, 1433)



## Build the model

GAT takes as input a graph (namely an edge tensor and a node feature tensor) and
outputs \[updated\] node states. The node states are, for each target node, neighborhood
aggregated information of *N*-hops (where *N* is decided by the number of layers of the
GAT). Importantly, in contrast to the
[graph convolutional network](https://arxiv.org/abs/1609.02907) (GCN)
the GAT makes use of attention mechanisms
to aggregate information from neighboring nodes (or *source nodes*). In other words, instead of simply
averaging/summing node states from source nodes (*source papers*) to the target node (*target papers*),
GAT first applies normalized attention scores to each source node state and then sums.



### (Multi-head) graph attention layer

The GAT model implements multi-head graph attention layers. The `MultiHeadGraphAttention`
layer is simply a concatenation (or averaging) of multiple graph attention layers
(`GraphAttention`), each with separate learnable weights `W`. The `GraphAttention` layer
does the following:

Consider inputs node states `h^{l}` which are linearly transformed by `W^{l}`, resulting in `z^{l}`.

For each target node:

1. Computes pair-wise attention scores `a^{l}^{T}(z^{l}_{i}||z^{l}_{j})` for all `j`,
resulting in `e_{ij}` (for all `j`).
`||` denotes a concatenation, `_{i}` corresponds to the target node, and `_{j}`
corresponds to a given 1-hop neighbor/source node.
2. Normalizes `e_{ij}` via softmax, so as the sum of incoming edges' attention scores
to the target node (`sum_{k}{e_{norm}_{ik}}`) will add up to 1.
3. Applies attention scores `e_{norm}_{ij}` to `z_{j}`
and adds it to the new target node state `h^{l+1}_{i}`, for all `j`.


In [67]:
class GraphAttention(layers.Layer):
    def __init__(
        self,
        units,
        kernel_initializer="glorot_uniform",
        kernel_regularizer=None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.units = units
        self.kernel_initializer = keras.initializers.get(kernel_initializer)
        self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)

    def build(self, input_shape):
        self.kernel = self.add_weight(
            shape=(input_shape[0][-1], self.units),
            trainable=True,
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            name="kernel",
        )
        self.kernel_attention = self.add_weight(
            shape=(self.units * 2, 1),
            trainable=True,
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            name="kernel_attention",
        )
        self.built = True

    def call(self, inputs):
        node_states, edges = inputs

        # Linearly transform node states
        node_states_transformed = tf.matmul(node_states, self.kernel)

        # (1) Compute pair-wise attention scores
        node_states_expanded = tf.gather(node_states_transformed, edges)
        node_states_expanded = tf.reshape(node_states_expanded, (tf.shape(edges)[0], -1))
        attention_scores = tf.nn.leaky_relu(tf.matmul(node_states_expanded, self.kernel_attention))
        attention_scores = tf.squeeze(attention_scores, -1)

        # (2) Normalize attention scores
        attention_scores = tf.math.exp(tf.clip_by_value(attention_scores, -2, 2))
        attention_scores_sum = tf.math.unsorted_segment_sum(
            data=attention_scores,
            segment_ids=edges[:, 0],
            num_segments=tf.reduce_max(edges[:, 0]) + 1,
        )
        attention_scores_sum = tf.repeat(attention_scores_sum, tf.math.bincount(tf.cast(edges[:, 0], "int32")))
        attention_scores_norm = attention_scores / attention_scores_sum

        # (3) Gather node states of neighbors, apply attention scores and aggregate
        node_states_neighbors = tf.gather(node_states_transformed, edges[:, 1])
        out = tf.math.unsorted_segment_sum(
            data=node_states_neighbors * attention_scores_norm[:, tf.newaxis],
            segment_ids=edges[:, 0],
            num_segments=tf.shape(node_states)[0],
        )
        return out


class MultiHeadGraphAttention(layers.Layer):
    def __init__(self, units, num_heads=8, merge_type="concat", **kwargs):
        super().__init__(**kwargs)
        self.num_heads = num_heads
        self.merge_type = merge_type
        self.attention_layers = [GraphAttention(units) for _ in range(num_heads)]

    def call(self, inputs):
        atom_features, pair_indices = inputs

        # Obtain outputs from each attention head
        outputs = [attention_layer([atom_features, pair_indices]) for attention_layer in self.attention_layers]
        # Concatenate or average the node states from each head
        if self.merge_type == "concat":
            outputs = tf.concat(outputs, axis=-1)
        else:
            outputs = tf.reduce_mean(tf.stack(outputs, axis=-1), axis=-1)
        # Activate and return node states
        return tf.nn.relu(outputs)





### Implement training logic with custom `train_step`, `test_step`, and `predict_step` methods

Notice, the GAT model operates on the entire graph (namely, `node_states` and
`edges`) in all phases (training, validation and testing). Hence, `node_states` and
`edges` are passed to the constructor of the `keras.Model` and used as attributes.
The difference between the phases are the indices (and labels), which gathers
certain outputs (`tf.gather(outputs, indices)`).



In [68]:
class GraphAttentionNetwork(keras.Model):
    def __init__(
        self,
        node_states,
        edges,
        hidden_units,
        num_heads,
        num_layers,
        output_dim,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.node_states = node_states
        self.edges = edges
        self.preprocess = layers.Dense(hidden_units * num_heads, activation="relu")
        self.attention_layers = [MultiHeadGraphAttention(hidden_units, num_heads) for _ in range(num_layers)]
        self.output_layer = layers.Dense(output_dim)

    def call(self, inputs):
        node_states, edges = inputs
        x = self.preprocess(node_states)
        for attention_layer in self.attention_layers:
            x = attention_layer([x, edges]) + x
        outputs = self.output_layer(x)
        return outputs

    def train_step(self, data):
        indices, labels = data

        with tf.GradientTape() as tape:
            # Forward pass
            outputs = self([self.node_states, self.edges])
            # Compute loss
            loss = self.compiled_loss(labels, tf.gather(outputs, indices))
        # Compute gradients
        grads = tape.gradient(loss, self.trainable_weights)
        # Apply gradients (update weights)
        optimizer.apply_gradients(zip(grads, self.trainable_weights))
        # Update metric(s)
        self.compiled_metrics.update_state(labels, tf.gather(outputs, indices))

        return {m.name: m.result() for m in self.metrics}

    def predict_step(self, data):
        indices = data
        # Forward pass
        outputs = self([self.node_states, self.edges])
        # Compute probabilities
        return tf.nn.softmax(tf.gather(outputs, indices))

    def test_step(self, data):
        indices, labels = data
        # Forward pass
        outputs = self([self.node_states, self.edges])
        # Compute loss
        loss = self.compiled_loss(labels, tf.gather(outputs, indices))
        # Update metric(s)
        self.compiled_metrics.update_state(labels, tf.gather(outputs, indices))

        return {m.name: m.result() for m in self.metrics}


In [70]:
### Train and evaluate


# Define hyper-parameters
HIDDEN_UNITS = 100
NUM_HEADS = 8
NUM_LAYERS = 3
OUTPUT_DIM = 2

NUM_EPOCHS = 100
BATCH_SIZE = 256
VALIDATION_SPLIT = 0.1
LEARNING_RATE = 3e-1
MOMENTUM = 0.9

loss_fn = keras.losses.MeanSquaredError()
optimizer = keras.optimizers.SGD(LEARNING_RATE, momentum=MOMENTUM)
accuracy_fn = keras.metrics.MeanSquaredError()
early_stopping = keras.callbacks.EarlyStopping(monitor="val_acc", min_delta=1e-5, patience=5, restore_best_weights=True)

# Build model
gat_model = GraphAttentionNetwork(node_states, edges, HIDDEN_UNITS, NUM_HEADS, NUM_LAYERS, OUTPUT_DIM)

# Compile model
gat_model.compile(loss=loss_fn, optimizer=optimizer, metrics=[accuracy_fn])

gat_model.fit(
    x=train_indices,
    y=train_labels,
    validation_split=VALIDATION_SPLIT,
    batch_size=BATCH_SIZE,
    epochs=NUM_EPOCHS,
    callbacks=[early_stopping],
    verbose=2,
)

NameError: name 'train_indices' is not defined

In [None]:
_, test_accuracy = gat_model.evaluate(x=test_indices, y=test_labels, verbose=0)

print("--" * 38 + f"\nTest Accuracy {test_accuracy*100:.1f}%")


### Predict (probabilities)

test_probs = gat_model.predict(x=test_indices)

mapping = {v: k for (k, v) in class_idx.items()}

for i, (probs, label) in enumerate(zip(test_probs[:10], test_labels[:10])):
    print(f"Example {i+1}: {mapping[label]}")
    for j, c in zip(probs, class_idx.keys()):
        print(f"\tProbability of {c: <24} = {j*100:7.3f}%")
    print("---" * 20)



## Conclusions

The results look OK! The GAT model seems to correctly predict the subjects of the papers,
based on what they cite, about 80% of the time. Further improvements could be
made by fine-tuning the hyper-parameters of the GAT. For instance, try changing the number of layers,
the number of hidden units, or the optimizer/learning rate; add regularization (e.g., dropout);
or modify the preprocessing step. We could also try to implement *self-loops*
(i.e., paper X cites paper X) and/or make the graph *undirected*.



Original edges count: 15
Original nodes count: 13
Filtered edges count: 14
Filtered nodes count: 13
Scene ID 1352890817715 contains -1 edges. Processing...
Original edges count: 16
Original nodes count: 14
Filtered edges count: 15
Filtered nodes count: 13
Scene ID 1352890814428 contains -1 edges. Processing...
Scene ID 1352890802323 contains -1 edges. Processing...
Original edges count: 23
Original nodes count: 12
Filtered edges count: 22
Filtered nodes count: 11
Scene ID 1352890800322 contains -1 edges. Processing...
Scene ID 1352890875617 contains -1 edges. Processing...
Original edges count: 16
Original nodes count: 13
Filtered edges count: 13
Filtered nodes count: 13
Scene ID 1352890804562 contains -1 edges. Processing...
Original edges count: 14
Original nodes count: 10
Filtered edges count: 13
Filtered nodes count: 9
Scene ID 1352890841688 contains -1 edges. Processing...
Scene ID 1352890837555 contains -1 edges. Processing...
Scene ID 1352890825684 contains -1 edges. Processing.

In [None]:
"""
Title: GAT Regression for Pedestrian Future Position Prediction
Description:
    This script demonstrates how to use a Graph Attention Network (GAT)
    for a regression task over pedestrian trajectory data.

    Each scene is treated as a separate graph. The nodes represent
    pedestrians with features (e.g. current position, previous motion, etc.)
    and the edges represent interactions (or connectivity) between them.

    The model learns to predict the pedestrian's future position, namely
    future_x and future_y one second ahead.

Author: Your Name
Date: 2025-04-13
"""

import os
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import warnings

warnings.filterwarnings("ignore")
np.random.seed(2)

# ------------------------------------------------------------------------------
# Data Loading and Preprocessing
# ------------------------------------------------------------------------------


# Define the dataset directory
dataset_dir = "dataset"


# Function to find all scene IDs in the dataset directory
def find_all_scene_ids(dataset_dir):
    scene_ids = []
    for file in os.listdir(dataset_dir):
        if file.endswith(".edges"):
            scene_id = file.split(".")[0]
            scene_ids.append(scene_id)
    return scene_ids


# Function to load all subgraphs for the found scene IDs
def load_all_subgraphs(dataset_dir):
    scene_ids = find_all_scene_ids(dataset_dir)
    scenes = []

    for scene_id in scene_ids:

        edges_file = os.path.join(dataset_dir, f"{scene_id}.edges")
        nodes_file = os.path.join(dataset_dir, f"{scene_id}.nodes")

        # Check if both files exist
        if not os.path.exists(edges_file) or not os.path.exists(nodes_file):
            print(f"Skipping scene ID {scene_id}: Missing files.")
            continue

        # Load edges
        edges = pd.read_csv(edges_file, sep=",", header=None, names=["target", "source"])

        # Load nodes
        nodes = pd.read_csv(
            nodes_file,
            sep=",",
            header=None,
            names=["node_id", "current_x", "current_y", "previous_x", "previous_y", "future_x", "future_y"],
        )

        for col in nodes.columns:
            nodes[col] = pd.to_numeric(nodes[col], errors="coerce")

        if nodes.isnull().any().any():
            # Step 1: Identify rows with NaN values in nodes_df
            nan_nodes = nodes[nodes.isnull().any(axis=1)]

            # Step 2: Extract the node_id values of those rows
            nan_node_ids = nan_nodes["node_id"].tolist()

            # Step 3: Filter out edges in edges_df where source or target is in nan_node_ids
            # Display the filtered edges
            print(f"Original edges count: {len(edges)}")
            print(f"Original nodes count: {len(nodes)}")
            edges = edges[~edges["source"].isin(nan_node_ids) & ~edges["target"].isin(nan_node_ids)]

            print(f"Filtered edges count: {len(edges)}")
            nodes = nodes.dropna(subset=["future_x", "future_y"])
            print(f"Filtered nodes count: {len(nodes)}")

        # # Filter out edges with -1 as source value
        # edges = edges[edges["source"] != -1]

        # Check if there are any -1 edges
        if (edges["source"] == -1).any() or (edges["target"] == -1).any():
            print(f"Scene ID {scene_id} contains -1 edges. Processing...")

            # Remove edges with -1 as source or target
            edges = edges[(edges["source"] != -1) & (edges["target"] != -1)]

            # Get unique node IDs from the remaining edges
            connected_nodes = pd.unique(edges[["target", "source"]].values.ravel())

            # Filter nodes to keep only those that are connected
            nodes = nodes[nodes["node_id"].isin(connected_nodes)]

        # Store the subgraph
        scenes.append(
            {"scene_id": scene_id, "edges": edges, "nodes": nodes},
        )

    return scenes


# Example usage
scenes = load_all_subgraphs(dataset_dir)
print(f"Loaded {len(scenes)} scenes.")

In [None]:
def aggregate_scenes(scenes):
    """
    Aggregates a list of scene dictionaries into one unified graph.
    Each scene's nodes (and their features/targets) are stacked;
    the edges are adjusted (by offsetting node indices) to create a disjoint graph.

    Returns:
        all_nodes: DataFrame of all nodes (includes features and targets)
        all_edges: np.array of edges as shape (num_edges, 2)
        scene_node_indices: dict mapping scene_id -> array of node indices in all_nodes.
    """
    nodes_list = []
    edges_list = []
    scene_node_indices = {}
    node_offset = 0

    for scene in scenes:
        scene_id = scene["scene_id"]
        nodes_df = scene["nodes"].copy().reset_index(drop=True)
        edges_df = scene["edges"].copy().reset_index(drop=True)

        num_nodes = nodes_df.shape[0]
        # Record indices belonging to this scene (will be used for splitting)
        scene_node_indices[scene_id] = np.arange(node_offset, node_offset + num_nodes)

        # Create a mapping from original node_id to new index
        node_id_to_index = dict(zip(nodes_df["node_id"], range(node_offset, node_offset + num_nodes)))

        # Update edges: replace the original node_id values with the new indices.
        # Note: It is assumed that the edges DataFrame contains columns named "target" and "source".
        def map_id(x):
            return node_id_to_index.get(x, -1)

        edges_df["target"] = edges_df["target"].apply(map_id)
        edges_df["source"] = edges_df["source"].apply(map_id)
        # It is possible that some edges refer to node IDs not included in the nodes DataFrame.
        # Filter out such edges (where mapping returned -1).
        edges_df = edges_df[(edges_df["target"] != -1) & (edges_df["source"] != -1)]

        nodes_list.append(nodes_df)
        edges_list.append(edges_df)

        node_offset += num_nodes

    # Concatenate all nodes
    all_nodes = pd.concat(nodes_list, ignore_index=True)
    # Concatenate and convert edges to numpy array of type int32
    all_edges = pd.concat(edges_list, ignore_index=True).to_numpy().astype(np.int32)
    return all_nodes, all_edges, scene_node_indices


def scene_based_split(scene_node_indices, train_ratio=0.5):
    """
    Splits the scenes into train and test based on scene ids.
    Returns:
        train_indices: numpy array of node indices (all nodes belonging to training scenes)
        test_indices: numpy array of node indices (all nodes belonging to test scenes)
    """
    scene_ids = list(scene_node_indices.keys())
    scene_ids = np.array(scene_ids)
    np.random.shuffle(scene_ids)
    n_train = int(len(scene_ids) * train_ratio)
    train_scenes = scene_ids[:n_train]
    test_scenes = scene_ids[n_train:]

    train_indices = np.concatenate([scene_node_indices[sid] for sid in train_scenes])
    test_indices = np.concatenate([scene_node_indices[sid] for sid in test_scenes])
    return train_indices, test_indices


# ------------------------------------------------------------------------------
# Graph Attention Network (GAT) Model Definition for Regression
# ------------------------------------------------------------------------------


class GraphAttention(layers.Layer):
    def __init__(self, units, kernel_initializer="glorot_uniform", kernel_regularizer=None, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.kernel_initializer = keras.initializers.get(kernel_initializer)
        self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)

    def build(self, input_shape):
        # input_shape[0] is node_features shape; input_shape[1] is edge tensor shape.
        self.kernel = self.add_weight(
            shape=(input_shape[0][-1], self.units),
            trainable=True,
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            name="kernel",
        )
        self.kernel_attention = self.add_weight(
            shape=(self.units * 2, 1),
            trainable=True,
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            name="kernel_attention",
        )
        super().build(input_shape)

    def call(self, inputs):
        node_states, edges = inputs
        # Linear transformation of node states
        node_states_transformed = tf.matmul(node_states, self.kernel)

        # (1) Compute pair-wise attention scores for each edge
        # For each edge, gather the target and source node features and concatenate them.
        target_states = tf.gather(node_states_transformed, edges[:, 0])
        source_states = tf.gather(node_states_transformed, edges[:, 1])
        concat_features = tf.concat([target_states, source_states], axis=-1)
        e = tf.nn.leaky_relu(tf.matmul(concat_features, self.kernel_attention))
        e = tf.squeeze(e, axis=-1)

        # (2) Normalize the attention scores per target node.
        e = tf.exp(tf.clip_by_value(e, -2, 2))
        sum_e = tf.math.unsorted_segment_sum(e, edges[:, 0], num_segments=tf.shape(node_states)[0])
        # Repeat the sums to align with edge dimensions.
        sum_e_rep = tf.gather(sum_e, edges[:, 0])
        attention = e / (sum_e_rep + 1e-9)  # add epsilon to avoid division by zero

        # (3) Weighted sum of source node features
        source_transformed = tf.gather(node_states_transformed, edges[:, 1])
        messages = source_transformed * tf.expand_dims(attention, -1)
        output = tf.math.unsorted_segment_sum(messages, edges[:, 0], num_segments=tf.shape(node_states)[0])
        return output


class MultiHeadGraphAttention(layers.Layer):
    def __init__(self, units, num_heads=8, merge_type="concat", **kwargs):
        super().__init__(**kwargs)
        self.num_heads = num_heads
        self.merge_type = merge_type
        self.attention_layers = [GraphAttention(units) for _ in range(num_heads)]

    def call(self, inputs):
        node_states, edges = inputs
        # Gather outputs from each head.
        head_outputs = [att([node_states, edges]) for att in self.attention_layers]
        if self.merge_type == "concat":
            output = tf.concat(head_outputs, axis=-1)
        else:
            output = tf.reduce_mean(tf.stack(head_outputs, axis=-1), axis=-1)
        return tf.nn.relu(output)


class GraphAttentionNetwork(keras.Model):
    def __init__(self, node_states, edges, hidden_units, num_heads, num_layers, output_dim, **kwargs):
        super().__init__(**kwargs)
        self.node_states = node_states
        self.edges = edges
        self.preprocess = layers.Dense(hidden_units * num_heads, activation="relu")
        self.attention_layers = [MultiHeadGraphAttention(hidden_units, num_heads) for _ in range(num_layers)]
        self.output_layer = layers.Dense(output_dim)

    def call(self, inputs=None, training=None):
        # inputs is not used because self.node_states and self.edges are pre-stored.
        x = self.preprocess(self.node_states)
        for att_layer in self.attention_layers:
            x = att_layer([x, self.edges]) + x  # residual connection
        outputs = self.output_layer(x)
        return outputs

    def train_step(self, data):
        indices, labels = data
        with tf.GradientTape() as tape:
            outputs = self(None, training=True)  # call without external inputs
            predictions = tf.gather(outputs, indices)
            loss = self.compiled_loss(labels, predictions, regularization_losses=self.losses)
        grads = tape.gradient(loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.compiled_metrics.update_state(labels, predictions)
        return {m.name: m.result() for m in self.metrics}

    def test_step(self, data):
        indices, labels = data
        outputs = self(None, training=False)
        predictions = tf.gather(outputs, indices)
        loss = self.compiled_loss(labels, predictions, regularization_losses=self.losses)
        self.compiled_metrics.update_state(labels, predictions)
        results = {m.name: m.result() for m in self.metrics}
        results["loss"] = loss
        return results

    def predict_step(self, data):
        # data is just indices
        outputs = self(None, training=False)
        predictions = tf.gather(outputs, data)
        return predictions

In [22]:


# ------------------------------------------------------------------------------
# Main: Data Preparation, Model Training, and Evaluation
# ------------------------------------------------------------------------------

if __name__ == "__main__":
    # Load and aggregate scene data
    # scenes = load_data()
    all_nodes, all_edges, scene_node_indices = aggregate_scenes(scenes)

    # Split node indices by scene (scene-based split)
    train_indices, test_indices = scene_based_split(scene_node_indices, train_ratio=0.5)

    # Select input features and targets:
    # Assume that the input features are all columns except 'node_id', 'future_x', and 'future_y'
    feature_cols = [col for col in all_nodes.columns if col not in ["node_id", "future_x", "future_y"]]
    target_cols = ["future_x", "future_y"]

    # Prepare numpy arrays
    node_features_np = all_nodes[feature_cols].to_numpy().astype(np.float32)
    targets_np = all_nodes[target_cols].to_numpy().astype(np.float32)

    print("Aggregated nodes shape:", node_features_np.shape)
    print("Aggregated edges shape:", all_edges.shape)
    print("Training nodes:", train_indices.shape, "Test nodes:", test_indices.shape)

    # Convert aggregated graph data to tensors
    node_features_tensor = tf.convert_to_tensor(node_features_np)
    edges_tensor = tf.convert_to_tensor(all_edges)

    # Define hyper-parameters
    HIDDEN_UNITS = 100
    NUM_HEADS = 8
    NUM_LAYERS = 3
    OUTPUT_DIM = 2  # predicting future_x and future_y
    NUM_EPOCHS = 100
    BATCH_SIZE = 16  # batch size here relates to how many nodes to sample per update
    LEARNING_RATE = 1e-2

    # Build the model
    gat_model = GraphAttentionNetwork(
        node_states=node_features_tensor,
        edges=edges_tensor,
        hidden_units=HIDDEN_UNITS,
        num_heads=NUM_HEADS,
        num_layers=NUM_LAYERS,
        output_dim=OUTPUT_DIM,
    )

    # Compile with Mean Squared Error loss and mean absolute error metric
    gat_model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
        loss=keras.losses.MeanSquaredError(),
        metrics=[
            keras.metrics.MeanAbsoluteError(),
            keras.metrics.MeanSquaredError(),
            keras.metrics.R2Score(),
            keras.metrics.CosineSimilarity(axis=1),
        ],
    )

    # For training, we use the node indices and corresponding targets.
    # Create tf.data.Dataset objects.
    train_dataset = tf.data.Dataset.from_tensor_slices((train_indices, targets_np[train_indices]))
    train_dataset = train_dataset.shuffle(buffer_size=len(train_indices)).batch(BATCH_SIZE)

    test_dataset = tf.data.Dataset.from_tensor_slices((test_indices, targets_np[test_indices]))
    test_dataset = test_dataset.batch(BATCH_SIZE)

    print("Training...")
    gat_model.fit(train_dataset, epochs=NUM_EPOCHS, verbose=2)


Aggregated nodes shape: (1669, 4)
Aggregated edges shape: (2721, 2)
Training nodes: (818,) Test nodes: (851,)
Training...
Epoch 1/100
52/52 - 14s - 261ms/step - cosine_similarity: 0.7633 - mean_absolute_error: 16672.0645 - mean_squared_error: 8651210752.0000 - r2_score: -6.6249e+01 - loss: 7407.4155
Epoch 2/100
52/52 - 5s - 106ms/step - cosine_similarity: 0.9829 - mean_absolute_error: 2394.4409 - mean_squared_error: 26943594.0000 - r2_score: 0.6792 - loss: 5768.7227
Epoch 3/100
52/52 - 6s - 121ms/step - cosine_similarity: 0.9890 - mean_absolute_error: 1311.3588 - mean_squared_error: 5644909.5000 - r2_score: 0.9478 - loss: 5761.8594
Epoch 4/100
52/52 - 6s - 117ms/step - cosine_similarity: 0.9951 - mean_absolute_error: 968.3707 - mean_squared_error: 2575518.0000 - r2_score: 0.9729 - loss: 5723.0918
Epoch 5/100
52/52 - 6s - 121ms/step - cosine_similarity: 0.9969 - mean_absolute_error: 892.8805 - mean_squared_error: 1608296.3750 - r2_score: 0.9822 - loss: 5783.6216
Epoch 6/100
52/52 - 6s -

In [23]:
print("Evaluating on test set...")
results = gat_model.evaluate(test_dataset, verbose=2)
print(f"\nTest Loss (MSE): {results[0]:.4f}, Test MAE: {results[1]["mean_absolute_error"]:.4f}")

# Run predictions on test nodes
print("\nSample predictions for test nodes:")
predictions = gat_model.predict(tf.convert_to_tensor(test_indices))
for i, idx in enumerate(test_indices[:5]):
    print(
        f"Node {idx}: True future_x={targets_np[idx,0]:.1f}, future_y={targets_np[idx,1]:.1f} | Predicted future_x={predictions[i,0]:.1f}, future_y={predictions[i,1]:.1f}"
    )

Evaluating on test set...
54/54 - 2s - 45ms/step - cosine_similarity: 0.9996 - mean_absolute_error: 285.4763 - mean_squared_error: 289294.4062 - r2_score: 0.9953 - loss: 43330.5234

Test Loss (MSE): 43330.5234, Test MAE: 285.4763

Sample predictions for test nodes:
[1m27/27[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 54ms/step
Node 1216: True future_x=16055.0, future_y=-10549.0 | Predicted future_x=15787.4, future_y=-9927.2
Node 1217: True future_x=15966.0, future_y=-9603.0 | Predicted future_x=15935.6, future_y=-9025.5
Node 1218: True future_x=10766.0, future_y=-5762.0 | Predicted future_x=10590.6, future_y=-5331.0
Node 1219: True future_x=9303.0, future_y=1395.0 | Predicted future_x=9299.6, future_y=1411.1
Node 1220: True future_x=10456.0, future_y=-4786.0 | Predicted future_x=10381.8, future_y=-4295.6
