In [6]:
"""
Title: GAT Regression for Pedestrian Future Position Prediction
Description:
    This script demonstrates how to use a Graph Attention Network (GAT)
    for a regression task over pedestrian trajectory data.

    Each scene is treated as a separate graph. The nodes represent
    pedestrians with features (e.g. current position, previous motion, etc.)
    and the edges represent interactions (or connectivity) between them.

    The model learns to predict the pedestrian's future position, namely
    future_x and future_y one second ahead.

Author: Your Name
Date: 2025-04-13
"""

import os
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import warnings

warnings.filterwarnings("ignore")
np.random.seed(2)

# ------------------------------------------------------------------------------
# Data Loading and Preprocessing
# ------------------------------------------------------------------------------


# Define the dataset directory
dataset_dir = "dataset"


# Function to find all scene IDs in the dataset directory
def find_all_scene_ids(dataset_dir):
    scene_ids = []
    for file in os.listdir(dataset_dir):
        if file.endswith(".edges"):
            scene_id = file.split(".")[0]
            scene_ids.append(scene_id)
    return scene_ids


# Function to load all subgraphs for the found scene IDs
def load_all_subgraphs(dataset_dir):
    scene_ids = find_all_scene_ids(dataset_dir)
    scenes = []

    for scene_id in scene_ids:

        edges_file = os.path.join(dataset_dir, f"{scene_id}.edges")
        nodes_file = os.path.join(dataset_dir, f"{scene_id}.nodes")

        # Check if both files exist
        if not os.path.exists(edges_file) or not os.path.exists(nodes_file):
            print(f"Skipping scene ID {scene_id}: Missing files.")
            continue

        # Load edges
        edges = pd.read_csv(edges_file, sep=",", header=None, names=["target", "source"])

        # Load nodes
        nodes = pd.read_csv(
            nodes_file,
            sep=",",
            header=None,
            names=["node_id", "current_x", "current_y", "previous_x", "previous_y", "future_x", "future_y"],
        )

        for col in nodes.columns:
            nodes[col] = pd.to_numeric(nodes[col], errors="coerce")

        if nodes.isnull().any().any():
            # Step 1: Identify rows with NaN values in nodes_df
            nan_nodes = nodes[nodes.isnull().any(axis=1)]

            # Step 2: Extract the node_id values of those rows
            nan_node_ids = nan_nodes["node_id"].tolist()

            # Step 3: Filter out edges in edges_df where source or target is in nan_node_ids
            # Display the filtered edges
            print(f"Original edges count: {len(edges)}")
            print(f"Original nodes count: {len(nodes)}")
            edges = edges[~edges["source"].isin(nan_node_ids) & ~edges["target"].isin(nan_node_ids)]

            print(f"Filtered edges count: {len(edges)}")
            nodes = nodes.dropna(subset=["future_x", "future_y"])
            print(f"Filtered nodes count: {len(nodes)}")

        # # Filter out edges with -1 as source value
        # edges = edges[edges["source"] != -1]

        # Check if there are any -1 edges
        if (edges["source"] == -1).any() or (edges["target"] == -1).any():
            print(f"Scene ID {scene_id} contains -1 edges. Processing...")

            # Remove edges with -1 as source or target
            edges = edges[(edges["source"] != -1) & (edges["target"] != -1)]

            # Get unique node IDs from the remaining edges
            connected_nodes = pd.unique(edges[["target", "source"]].values.ravel())

            # Filter nodes to keep only those that are connected
            nodes = nodes[nodes["node_id"].isin(connected_nodes)]

        # Store the subgraph
        scenes.append(
            {"scene_id": scene_id, "edges": edges, "nodes": nodes},
        )

    return scenes


# Example usage
scenes = load_all_subgraphs(dataset_dir)
print(f"Loaded {len(scenes)} scenes.")

Original edges count: 15
Original nodes count: 13
Filtered edges count: 14
Filtered nodes count: 13
Scene ID 1352890817715 contains -1 edges. Processing...
Original edges count: 16
Original nodes count: 14
Filtered edges count: 15
Filtered nodes count: 13
Scene ID 1352890814428 contains -1 edges. Processing...
Scene ID 1352890802323 contains -1 edges. Processing...
Original edges count: 23
Original nodes count: 12
Filtered edges count: 22
Filtered nodes count: 11
Scene ID 1352890800322 contains -1 edges. Processing...
Scene ID 1352890875617 contains -1 edges. Processing...
Original edges count: 16
Original nodes count: 13
Filtered edges count: 13
Filtered nodes count: 13
Scene ID 1352890804562 contains -1 edges. Processing...
Original edges count: 14
Original nodes count: 10
Filtered edges count: 13
Filtered nodes count: 9
Scene ID 1352890841688 contains -1 edges. Processing...
Scene ID 1352890837555 contains -1 edges. Processing...
Scene ID 1352890825684 contains -1 edges. Processing.

In [7]:
def aggregate_scenes(scenes):
    """
    Aggregates a list of scene dictionaries into one unified graph.
    Each scene's nodes (and their features/targets) are stacked;
    the edges are adjusted (by offsetting node indices) to create a disjoint graph.

    Returns:
        all_nodes: DataFrame of all nodes (includes features and targets)
        all_edges: np.array of edges as shape (num_edges, 2)
        scene_node_indices: dict mapping scene_id -> array of node indices in all_nodes.
    """
    nodes_list = []
    edges_list = []
    scene_node_indices = {}
    node_offset = 0

    for scene in scenes:
        scene_id = scene["scene_id"]
        nodes_df = scene["nodes"].copy().reset_index(drop=True)
        edges_df = scene["edges"].copy().reset_index(drop=True)

        num_nodes = nodes_df.shape[0]
        # Record indices belonging to this scene (will be used for splitting)
        scene_node_indices[scene_id] = np.arange(node_offset, node_offset + num_nodes)

        # Create a mapping from original node_id to new index
        node_id_to_index = dict(zip(nodes_df["node_id"], range(node_offset, node_offset + num_nodes)))

        # Update edges: replace the original node_id values with the new indices.
        # Note: It is assumed that the edges DataFrame contains columns named "target" and "source".
        def map_id(x):
            return node_id_to_index.get(x, -1)

        edges_df["target"] = edges_df["target"].apply(map_id)
        edges_df["source"] = edges_df["source"].apply(map_id)
        # It is possible that some edges refer to node IDs not included in the nodes DataFrame.
        # Filter out such edges (where mapping returned -1).
        edges_df = edges_df[(edges_df["target"] != -1) & (edges_df["source"] != -1)]

        nodes_list.append(nodes_df)
        edges_list.append(edges_df)

        node_offset += num_nodes

    # Concatenate all nodes
    all_nodes = pd.concat(nodes_list, ignore_index=True)
    # Concatenate and convert edges to numpy array of type int32
    all_edges = pd.concat(edges_list, ignore_index=True).to_numpy().astype(np.int32)
    return all_nodes, all_edges, scene_node_indices


def scene_based_split(scene_node_indices, train_ratio=0.5):
    """
    Splits the scenes into train and test based on scene ids.
    Returns:
        train_indices: numpy array of node indices (all nodes belonging to training scenes)
        test_indices: numpy array of node indices (all nodes belonging to test scenes)
    """
    scene_ids = list(scene_node_indices.keys())
    scene_ids = np.array(scene_ids)
    np.random.shuffle(scene_ids)
    n_train = int(len(scene_ids) * train_ratio)
    train_scenes = scene_ids[:n_train]
    test_scenes = scene_ids[n_train:]

    train_indices = np.concatenate([scene_node_indices[sid] for sid in train_scenes])
    test_indices = np.concatenate([scene_node_indices[sid] for sid in test_scenes])
    return train_indices, test_indices

In [11]:
# ------------------------------------------------------------------------------
# Graph Attention Network (GAT) Model Definition for Regression
# ------------------------------------------------------------------------------


class GraphAttention(layers.Layer):
    def __init__(self, units, kernel_initializer="glorot_uniform", kernel_regularizer=None, **kwargs):
        super().__init__(**kwargs)
        self.units = units
        self.kernel_initializer = keras.initializers.get(kernel_initializer)
        self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)

    def build(self, input_shape):
        # input_shape[0] is node_features shape; input_shape[1] is edge tensor shape.
        self.kernel = self.add_weight(
            shape=(input_shape[0][-1], self.units),
            trainable=True,
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            name="kernel",
        )
        self.kernel_attention = self.add_weight(
            shape=(self.units * 2, 1),
            trainable=True,
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            name="kernel_attention",
        )
        # super().build(input_shape)
        self.built = True # Original

    def call(self, inputs):
        node_states, edges = inputs
        # Linear transformation of node states
        node_states_transformed = tf.matmul(node_states, self.kernel)

        # (1) Compute pair-wise attention scores for each edge
        # For each edge, gather the target and source node features and concatenate them.
        target_states = tf.gather(node_states_transformed, edges[:, 0])
        source_states = tf.gather(node_states_transformed, edges[:, 1])
        concat_features = tf.concat([target_states, source_states], axis=-1)
        e = tf.nn.leaky_relu(tf.matmul(concat_features, self.kernel_attention))
        e = tf.squeeze(e, axis=-1)

        # (2) Normalize the attention scores per target node.
        e = tf.exp(tf.clip_by_value(e, -2, 2))
        sum_e = tf.math.unsorted_segment_sum(e, edges[:, 0], num_segments=tf.shape(node_states)[0])
        # Repeat the sums to align with edge dimensions.
        sum_e_rep = tf.gather(sum_e, edges[:, 0])
        attention = e / (sum_e_rep + 1e-9)  # add epsilon to avoid division by zero

        # (3) Weighted sum of source node features
        source_transformed = tf.gather(node_states_transformed, edges[:, 1])
        messages = source_transformed * tf.expand_dims(attention, -1)
        output = tf.math.unsorted_segment_sum(messages, edges[:, 0], num_segments=tf.shape(node_states)[0])
        return output


class MultiHeadGraphAttention(layers.Layer):
    def __init__(self, units, num_heads=8, merge_type="concat", **kwargs):
        super().__init__(**kwargs)
        self.num_heads = num_heads
        self.merge_type = merge_type
        self.attention_layers = [GraphAttention(units) for _ in range(num_heads)]

    def call(self, inputs):
        node_states, edges = inputs
        # Gather outputs from each head.
        head_outputs = [att([node_states, edges]) for att in self.attention_layers]
        if self.merge_type == "concat":
            output = tf.concat(head_outputs, axis=-1)
        else:
            output = tf.reduce_mean(tf.stack(head_outputs, axis=-1), axis=-1)
        return tf.nn.relu(output)


class GraphAttentionNetwork(keras.Model):
    def __init__(self, node_states, edges, hidden_units, num_heads, num_layers, output_dim, **kwargs):
        super().__init__(**kwargs)
        self.node_states = node_states
        self.edges = edges
        self.preprocess = layers.Dense(hidden_units * num_heads, activation="relu")
        self.attention_layers = [MultiHeadGraphAttention(hidden_units, num_heads) for _ in range(num_layers)]
        self.output_layer = layers.Dense(output_dim)

    def call(self, inputs):
        # node_states, edges = inputs
        x = self.preprocess(self.node_states)
        for att_layer in self.attention_layers:
            x = att_layer([x, self.edges]) + x
        outputs = self.output_layer(x)
        return outputs

    def train_step(self, data):
        indices, labels = data
        with tf.GradientTape() as tape:
            outputs = self(None, training=True)  # call without external inputs
            predictions = tf.gather(outputs, indices)
            loss = self.compiled_loss(labels, predictions, regularization_losses=self.losses)
        grads = tape.gradient(loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.compiled_metrics.update_state(labels, predictions)
        return {m.name: m.result() for m in self.metrics}

    def test_step(self, data):
        indices, labels = data
        outputs = self(None, training=False)
        predictions = tf.gather(outputs, indices)
        loss = self.compiled_loss(labels, predictions, regularization_losses=self.losses)
        self.compiled_metrics.update_state(labels, predictions)
        results = {m.name: m.result() for m in self.metrics}
        results["loss"] = loss
        return results

    def predict_step(self, data):
        # data is just indices
        outputs = self(None, training=False)
        predictions = tf.gather(outputs, data)
        return predictions

In [21]:
# ------------------------------------------------------------------------------
# Main: Data Preparation, Model Training, and Evaluation
# ------------------------------------------------------------------------------

if __name__ == "__main__":
    # Load and aggregate scene data
    # scenes = load_data()
    all_nodes, all_edges, scene_node_indices = aggregate_scenes(scenes)

    # Split node indices by scene (scene-based split)
    train_indices, test_indices = scene_based_split(scene_node_indices, train_ratio=0.8)

    # Select input features and targets:
    # Assume that the input features are all columns except 'node_id', 'future_x', and 'future_y'
    feature_cols = [col for col in all_nodes.columns if col not in ["node_id", "future_x", "future_y"]]
    target_cols = ["future_x", "future_y"]

    # Prepare numpy arrays
    node_features_np = all_nodes[feature_cols].to_numpy().astype(np.float32)
    targets_np = all_nodes[target_cols].to_numpy().astype(np.float32)

    print("Aggregated nodes shape:", node_features_np.shape)
    print("Aggregated edges shape:", all_edges.shape)
    print("Training nodes:", train_indices.shape, "Test nodes:", test_indices.shape)

    # Convert aggregated graph data to tensors
    node_features_tensor = tf.convert_to_tensor(node_features_np)
    edges_tensor = tf.convert_to_tensor(all_edges)

    # Define hyper-parameters
    HIDDEN_UNITS = 100
    NUM_HEADS = 8
    NUM_LAYERS = 3
    OUTPUT_DIM = 2
    NUM_EPOCHS = 100
    BATCH_SIZE = 256  # batch size here relates to how many nodes to sample per update
    VALIDATION_SPLIT = 0.1 # Original
    # LEARNING_RATE = 3e-1  # Original
    LEARNING_RATE = 1e-2
    MOMENTUM = 0.9

    # Build the model
    gat_model = GraphAttentionNetwork(
        node_states=node_features_tensor,
        edges=edges_tensor,
        hidden_units=HIDDEN_UNITS,
        num_heads=NUM_HEADS,
        num_layers=NUM_LAYERS,
        output_dim=OUTPUT_DIM,
    )

    # Compile with Mean Squared Error loss and mean absolute error metric
    gat_model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
        # optimizer=keras.optimizers.SGD(LEARNING_RATE, momentum=MOMENTUM), # Original
        loss=keras.losses.MeanSquaredError(),
        metrics=[
            keras.metrics.MeanAbsoluteError(),
            keras.metrics.MeanSquaredError(),
            keras.metrics.R2Score(),
            keras.metrics.CosineSimilarity(axis=1),
        ],
    )

    # For training, we use the node indices and corresponding targets.
    # Create tf.data.Dataset objects.
    train_dataset = tf.data.Dataset.from_tensor_slices((train_indices, targets_np[train_indices]))
    train_dataset = train_dataset.shuffle(buffer_size=len(train_indices)).batch(BATCH_SIZE)

    test_dataset = tf.data.Dataset.from_tensor_slices((test_indices, targets_np[test_indices]))
    test_dataset = test_dataset.batch(BATCH_SIZE)

    early_stopping = keras.callbacks.EarlyStopping(monitor="val_acc", min_delta=1e-5, patience=5, restore_best_weights=True)

    print("Training...")
    gat_model.fit(
        train_dataset,
        # validation_split=VALIDATION_SPLIT,
        batch_size=BATCH_SIZE,
        epochs=NUM_EPOCHS,
        # callbacks=[early_stopping],
        verbose=2,
    )

Aggregated nodes shape: (1669, 4)
Aggregated edges shape: (2721, 2)
Training nodes: (1334,) Test nodes: (335,)
Training...
Epoch 1/100
6/6 - 11s - 2s/step - cosine_similarity: 0.1847 - mean_absolute_error: 70126.7266 - mean_squared_error: 27815884800.0000 - r2_score: -4.3921e+02 - loss: -4.6224e+03
Epoch 2/100
6/6 - 1s - 201ms/step - cosine_similarity: 0.5375 - mean_absolute_error: 18008.6426 - mean_squared_error: 502028064.0000 - r2_score: -5.5334e+00 - loss: 6663.4189
Epoch 3/100
6/6 - 1s - 174ms/step - cosine_similarity: 0.9766 - mean_absolute_error: 7197.7285 - mean_squared_error: 97442416.0000 - r2_score: -1.4452e-01 - loss: 4745.5459
Epoch 4/100
6/6 - 1s - 168ms/step - cosine_similarity: 0.9757 - mean_absolute_error: 4069.2561 - mean_squared_error: 34141792.0000 - r2_score: 0.5856 - loss: 5267.4731
Epoch 5/100
6/6 - 1s - 161ms/step - cosine_similarity: 0.9843 - mean_absolute_error: 3084.0432 - mean_squared_error: 19159504.0000 - r2_score: 0.7684 - loss: 5932.2534
Epoch 6/100
6/6 

In [22]:
print("Evaluating on test set...")
results = gat_model.evaluate(test_dataset, verbose=2)
print(f"\nTest Loss (MSE): {results[0]:.4f}, Test MAE: {results[1]["mean_absolute_error"]:.4f}")

# Run predictions on test nodes
print("\nSample predictions for test nodes:")
predictions = gat_model.predict(tf.convert_to_tensor(test_indices))
for i, idx in enumerate(test_indices[:5]):
    print(
        f"Node {idx}: True future_x={targets_np[idx,0]:.1f}, future_y={targets_np[idx,1]:.1f} | Predicted future_x={predictions[i,0]:.1f}, future_y={predictions[i,1]:.1f}"
    )

Evaluating on test set...
2/2 - 1s - 510ms/step - cosine_similarity: 0.9990 - mean_absolute_error: 595.9611 - mean_squared_error: 621791.4375 - r2_score: 0.9948 - loss: 730466.1250

Test Loss (MSE): 730466.1250, Test MAE: 595.9611

Sample predictions for test nodes:
[1m11/11[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 98ms/step
Node 656: True future_x=24892.0, future_y=-15408.0 | Predicted future_x=25277.5, future_y=-15701.5
Node 657: True future_x=25107.0, future_y=-14769.0 | Predicted future_x=24763.0, future_y=-14495.4
Node 658: True future_x=18820.0, future_y=-12470.0 | Predicted future_x=18283.8, future_y=-12044.5
Node 659: True future_x=21316.0, future_y=-13596.0 | Predicted future_x=21828.3, future_y=-13790.7
Node 660: True future_x=18970.0, future_y=-11871.0 | Predicted future_x=18683.3, future_y=-11388.2
