In [3]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
import pickle
with open('/kaggle/input/graph-dataset/graph_sequences.pkl', 'rb') as f:
    loaded_graph_sequences = pickle.load(f)


In [8]:
import torch
from torch_geometric.data import Data, Dataset
import numpy as np

class GraphTrajectoryDataset(Dataset):
    def __init__(self, raw_data):
        self.data = raw_data  # list of (feature_seq, adj_seq, future_coords)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        feature_seq, adj_seq, future_coords = self.data[idx]
        graph_seq = []
    
        for x, a in zip(feature_seq, adj_seq):
            x_tensor = torch.tensor(x, dtype=torch.float32)  # [N, F]
    
            # Detect valid (non-padded) nodes
            valid_mask = ~(x_tensor == 0).all(dim=1)
            valid_indices = valid_mask.nonzero(as_tuple=True)[0]
            num_valid = valid_indices.shape[0]
    
            # Edge filtering
            a = np.array(a)
            a = a[:num_valid, :num_valid]  # crop adjacency to valid nodes only
            edge_index = torch.tensor(np.array((a > 0).nonzero())).long()  # [2, E]
    
            graph_seq.append(Data(x=x_tensor[:num_valid], edge_index=edge_index))
        # Ensure it's exactly 4 points
        while len(future_coords) < 4:
            future_coords.append(future_coords[-1])  # Repeat last coordinate or use (0,0)

        future_coords = torch.tensor(future_coords, dtype=torch.float32)  # [4, 2]
        return graph_seq, future_coords


In [16]:
def graph_sequence_collate(batch):
    """
    Custom collate function to support list of sequences of torch_geometric Data.
    """
    batched_graph_seqs = []
    future_coords_list = []

    for graph_seq, future_coords in batch:
        batched_graph_seqs.append(graph_seq)
        future_coords_list.append(future_coords)

    return batched_graph_seqs, torch.stack(future_coords_list)


In [17]:
import torch
from torch.utils.data import Dataset, DataLoader

dataset = GraphTrajectoryDataset(loaded_graph_sequences)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=graph_sequence_collate)


In [18]:
def constant_velocity_baseline(graph_seq, future_steps=4):
    last_pos = graph_seq[-1].x[0][:2]        # ego agent position at T
    prev_pos = graph_seq[-2].x[0][:2]        # ego agent position at T-1
    velocity = last_pos - prev_pos

    preds = [last_pos + (i + 1) * velocity for i in range(future_steps)]
    return torch.stack(preds)  # [future_steps, 2]


In [11]:
import torch.nn as nn

class SimpleGRUBaseline(nn.Module):
    def __init__(self, input_dim=2, hidden_dim=64, future_steps=4):
        super().__init__()
        self.gru = nn.GRU(input_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, 2)
        self.future_steps = future_steps

    def forward(self, traj_seq):  # traj_seq: [1, T, 2]
        _, hidden = self.gru(traj_seq)
        decoder_input = traj_seq[:, -1:]  # start from last known point
        outputs = []

        for _ in range(self.future_steps):
            out, hidden = self.gru(decoder_input, hidden)
            pred = self.fc(out[:, -1])
            outputs.append(pred)
            decoder_input = pred.unsqueeze(1)

        return torch.stack(outputs, dim=1).squeeze(0)  # [future_steps, 2]


In [19]:
class MLPBaseline(nn.Module):
    def __init__(self, input_steps=5, future_steps=4):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_steps * 2, 128),
            nn.ReLU(),
            nn.Linear(128, future_steps * 2)
        )

    def forward(self, traj_seq):  # traj_seq: [1, T, 2]
        x = traj_seq.view(1, -1)  # [1, T*2]
        out = self.mlp(x).view(-1, 2)  # [future_steps, 2]
        return out


In [20]:
def compute_ade_fde(pred, target):
    """
    pred, target: [future_steps, 2]
    Returns: ADE, FDE
    """
    ade = ((pred - target) ** 2).sum(dim=1).sqrt().mean()
    fde = ((pred[-1] - target[-1]) ** 2).sum().sqrt()
    return ade.item(), fde.item()


In [21]:
def evaluate_models(graph_seq, target_seq, models):
    """
    graph_seq: List[graph at each past timestep]
    target_seq: Tensor of shape [future_steps, 2] (ground truth)
    models: Dict with model name as key and model/callback as value
    """
    results = {}
    traj_seq = torch.stack([g.x[0][:2] for g in graph_seq]).unsqueeze(0)  # [1, T, 2]

    for name, model in models.items():
        if name == "ConstantVelocity":
            pred = model(graph_seq)
        elif name == "SimpleGRU" or name == "MLP":
            pred = model(traj_seq)
        else:  # Your GATGRUWithAttention
            pred = model(graph_seq, target_seq.unsqueeze(0))

        ade, fde = compute_ade_fde(pred, target_seq)
        results[name] = {"ADE": ade, "FDE": fde}

    return results


In [24]:
# Assume graph_seq and target_seq are defined as before
# Instantiate models
gru_baseline = SimpleGRUBaseline()
mlp_baseline = MLPBaseline()
# gat_gru_model = GATGRUWithAttention(in_channels=4, hidden_dim=64)

# Load trained weights if available
# gru_baseline.load_state_dict(...)
# mlp_baseline.load_state_dict(...)
# gat_gru_model.load_state_dict(...)

models = {
    "ConstantVelocity": constant_velocity_baseline,
    "SimpleGRU": gru_baseline,
    "MLP": mlp_baseline
    # "GATGRUWithAttention": gat_gru_model
}

# Ensure models are in eval mode
for model in models.values():
    if isinstance(model, torch.nn.Module):
        model.eval()

for graph_seqs, future_coords_batch in dataloader:
    graph_seq = [g for g in graph_seqs[0]]             # list of graphs
    future_coords = future_coords_batch[0].unsqueeze(0)  # [1, 4, 2]

    results = evaluate_models(graph_seq, future_coords.squeeze(0), models)

for name, metrics in results.items():
    print(f"{name}: ADE = {metrics['ADE']:.4f}, FDE = {metrics['FDE']:.4f}")


ConstantVelocity: ADE = 2738.9756, FDE = 4692.5039
SimpleGRU: ADE = 1077.0840, FDE = 450.3465
MLP: ADE = 1023.8361, FDE = 704.5679
