In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import pandas as pd
import os
import warnings
import glob

warnings.filterwarnings("ignore")
pd.set_option("display.max_columns", 6)
pd.set_option("display.max_rows", 6)
np.random.seed(2)

2025-04-15 16:08:21.375976: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Load the data and process it

In [43]:
"""
As in the tutorial, the provided dataset consists of two files for each traffic scene:

<scene_id>.edges
two columns containing node IDs
target, source
Note: The tutorial models directed edges with source -> target.
You can either use undirected edges by changing the implementation or adding the missing entries to the edges file,
e.g., to the line target, source, you add the line source target. If you want to be more fancy, you could also try to infer
which other pedestrians the source node can see in their field of view and only add those (this would model that the movement
decisions are based only on the pedestrians in the field of view.)
<scene_id>.nodes
seven columns with node properties and target values, which should be predicted 
node id, current x, current y, previous x, previous y, future x, future y
the previous x and y represents the location of the pedestrian 1 second ago (you can use those values directly or infer the
movement direction and some speed estimate yourself)
the future x and y represents the target value, i.e., the location where the pedestrian will be in 1 second
Note: Some pedestrians do not have a future x and y coordinate, so you need to filter those for prediction. However, you can
still use their current and previous location when predicting the future location of other pedestrians.
"""
"""
Sample data: 
file: dataset/13528908058.edges:
contains data: target, source
19585800, 19590700
19585800, 19595200
19585800, 20000100
19590700, 19595200
19590700, 20000100
19591900, 19594200
19591900, 19595300
19591900, 19595800
19592201, 19595800
19592400, 20000200
19592800, 20000200
19592800, 20000300
19594200, 19595300
19594200, 19595800
19595200, 20000100
19595300, 19595800
20000200, 20000300
19502500, -1

corresponding nodes file: 
file: dataset/13528908058.nodes
contains data: node id, current x, current y, previous x, previous y, future x, future y
19502500,40050.0,-16544.0,40176.0,-16619.0,40205.0,-16357.0
19585800,16802.0,-11108.0,16140.0,-10573.0,17831.0,-11792.0
19590700,16846.0,-10526.0,16079.0,-9694.0,17528.0,-11131.0
19591900,11346.0,-6253.0,10833.0,-5840.0,12184.0,-6936.0
19592201,14232.0,-8556.0,13610.0,-7856.0,14867.0,-9359.0
19592400,5649.0,191.0,6542.0,-779.0,4809.0,1245.0
19592800,9097.0,1278.0,9323.0,1412.0,9221.0,1551.0
19594200,11262.0,-5387.0,10468.0,-4863.0,12387.0,-6358.0
19595200,18425.0,-11390.0,17495.0,-10254.0,20034.0,-12398.0
19595300,11060.0,-5800.0,10335.0,-5510.0,11890.0,-6411.0
19595800,12432.0,-7962.0,_,_,12267.0,-7515.0
20000100,18149.0,-10095.0,18159.0,-9697.0,17792.0,-10597.0
20000200,7989.0,-70.0,7759.0,1511.0,8475.0,-1480.0
20000300,8677.0,41.0,8353.0,1315.0,9280.0,-1068.0
""" 

"""
Sample data: 
file: dataset/13528908058.edges:
contains data: target, source
19585800, 19590700
19585800, 19595200
19585800, 20000100
19590700, 19595200
19590700, 20000100
19591900, 19594200
19591900, 19595300
19591900, 19595800
19592201, 19595800
19592400, 20000200
19592800, 20000200
19592800, 20000300
19594200, 19595300
19594200, 19595800
19595200, 20000100
19595300, 19595800
20000200, 20000300
19502500, -1

corresponding nodes file: 
file: dataset/13528908058.nodes
contains data: node id, current x, current y, previous x, previous y, future x, future y
19502500,40050.0,-16544.0,40176.0,-16619.0,40205.0,-16357.0
19585800,16802.0,-11108.0,16140.0,-10573.0,17831.0,-11792.0
19590700,16846.0,-10526.0,16079.0,-9694.0,17528.0,-11131.0
19591900,11346.0,-6253.0,10833.0,-5840.0,12184.0,-6936.0
19592201,14232.0,-8556.0,13610.0,-7856.0,14867.0,-9359.0
19592400,5649.0,191.0,6542.0,-779.0,4809.0,1245.0
19592800,9097.0,1278.0,9323.0,1412.0,9221.0,1551.0
19594200,11262.0,-5387.0,10468.0,-4863.0,12387.0,-6358.0
19595200,18425.0,-11390.0,17495.0,-10254.0,20034.0,-12398.0
19595300,11060.0,-5800.0,10335.0,-5510.0,11890.0,-6411.0
19595800,12432.0,-7962.0,_,_,12267.0,-7515.0
20000100,18149.0,-10095.0,18159.0,-9697.0,17792.0,-10597.0
20000200,7989.0,-70.0,7759.0,1511.0,8475.0,-1480.0
20000300,8677.0,41.0,8353.0,1315.0,9280.0,-1068.0
""" 

dataset_path = "dataset/"

# Function to load all scenes while keeping each scene as a separate graph
def load_all_scenes():
    scene_ids = [file.split("/")[-1].split(".")[0] for file in glob.glob(f"{dataset_path}*.nodes")]
    
    all_scenes = []
    for scene_id in scene_ids:
        edges_file = os.path.join(dataset_path, f"{scene_id}.edges")
        nodes_file = os.path.join(dataset_path, f"{scene_id}.nodes")

        edges_df = pd.read_csv(edges_file, header=None, names=["target", "source"], na_values="_")
        nodes_df = pd.read_csv(
            nodes_file,
            header=None,
            names=["node_id", "current_x", "current_y", "prev_x", "prev_y", "future_x", "future_y"],
            na_values="_"
        )
        
        # Convert "_" to NaN
        nodes_df = nodes_df.replace('_', np.nan)
        edges_df = edges_df.replace('_', np.nan)
        
        # Filter out nodes with missing future positions
        nodes_df = nodes_df.dropna(subset=["future_x", "future_y"])
        
        # Create mapping for node ids within this scene
        node_id_to_index = {node_id: idx for idx, node_id in enumerate(nodes_df["node_id"].values)}
        
        # Process edges using the node id mapping
        edges_df = edges_df.dropna()
        edges_df['target'] = edges_df['target'].apply(lambda x: node_id_to_index.get(x, -1))
        edges_df['source'] = edges_df['source'].apply(lambda x: node_id_to_index.get(x, -1))
        edges_df = edges_df[(edges_df['target'] != -1) & (edges_df['source'] != -1)]

        # Calculate motion features
        # nodes_df["dir_x"] = nodes_df["current_x"] - nodes_df["prev_x"]
        # nodes_df["dir_y"] = nodes_df["current_y"] - nodes_df["prev_y"]
        # nodes_df["speed"] = np.sqrt(nodes_df["dir_x"]**2 + nodes_df["dir_y"]**2)
        
        # Extract features and labels
        node_features = nodes_df[["current_x", "current_y", "prev_x", "prev_y"]].values
        labels = nodes_df[["future_x", "future_y"]].values
        edges = edges_df[["target", "source"]].values
        
        # Store as a structured scene
        all_scenes.append({
            "scene_id": scene_id,
            "node_features": tf.convert_to_tensor(node_features, dtype=tf.float32),
            "edges": tf.convert_to_tensor(edges, dtype=tf.int32),
            "labels": tf.convert_to_tensor(labels, dtype=tf.float32),
            "node_indices": np.arange(len(nodes_df)),
        })
    
    return all_scenes

scenes = load_all_scenes()

# np.random.seed(2)
np.random.shuffle(scenes)

""" Om man vill ha 70% Training set """
train_scenes = scenes[:int(0.7 * len(scenes))]
val_scenes = scenes[int(0.7 * len(scenes)):int(0.9 * len(scenes))]
test_scenes = scenes[int(0.9 * len(scenes)):]

""" Om man vill ha 50% Training set """
# train_scenes = scenes[:int(0.5 * len(scenes))]
# val_scenes = scenes[int(0.5 * len(scenes)):int(0.9 * len(scenes))]
# test_scenes = scenes[int(0.9 * len(scenes)):]

print(f"Total scenes: {len(scenes)}")
print(f"Training scenes: {len(train_scenes)}")
print(f"Validation scenes: {len(val_scenes)}")
print(f"Test scenes: {len(test_scenes)}")

Total scenes: 193
Training scenes: 135
Validation scenes: 38
Test scenes: 20


## GAT model implementation

In [44]:
class GraphAttention(layers.Layer):
    def __init__(
        self,
        units,
        kernel_initializer="glorot_uniform",
        kernel_regularizer=None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.units = units
        self.kernel_initializer = keras.initializers.get(kernel_initializer)
        self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)

    def build(self, input_shape):
        self.kernel = self.add_weight(
            shape=(input_shape[0][-1], self.units),
            trainable=True,
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            name="kernel",
        )
        self.kernel_attention = self.add_weight(
            shape=(self.units * 2, 1),
            trainable=True,
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            name="kernel_attention",
        )
        self.built = True

    def call(self, inputs):
        node_states, edges = inputs

        # Linearly transform node states
        node_states_transformed = tf.matmul(node_states, self.kernel)

        # (1) Compute pair-wise attention scores
        node_states_expanded = tf.gather(node_states_transformed, edges)
        node_states_expanded = tf.reshape(
            node_states_expanded, (tf.shape(edges)[0], -1)
        )
        attention_scores = tf.nn.leaky_relu(
            tf.matmul(node_states_expanded, self.kernel_attention)
        )
        attention_scores = tf.squeeze(attention_scores, -1)

        # (2) Normalize attention scores
        attention_scores = tf.math.exp(tf.clip_by_value(attention_scores, -2, 2))
        attention_scores_sum = tf.math.unsorted_segment_sum(
            data=attention_scores,
            segment_ids=edges[:, 0],
            num_segments=tf.reduce_max(edges[:, 0]) + 1,
        )
        attention_scores_sum = tf.repeat(
            attention_scores_sum, tf.math.bincount(tf.cast(edges[:, 0], "int32"))
        )
        attention_scores_norm = attention_scores / attention_scores_sum

        # (3) Gather node states of neighbors, apply attention scores and aggregate
        node_states_neighbors = tf.gather(node_states_transformed, edges[:, 1])
        out = tf.math.unsorted_segment_sum(
            data=node_states_neighbors * attention_scores_norm[:, tf.newaxis],
            segment_ids=edges[:, 0],
            num_segments=tf.shape(node_states)[0],
        )
        return out


class MultiHeadGraphAttention(layers.Layer):
    def __init__(self, units, num_heads=8, merge_type="concat", **kwargs):
        super().__init__(**kwargs)
        self.num_heads = num_heads
        self.merge_type = merge_type
        self.attention_layers = [GraphAttention(units) for _ in range(num_heads)]

    def call(self, inputs):
        atom_features, pair_indices = inputs

        # Obtain outputs from each attention head
        outputs = [
            attention_layer([atom_features, pair_indices])
            for attention_layer in self.attention_layers
        ]
        # Concatenate or average the node states from each head
        if self.merge_type == "concat":
            outputs = tf.concat(outputs, axis=-1)
        else:
            outputs = tf.reduce_mean(tf.stack(outputs, axis=-1), axis=-1)
        # Activate and return node states
        return tf.nn.relu(outputs)


class GraphAttentionNetwork(keras.Model):
    def __init__(
        self,
        node_states,
        edges,
        hidden_units,
        num_heads,
        num_layers,
        output_dim,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.node_states = node_states
        self.edges = edges
        self.preprocess = layers.Dense(hidden_units * num_heads, activation="relu")
        self.attention_layers = [
            MultiHeadGraphAttention(hidden_units, num_heads) for _ in range(num_layers)
        ]
        self.output_layer = layers.Dense(output_dim)

    def call(self, inputs):
        node_states, edges = inputs
        x = self.preprocess(node_states)
        for attention_layer in self.attention_layers:
            x = attention_layer([x, edges]) + x
        outputs = self.output_layer(x)
        return outputs

    def train_step(self, data):
        indices, labels = data

        with tf.GradientTape() as tape:
            # Forward pass
            outputs = self([self.node_states, self.edges])
            # Compute loss
            loss = self.compiled_loss(labels, tf.gather(outputs, indices))
        # Compute gradients
        grads = tape.gradient(loss, self.trainable_weights)
        # Apply gradients (update weights)
        optimizer.apply_gradients(zip(grads, self.trainable_weights))
        # Update metric(s)
        self.compiled_metrics.update_state(labels, tf.gather(outputs, indices))

        return {m.name: m.result() for m in self.metrics}

    def predict_step(self, data):
        # indices = data
        # # Forward pass
        # outputs = self([self.node_states, self.edges])
        # # Compute probabilities
        # return tf.nn.softmax(tf.gather(outputs, indices))
        indices = data
        
        outputs = self([self.node_states, self.edges])
        return tf.gather(outputs, indices)

    def test_step(self, data):
        indices, labels = data
        # Forward pass
        outputs = self([self.node_states, self.edges])
        # Compute loss
        loss = self.compiled_loss(labels, tf.gather(outputs, indices))
        # Update metric(s)
        self.compiled_metrics.update_state(labels, tf.gather(outputs, indices))

        return {m.name: m.result() for m in self.metrics}


## Test and evaluation 1

In [79]:

# Define hyper-parameters
HIDDEN_UNITS = 100
NUM_HEADS = 8
NUM_LAYERS = 3
OUTPUT_DIM = 2 # For position prediction, output_dim should be 2 (x and y coordinates)

# The rest of your model definition and training code can remain the same
# Just remember that you're now predicting continuous values (positions)
# instead of class labels

NUM_EPOCHS = 2
BATCH_SIZE = 256
VALIDATION_SPLIT = 0.2
LEARNING_RATE = 1e-6

# MOMENTUM = 0.9
# loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# optimizer = keras.optimizers.SGD(LEARNING_RATE, momentum=MOMENTUM)
# accuracy_fn = keras.metrics.SparseCategoricalAccuracy(name="acc")


loss_fn = keras.losses.MeanSquaredError(name="mean_squared_error")
optimizer = keras.optimizers.Adam(
    learning_rate=LEARNING_RATE, 
    clipnorm=1.0,
    epsilon=1e-8
)
accuracy_fn = keras.metrics.MeanSquaredError(name="mean_squared_error")

# Prepare data for training and testing
# Let's use the scenes we've loaded earlier

# Helper function to prepare data from scenes
# Modified prepare_data_from_scenes function to handle NaN values
def prepare_data_from_scenes(scenes):
    all_node_features = []
    all_edges = []
    all_labels = []
    all_indices = []
    
    total_nodes = 0
    for scene in scenes:
        # Get data from scene
        node_features = scene["node_features"]
        edges = scene["edges"]
        labels = scene["labels"]
        
        # Check for NaN values
        valid_mask = tf.math.logical_not(tf.math.reduce_any(tf.math.is_nan(node_features), axis=1))
        #if tf.reduce_sum(tf.cast(valid_mask, tf.int32)) < len(node_features):
            #print(f"Found {len(node_features) - tf.reduce_sum(tf.cast(valid_mask, tf.int32))} NaN values in scene {scene['scene_id']}")
        
        # Filter out nodes with NaN values and adjust edges
        valid_indices = tf.where(valid_mask)[:, 0]
        
        # Create index mapping for valid nodes
        old_to_new = {}
        for new_idx, old_idx in enumerate(valid_indices.numpy()):
            old_to_new[old_idx] = new_idx
        
        # Filter node features and labels
        node_features_filtered = tf.gather(node_features, valid_indices)
        labels_filtered = tf.gather(labels, valid_indices)
        
        # Adjust edges to account for removed nodes
        valid_edges = []
        for edge in edges.numpy():
            source, target = edge
            if source in old_to_new and target in old_to_new:
                valid_edges.append([old_to_new[source], old_to_new[target]])
        
        # Skip if no valid edges remain
        if not valid_edges:
            # print(f"Skipping scene {scene['scene_id']} - no valid edges after filtering")
            continue
            
        # Convert to tensor
        edges_filtered = tf.constant(valid_edges, dtype=tf.int32)
        
        # Adjust edge indices to account for the total number of nodes so far
        adjusted_edges = edges_filtered + total_nodes
        
        # Create node indices for this scene
        indices = tf.range(len(node_features_filtered)) + total_nodes
        
        # Append to our lists
        all_node_features.append(node_features_filtered)
        all_edges.append(adjusted_edges)
        all_labels.append(labels_filtered)
        all_indices.append(indices)
        
        # Update total node count
        total_nodes += len(node_features_filtered)
    
    # Concatenate all data
    node_states = tf.concat(all_node_features, axis=0)
    edges_list = tf.concat(all_edges, axis=0)
    labels = tf.concat(all_labels, axis=0)
    indices = tf.concat(all_indices, axis=0)
    
    return node_states, edges_list, indices, labels
# Prepare training data
node_states, edges, train_indices, train_labels = prepare_data_from_scenes(train_scenes)

print("edges: ", edges[:10])
print("node_states: ", node_states[:10])
# Prepare validation data (optional)
val_node_states, val_edges, val_indices, val_labels = prepare_data_from_scenes(val_scenes)

# Prepare test data
test_node_states, test_edges, test_indices, test_labels = prepare_data_from_scenes(test_scenes)

# Print dataset statistics
print(f"Training: {len(train_indices)} nodes")
print(f"Validation: {len(val_indices)} nodes")
print(f"Testing: {len(test_indices)} nodes")

# Build model
gat_model = GraphAttentionNetwork(
    node_states, edges, HIDDEN_UNITS, NUM_HEADS, NUM_LAYERS, OUTPUT_DIM
)

# Compile model
gat_model.compile(loss=loss_fn, optimizer=optimizer, metrics=[accuracy_fn])

# Define callbacks for better training
callbacks = [
    keras.callbacks.EarlyStopping(
        monitor="val_mean_squared_error",
        patience=10,
        restore_best_weights=True
    ),
]

# Train the model with TensorFlow indices
history = gat_model.fit(
    x=train_indices,
    y=train_labels,
    validation_data=(val_indices, val_labels),
    batch_size=128,
    epochs=10,
    verbose=1,
    callbacks=callbacks
)

# Evaluate on test set
_, test_accuracy = gat_model.evaluate(x=test_indices, y=test_labels, verbose=1)
print(f"Test MSE: {test_accuracy}")

# Make predictions on test data
test_predictions = gat_model.predict(test_indices)

# Denormalize predictions and ground truth for visualization
def denormalize(normalized_coords, min_vals, max_vals):
    return normalized_coords * (max_vals - min_vals) + min_vals

# Visualize some predictions (first test scene as example)
test_scene = test_scenes[0]
# Convert to tensor-compatible indexing
scene_indices = tf.where(test_indices < len(test_scene["node_features"]))[:, 0]

if len(scene_indices) > 0:
    # Convert to numpy array for easier handling
    scene_indices_np = scene_indices.numpy()
    scene_predictions = np.array(test_predictions)[scene_indices_np]
    scene_ground_truth = np.array(test_labels)[scene_indices_np]
    
    # Check for valid min/max values
    if np.any(np.isnan(test_scene["min_vals"])) or np.any(np.isnan(test_scene["max_vals"])):
        print("Warning: NaN values found in min_vals or max_vals")
        # Use reasonable defaults if min/max values are NaN
        min_vals = np.nanmin(scene_ground_truth, axis=0) if np.any(np.isnan(test_scene["min_vals"])) else test_scene["min_vals"]
        max_vals = np.nanmax(scene_ground_truth, axis=0) if np.any(np.isnan(test_scene["max_vals"])) else test_scene["max_vals"]
    else:
        min_vals = test_scene["min_vals"]
        max_vals = test_scene["max_vals"]
    
    # Print min/max values
    print(f"min_vals: {min_vals}")
    print(f"max_vals: {max_vals}")
    
    # Make sure we don't have division by zero in the original normalization
    range_vals = max_vals - min_vals
    if np.any(range_vals == 0):
        print("Warning: Range of values is zero, adding small epsilon")
        range_vals = np.where(range_vals == 0, 1e-8, range_vals)
        
    # Denormalize with safeguards
    pred_denorm = scene_predictions * range_vals + min_vals
    truth_denorm = scene_ground_truth * range_vals + min_vals
    
    # Print some predictions
    print("Predictions vs Ground Truth (first 5):")
    for i in range(min(5, len(scene_predictions))):
        print(f"Pred: [{pred_denorm[i][0]:.2f}, {pred_denorm[i][1]:.2f}], " 
              f"True: [{truth_denorm[i][0]:.2f}, {truth_denorm[i][1]:.2f}]")

edges:  tf.Tensor(
[[ 0  9]
 [ 1 11]
 [ 2  3]
 [ 2  5]
 [ 2  6]
 [ 2 10]
 [ 3  5]
 [ 3  6]
 [ 3 10]
 [ 5  6]], shape=(10, 2), dtype=int32)
node_states:  tf.Tensor(
[[0.8478331  0.28035972 0.8478601  0.2801659 ]
 [0.7390847  0.08876313 0.7381145  0.08124346]
 [0.9218413  0.07608822 0.89852846 0.09841467]
 [0.912624   0.01845033 0.8942702  0.03449746]
 [0.01347564 1.         0.01732967 0.9844955 ]
 [0.9331339  0.09062367 0.90631735 0.1182604 ]
 [0.9389284  0.04961433 0.91248924 0.06938253]
 [0.96172917 0.20512423 0.92731243 0.2168301 ]
 [0.9600582  0.23252839 0.93394244 0.2434978 ]
 [0.8798512  0.22415598 0.85007006 0.21384549]], shape=(10, 4), dtype=float32)
Training: 931 nodes
Validation: 295 nodes
Testing: 194 nodes
Epoch 1/10
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 628ms/step - mean_squared_error: 0.2627 - loss: 0.0701 - val_loss: 0.0703
Epoch 2/10
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 177ms/step - mean_squared_error: 0.2620 - loss: 0.07

## Test and evaluation 2

In [42]:
# Prepare data for training and testing
# Let's use the scenes we've loaded earlier

# Helper function to prepare data from scenes
# Modified prepare_data_from_scenes function to handle NaN values
def prepare_data_from_scenes(scenes):
    all_node_features = []
    all_edges = []
    all_labels = []
    all_indices = []
    
    total_nodes = 0
    for scene in scenes:
        # Get data from scene
        node_features = scene["node_features"]
        edges = scene["edges"]
        labels = scene["labels"]
        
        # Check for NaN values
        valid_mask = tf.math.logical_not(tf.math.reduce_any(tf.math.is_nan(node_features), axis=1))
        #if tf.reduce_sum(tf.cast(valid_mask, tf.int32)) < len(node_features):
            #print(f"Found {len(node_features) - tf.reduce_sum(tf.cast(valid_mask, tf.int32))} NaN values in scene {scene['scene_id']}")
        
        # Filter out nodes with NaN values and adjust edges
        valid_indices = tf.where(valid_mask)[:, 0]
        
        # Create index mapping for valid nodes
        old_to_new = {}
        for new_idx, old_idx in enumerate(valid_indices.numpy()):
            old_to_new[old_idx] = new_idx
        
        # Filter node features and labels
        node_features_filtered = tf.gather(node_features, valid_indices)
        labels_filtered = tf.gather(labels, valid_indices)
        
        # Adjust edges to account for removed nodes
        valid_edges = []
        for edge in edges.numpy():
            source, target = edge
            if source in old_to_new and target in old_to_new:
                valid_edges.append([old_to_new[source], old_to_new[target]])
        
        # Skip if no valid edges remain
        if not valid_edges:
            # print(f"Skipping scene {scene['scene_id']} - no valid edges after filtering")
            continue
            
        # Convert to tensor
        edges_filtered = tf.constant(valid_edges, dtype=tf.int32)
        
        # Adjust edge indices to account for the total number of nodes so far
        adjusted_edges = edges_filtered + total_nodes
        
        # Create node indices for this scene
        indices = tf.range(len(node_features_filtered)) + total_nodes
        
        # Append to our lists
        all_node_features.append(node_features_filtered)
        all_edges.append(adjusted_edges)
        all_labels.append(labels_filtered)
        all_indices.append(indices)
        
        # Update total node count
        total_nodes += len(node_features_filtered)
    
    # Concatenate all data
    node_states = tf.concat(all_node_features, axis=0)
    edges_list = tf.concat(all_edges, axis=0)
    labels = tf.concat(all_labels, axis=0)
    indices = tf.concat(all_indices, axis=0)
    
    return node_states, edges_list, indices, labels

# Prepare training data
node_states, edges, train_indices, train_labels = prepare_data_from_scenes(train_scenes)
train_labels = train_labels

# print("edges: ", edges[:10])
# print("node_states: ", node_states[:10])

# Prepare validation data (optional)
val_node_states, val_edges, val_indices, val_labels = prepare_data_from_scenes(val_scenes)

# Prepare test data
test_node_states, test_edges, test_indices, test_labels = prepare_data_from_scenes(test_scenes)

# Print dataset statistics
print(f"Training: {len(train_indices)} nodes")
print(f"Validation: {len(val_indices)} nodes")
print(f"Testing: {len(test_indices)} nodes")


# Define hyper-parameters
HIDDEN_UNITS = 100
NUM_HEADS = 8
NUM_LAYERS = 3
OUTPUT_DIM = 2


NUM_EPOCHS = 25
BATCH_SIZE = 128
LEARNING_RATE = 1e-5

loss_fn = keras.losses.MeanAbsoluteError(name="mean_absolute_error")
optimizer = keras.optimizers.Adam(
    learning_rate=LEARNING_RATE, 
    #clipnorm=1.0,
    epsilon=1e-8
)
accuracy_fn = keras.metrics.MeanAbsoluteError(name="mean_absolute_error")

# Build model
gat_model = GraphAttentionNetwork(
    node_states, edges, HIDDEN_UNITS, NUM_HEADS, NUM_LAYERS, OUTPUT_DIM
)

# Compile model
gat_model.compile(loss=loss_fn, optimizer=optimizer, metrics=[accuracy_fn])

# Train the model with TensorFlow indices
history = gat_model.fit(
    x=train_indices,
    y=train_labels,
    validation_data=(val_indices, val_labels),
    batch_size=BATCH_SIZE,
    epochs=NUM_EPOCHS,
    verbose=1,
)

# Evaluate on test set
_, test_accuracy = gat_model.evaluate(x=test_indices, y=test_labels, verbose=1)
print(f"Test Mean Absolute Error: {test_accuracy}")

# Visualize some predictions (first test scene as example)
test_scene = test_scenes[0]
# Convert to tensor-compatible indexing
scene_indices = tf.where(test_indices < len(test_scene["node_features"]))[:, 0]

# Make predictions on test data
test_predictions = gat_model.predict(test_indices)

def check_scene_validity(scene):
    """Check if a scene has valid normalization values and features."""
    has_nan_features = tf.math.reduce_any(tf.math.is_nan(scene["node_features"]))
    
    return not (has_nan_features)

# First, filter out test scenes with NaN values
valid_test_scenes = [scene for scene in test_scenes if check_scene_validity(scene)]
print(f"Filtered out {len(test_scenes) - len(valid_test_scenes)} test scenes with NaN values")
print(f"Remaining valid test scenes: {len(valid_test_scenes)}")

# Make predictions on test data
test_predictions = gat_model.predict(test_indices)

# Visualize predictions for the first valid test scene
if valid_test_scenes:
    test_scene = valid_test_scenes[0]
    print(f"Selected scene {test_scene['scene_id']} for visualization")
    
    # Get indices for this scene
    scene_node_count = len(test_scene["node_features"])
    # Find which indices in the test set correspond to this scene
    scene_global_indices = []
    current_offset = 0
    
    for i, s in enumerate(valid_test_scenes):
        if i == 0:  # First scene matches our selected test_scene
            scene_global_indices = np.arange(current_offset, current_offset + scene_node_count)
            break
        current_offset += len(s["node_features"])
    
    # Get predictions for these indices
    scene_predictions = np.array(test_predictions)[scene_global_indices]
    scene_ground_truth = np.array(test_labels)[scene_global_indices]
    
    # Print some predictions in the original coordinate space
    print("\nPredictions vs Ground Truth (first 5):")
    for i in range(min(5, len(scene_predictions))):
        print(f"Pred: [{scene_predictions[i][0]:.2f}, {scene_predictions[i][1]:.2f}], " 
              f"True: [{scene_ground_truth[i][0]:.2f}, {scene_ground_truth[i][1]:.2f}]")
    
    # Calculate mean absolute error
    mae = np.mean(np.abs(scene_predictions - scene_ground_truth))
    print(f"Mean Absolute Error: {mae:.2f}")
    
    # Calculate Euclidean distance error
    euclidean_errors = np.sqrt(np.sum((scene_predictions - scene_ground_truth)**2, axis=1))
    mean_euclidean_error = np.mean(euclidean_errors)
    print(f"Mean Euclidean Distance Error: {mean_euclidean_error:.2f} units")



Training: 970 nodes
Validation: 285 nodes
Testing: 165 nodes
Epoch 1/25
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m30s[0m 639ms/step - mean_absolute_error: 18741.5293 - loss: 422.2069 - val_loss: 445.0100
Epoch 2/25
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 178ms/step - mean_absolute_error: 16971.0137 - loss: 476.4238 - val_loss: 571.2065
Epoch 3/25
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 172ms/step - mean_absolute_error: 15897.3115 - loss: 568.1819 - val_loss: 671.7411
Epoch 4/25
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 172ms/step - mean_absolute_error: 14188.2725 - loss: 727.1000 - val_loss: 817.9021
Epoch 5/25
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 170ms/step - mean_absolute_error: 12481.5254 - loss: 939.2967 - val_loss: 1075.7410
Epoch 6/25
[1m8/8[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 175ms/step - mean_absolute_error: 11351.4365 - loss: 1326.4062 - val_loss: 1799.7566