In [2]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
import time

# Dataset Class
import torch

import pandas as pd

import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from matplotlib.widgets import Slider
import numpy as np
import gc
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader
import torch.utils.checkpoint as checkpoint
import json
import matplotlib.pyplot as plt
from matplotlib import cm
import gc
import torch.nn.functional as F


In [3]:
class PFM_TrajectoryDataset(torch.utils.data.Dataset):
    def __init__(self, file_path, history_len=8, prediction_len=12):
        self.data, self.uses_scene_id = self.load_data(file_path)
        self.history_len = history_len
        self.prediction_len = prediction_len
        self.valid_frames = self._get_valid_frames()

    def load_data(self, file_path):
        data = {}
        uses_scene_id = False
        with open(file_path, 'r') as file:
            for line_idx, line in enumerate(file):
                parts = line.strip().split(',')
                if len(parts) == 4:
                    c1, agent, x, y = map(float, parts)
                    agent = int(agent)

                    # Detect format:
                    # If c1 grows smoothly like 10,20,30 => Zara format
                    # Else if c1 is same for many lines => scene ID format
                    if line_idx == 0:
                        first_val = c1
                    if line_idx == 1 and c1 == first_val:
                        uses_scene_id = True

                    frame = int(c1) if not uses_scene_id else line_idx  # if Zara -> frame, else use line index as frame
                    if frame not in data:
                        data[frame] = {}
                    data[frame][agent] = torch.tensor([x, y], dtype=torch.float32)

        return data, uses_scene_id

    def _get_valid_frames(self):
        all_frames = sorted(self.data.keys())
        valid_frames = []
        for frame in all_frames:
            history_start = frame - self.history_len + 1
            future_end = frame + self.prediction_len
            if history_start >= min(all_frames) and future_end <= max(all_frames):
                valid_frames.append(frame)
        return valid_frames

    def __len__(self):
        return len(self.valid_frames)

    def __getitem__(self, idx):
        frame = self.valid_frames[idx]
        agents = list(self.data[frame].keys())
        num_agents = len(agents)

        history = torch.zeros(num_agents, self.history_len, 2)
        future = torch.zeros(num_agents, self.prediction_len, 2)
        goals = torch.zeros(num_agents, 2)

        # Build history & future
        for i, agent in enumerate(agents):
            for t in range(self.history_len):
                hist_frame = frame - (self.history_len - 1 - t)
                if hist_frame in self.data and agent in self.data[hist_frame]:
                    history[i, t] = self.data[hist_frame][agent]
            for t in range(self.prediction_len):
                fut_frame = frame + t + 1
                if fut_frame in self.data and agent in self.data[fut_frame]:
                    future[i, t] = self.data[fut_frame][agent]

            # goal = last non-zero future point
            non_zero_mask = torch.any(future[i] != 0, dim=1)
            if non_zero_mask.any():
                last_valid_idx = torch.where(non_zero_mask)[0][-1]
                goals[i] = future[i, last_valid_idx]
            else:
                goals[i] = self.data[frame][agent]

        # Build neighbors
        neighbors_list = []
        for i, agent in enumerate(agents):
            agent_neighbors = [self.data[frame][other] for other in self.data[frame] if other != agent]
            if agent_neighbors:
                neighbors_tensor = torch.stack(agent_neighbors)
            else:
                neighbors_tensor = torch.zeros(1, 2)
            neighbors_list.append(neighbors_tensor)

        max_neighbors = 12
        padded_neighbors = torch.zeros(num_agents, max_neighbors, 2)
        for i, neighbor_tensor in enumerate(neighbors_list):
            if neighbor_tensor.shape[0] > 0:
                ego_pos = history[i, -1].unsqueeze(0)
                dists = torch.norm(neighbor_tensor - ego_pos, dim=1)
                sorted_idx = torch.argsort(dists)
                sorted_neighbors = neighbor_tensor[sorted_idx]
                top_neighbors = sorted_neighbors[:max_neighbors]
                padded_neighbors[i, :top_neighbors.shape[0]] = top_neighbors
        neighbors = padded_neighbors

        # Mask invalid agents
        mask = torch.ones(history.shape[0], dtype=torch.bool)
        for i in range(history.shape[0]):
            if torch.all(history[i] == 0) or torch.all(future[i] == 0):
                mask[i] = False
        history = history[mask]
        future = future[mask]
        neighbors = neighbors[mask]
        goals = goals[mask]

        return history, future, neighbors, goals

In [4]:
def collate_fn(batch):
    """Collate function for batches that include goal"""
    max_agents = max(sample[0].shape[0] for sample in batch)
    # history
    max_neighbors = max(sample[2].shape[1] for sample in batch)  # neighbors
    hist_len = batch[0][0].shape[1]
    fut_len = batch[0][1].shape[1]

    padded_histories = []
    padded_futures = []
    padded_neighbors = []
    padded_goals = []

    for sample in batch:
        history, future, neighbors, goal = sample
        A, H, D = history.shape
        N = neighbors.shape[1]

        padded_hist = torch.zeros(max_agents, hist_len, D)
        padded_fut = torch.zeros(max_agents, fut_len, D)
        padded_neigh = torch.zeros(max_agents, max_neighbors, D)
        padded_goal = torch.zeros(max_agents, D)

        padded_hist[:A] = history
        padded_fut[:A] = future
        padded_neigh[:A, :N] = neighbors
        padded_goal[:A] = goal

        padded_histories.append(padded_hist)
        padded_futures.append(padded_fut)
        padded_neighbors.append(padded_neigh)
        padded_goals.append(padded_goal)

    return (
        torch.stack(padded_histories),   # [B, A, hist_len, D]
        torch.stack(padded_futures),     # [B, A, fut_len, D]
        torch.stack(padded_neighbors),   # [B, A, max_neighbors, D]
        torch.stack(padded_goals)        # [B, A, D]
    )

# FIXED: Added goals to test data (4th element in each tuple)
test_batch = [
    (torch.rand(3, 8, 2), torch.rand(3, 12, 2), torch.rand(3, 5, 2), torch.rand(3, 2)),  # 3 agents with goals
    (torch.rand(2, 8, 2), torch.rand(2, 12, 2), torch.rand(2, 3, 2), torch.rand(2, 2)),  # 2 agents with goals
    (torch.rand(4, 8, 2), torch.rand(4, 12, 2), torch.rand(4, 6, 2), torch.rand(4, 2))   # 4 agents with goals
]

# Now this should work correctly
hist, fut, neigh, goals = collate_fn(test_batch)  # Note: also need to unpack goals here
print(f"History shape: {hist.shape}")     # Will be [3, 4, 8, 2]
print(f"Future shape: {fut.shape}")       # Will be [3, 4, 12, 2]
print(f"Neighbors shape: {neigh.shape}")  # Will be [3, 4, 6, 2]
print(f"Goals shape: {goals.shape}")      # Will be [3, 4, 2]

History shape: torch.Size([3, 4, 8, 2])
Future shape: torch.Size([3, 4, 12, 2])
Neighbors shape: torch.Size([3, 4, 6, 2])
Goals shape: torch.Size([3, 4, 2])


In [5]:

def compute_average_speed_from_file_ZARA(file_path, delimiter=None):
    """
    Compute average speed across all agents from trajectory data in file.

    Args:
        file_path (str): Path to the file containing data: frame agent x y
        delimiter (str or None): delimiter for np.loadtxt (default None detects spaces).
                                Use "," if your file is comma separated.

    Returns:
        float: average speed (units per frame step)
    """
    # Load dataset: shape (N,4) --> frame, agent, x, y
    data = np.loadtxt(file_path, delimiter=delimiter)

    # Sort by agent then by frame to ensure proper ordering
    data = data[np.lexsort((data[:,0], data[:,1]))]

    total_distance = 0.0
    total_transitions = 0

    # Process each agent separately
    for agent_id in np.unique(data[:,1]):
        agent_data = data[data[:,1] == agent_id]

        # Filter out zero positions (padded)
        mask = ~((agent_data[:,2] == 0) & (agent_data[:,3] == 0))
        agent_data = agent_data[mask]

        if len(agent_data) < 2:
            continue

        frames = agent_data[:,0]
        positions = agent_data[:, 2:4]

        # Calculate Euclidean distances between consecutive positions
        displacements = positions[1:] - positions[:-1]
        distances = np.linalg.norm(displacements, axis=1)

        # Calculate time differences (frame differences)
        delta_times = frames[1:] - frames[:-1]

        # Ignore zero or negative time intervals (avoid division issues)
        valid_mask = delta_times > 0

        total_distance += distances[valid_mask].sum()
        total_transitions += delta_times[valid_mask].sum()  # total time elapsed

    if total_transitions > 0:
        avg_speed = total_distance / total_transitions
    else:
        avg_speed = 0.0

    return avg_speed

file_path = "/content/crowds_zara02_test.txt"
average_speed = compute_average_speed_from_file_ZARA(file_path)
print(f"Average speed: {average_speed:.4f} units per frame")

Average speed: 0.0278 units per frame


In [6]:
def calculate_speed(trajectory):
    """Calculate average speed from trajectory tensor [B, A, T, 2]"""
    diffs = trajectory[:, :, 1:, :] - trajectory[:, :, :-1, :]  # Frame-to-frame differences
    speeds = torch.norm(diffs, dim=-1)  # [B, A, T-1]
    valid_mask = speeds > 0
    if valid_mask.sum() > 0:
        avg_speed = speeds[valid_mask].mean()
    else:
        avg_speed = torch.tensor(0.0)
    return avg_speed

def check_speed_violations(predictions, history, min_speed, max_speed):
    """Count speed violations in predicted trajectories"""
    # Get last history position
    last_pos = history[:, :, -1:, :]  # [B, A, 1, 2]

    # Combine last history with predictions for speed calculation
    full_traj = torch.cat([last_pos, predictions], dim=2)  # [B, A, T+1, 2]

    # Calculate speeds
    diffs = full_traj[:, :, 1:, :] - full_traj[:, :, :-1, :]
    speeds = torch.norm(diffs, dim=-1)  # [B, A, T]

    # Count violations
    violations = ((speeds < min_speed) | (speeds > max_speed)).sum().item()
    return violations

In [10]:


class PotentialField(nn.Module):
    def __init__(self, goal, num_agents=1000, k_init=1.0, repulsion_radius=0.5):
        super().__init__()
        self.register_buffer('goal', torch.tensor(goal, dtype=torch.float32))
        self.repulsion_radius = repulsion_radius
        self.coeff_embedding = nn.Embedding(num_agents, 3)
        self.coeff_embedding.weight.data.fill_(k_init)

    def forward(self, pos, predicted, neighbors, goal, coeffs):
        k_att1 = coeffs[..., 0:1]
        k_att2 = coeffs[..., 1:2]
        k_rep = coeffs[..., 2:3]

        F_goal = k_att1 * (goal - pos)
        F_pred = k_att2 * (predicted[:, :, 0, :] - pos)

        diffs = pos.unsqueeze(2) - neighbors
        dists = torch.norm(diffs, dim=-1, keepdim=True) + 1e-6
        mask = (dists < self.repulsion_radius).float()

        F_rep = (k_rep.unsqueeze(2) * diffs / dists.pow(2) * mask).sum(dim=2)

        total_force = F_goal + F_pred + F_rep

        return total_force, coeffs


class CheckpointedIntegratedMTAPFMModel(nn.Module):
    def __init__(self, input_size=2, hidden_size=64, num_layers=2,
                 goal=(4.2,4.2), target_avg_speed=0.027,
                 speed_tolerance=0.15, num_agents=1000):
        super().__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size

        self.input_embed = nn.Linear(input_size, hidden_size)
        self.encoder = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True)
        self.decoder = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True)

        # If postprocess is unnecessary, replace with identity
        # self.postprocess = nn.Identity()
        self.postprocess = nn.Linear(input_size, input_size)


        self.output = nn.Linear(hidden_size, input_size)
        self.pfm = PotentialField(goal=goal, num_agents=num_agents)

        if target_avg_speed is None:
            raise ValueError("target_avg_speed must be computed from dataset and passed here.")

        self.target_avg_speed = target_avg_speed
        self.speed_tolerance = speed_tolerance
        self.min_speed = self.target_avg_speed * (1 - self.speed_tolerance)
        self.max_speed = self.target_avg_speed * (1 + self.speed_tolerance)

    def apply_speed_constraints(self, predictions, last_positions):
        B, A, T, D = predictions.shape
        constrained_preds = predictions.clone()
        current_pos = last_positions.clone()
        for t in range(T):
            displacement = predictions[:, :, t] - current_pos
            speeds = torch.norm(displacement, dim=-1, keepdim=True)
            non_zero_mask = speeds > 0
            clipped_speeds = torch.clamp(speeds, self.min_speed, self.max_speed)
            final_speeds = torch.where(non_zero_mask, clipped_speeds, speeds)
            direction = displacement / (speeds + 1e-8)
            clipped_displacement = direction * final_speeds
            constrained_preds[:, :, t] = current_pos + clipped_displacement
            current_pos = constrained_preds[:, :, t].clone()
        return constrained_preds

    def forward_encoder(self, x):
        return self.encoder(x)

    def forward_decoder_step(self, input_step, hx):
        return self.decoder(input_step, hx)

    def forward(self, history, neighbors, goal):
        B, A, H, D = history.shape
        device = history.device
        agent_ids = torch.arange(A).repeat(B, 1).to(device)

        hist_flat = history.reshape(B * A, H, D)
        hist_embedded = self.input_embed(hist_flat)
        # print('history',history)
        # print('neighbours',neighbors)
        # print('goal', goal)

        # Gradient checkpoint on encoder
        _, (h_n, c_n) = checkpoint.checkpoint(self.forward_encoder, hist_embedded)

        pred_flat = torch.zeros(B * A, 12, D, device=device)

        h_n = h_n[-1].unsqueeze(0).repeat(self.num_layers, 1, 1)
        c_n = c_n[-1].unsqueeze(0).repeat(self.num_layers, 1, 1)
        hx = (h_n, c_n)

        for t in range(12):
            current_pred = pred_flat[:, t:t+1].clone()
            pred_embedded = self.input_embed(current_pred)
            out, hx = checkpoint.checkpoint(self.forward_decoder_step, pred_embedded, hx)
            step_output = self.output(out.squeeze(1))
            pred_flat[:, t] = self.postprocess(step_output)

        predictions = pred_flat.view(B, A, 12, D)
        adjusted_preds = torch.zeros_like(predictions)
        current_pos = history[:, :, -1, :].clone()
        coeff_list = []

        coeffs = self.pfm.coeff_embedding(agent_ids)

        for t in range(12):
            pred_slice = predictions[:, :, t:t+1].clone()
            forces, coeff_step = self.pfm(current_pos, pred_slice, neighbors.clone(), goal, coeffs)
            if t == 0:
                adjusted_preds[:, :, t] = current_pos + forces
            else:
                adjusted_preds[:, :, t] = adjusted_preds[:, :, t-1] + forces
            coeff_list.append(coeff_step)
            current_pos = adjusted_preds[:, :, t].clone()

        last_known_pos = history[:, :, -1, :]

        constrained_preds = self.apply_speed_constraints(adjusted_preds, last_known_pos)
        # print(constrained_preds)
        coeff_stack = torch.stack(coeff_list, dim=0)
        coeff_mean = coeff_stack.mean()
        coeff_var = coeff_stack.var(unbiased=False)

        return constrained_preds, coeff_mean, coeff_var

In [8]:
# Define test dimensions
batch_size = 32
num_agents = 1000
history_len = 8
num_neighbors = 10
goal_dim = 2

# Initialize the model
model = CheckpointedIntegratedMTAPFMModel(input_size=2, hidden_size=64)

# Generate random input tensors
test_input = torch.rand(batch_size, num_agents, history_len, goal_dim)      # [B, A, H, 2]
test_neighbors = torch.rand(batch_size, num_agents, num_neighbors, goal_dim)  # [B, A, N, 2]

# Generate random goals for each agent
test_goal = torch.rand(batch_size, num_agents, goal_dim)  # [B, A, 2]

# Forward pass
output, coeff_mean, coeff_var = model(test_input, test_neighbors, test_goal)

# Output result shapes
print("Output shape:", output.shape)        # Expected: [32, 5, 12, 2]
print("Coeff mean:", coeff_mean.item())
print("Coeff var:", coeff_var.item())

history tensor([[[[2.6057e-01, 1.4399e-02],
          [7.6890e-01, 9.9851e-01],
          [5.5962e-01, 8.5919e-01],
          ...,
          [5.4017e-01, 3.6295e-01],
          [7.2765e-01, 2.7785e-02],
          [7.7101e-02, 5.4937e-01]],

         [[8.6097e-01, 9.6594e-02],
          [5.4383e-01, 3.6959e-01],
          [1.7857e-01, 2.1652e-01],
          ...,
          [5.9590e-02, 7.1777e-02],
          [1.1094e-01, 2.4175e-01],
          [5.1470e-01, 9.9765e-01]],

         [[4.7902e-01, 4.5142e-01],
          [8.7579e-01, 7.0143e-01],
          [6.0260e-01, 2.4974e-01],
          ...,
          [9.2126e-01, 7.3852e-01],
          [7.5591e-01, 4.8239e-02],
          [5.1443e-01, 3.1856e-01]],

         ...,

         [[9.8892e-02, 1.6657e-01],
          [9.4136e-01, 7.1780e-01],
          [4.6063e-01, 2.9086e-01],
          ...,
          [9.2423e-01, 6.7055e-01],
          [1.4595e-01, 8.9663e-01],
          [5.0335e-01, 6.7908e-01]],

         [[9.5884e-01, 1.8369e-01],
         

  return fn(*args, **kwargs)


KeyboardInterrupt: 

In [11]:


def train_pfm_model(
    data_path,
    model_save_path,
    model_class,
    dataset_class,
    collate_fn,
    batch_size=32,
    epochs=1,
    learning_rate=0.001,
    weight_decay=0.0,
    device=None
):
    """
    Train a PFM trajectory prediction model with speed constraints.

    Args:
        data_path (str): Path to dataset file.
        model_save_path (str): Where to save the trained model.
        model_class (class): Model class to instantiate (e.g., IntegratedMTAPFMModel).
        dataset_class (class): Dataset loader class (e.g., PFM_TrajectoryDataset).
        collate_fn (function): Collate function for DataLoader.
        batch_size (int): Batch size for DataLoader.
        epochs (int): Number of training epochs.
        learning_rate (float): Learning rate for optimizer.
        weight_decay (float): Weight decay for optimizer.
        device (torch.device): Device to train on. If None, auto-select CUDA if available.
    """
    device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.autograd.set_detect_anomaly(True)

    # === DATA LOADING ===
    dataset = dataset_class(data_path)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

    # === MODEL / OPTIMIZER / LOSS ===
    model = model_class().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    criterion = nn.MSELoss()

    print(f"\n🎯 Speed Constraints Enabled:")
    print(f"   Target Avg Speed: {model.target_avg_speed:.4f}")
    print(f"   Allowed range: [{model.min_speed:.4f}, {model.max_speed:.4f}]")
    print(f"   Tolerance: ±{model.speed_tolerance * 100:.1f}%\n")

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        epoch_gt_speeds, epoch_pred_speeds, epoch_hist_speeds, epoch_violations = [], [], [], []

        for batch_idx, (history, future, neighbors, goal) in enumerate(dataloader):
            history = history.to(device)
            future = future.to(device)
            neighbors = neighbors.to(device)
            goal = future[:, :, -1, :].clone()  # final target position

            optimizer.zero_grad()
            pred, coeff_mean, coeff_var = model(history, neighbors, goal)
            loss = criterion(pred, future)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            epoch_loss += loss.item()

            # === SPEED TRACKING ===
            with torch.no_grad():
                gt_speed = calculate_speed(future)
                pred_speed = calculate_speed(pred)
                hist_speed = calculate_speed(history)



                epoch_gt_speeds.append(gt_speed.item())
                epoch_pred_speeds.append(pred_speed.item())
                epoch_hist_speeds.append(hist_speed.item())

                violation_count = check_speed_violations(pred, history, model.min_speed, model.max_speed)
                epoch_violations.append(violation_count)

            if batch_idx % 50 == 0:
                print(f"[{batch_idx}] Loss: {loss:.4f} | k_att1 μ={coeff_mean:.2f}, σ²={coeff_var:.2f}")
                print(f"    Speeds - GT: {gt_speed:.3f}, Pred: {pred_speed:.3f}, Hist: {hist_speed:.3f}")
                print(f"    Speed Violations: {violation_count}")

        # === EPOCH SUMMARY ===
        avg_gt_speed = sum(epoch_gt_speeds) / len(epoch_gt_speeds)
        avg_pred_speed = sum(epoch_pred_speeds) / len(epoch_pred_speeds)
        avg_hist_speed = sum(epoch_hist_speeds) / len(epoch_hist_speeds)
        total_violations = sum(epoch_violations)

        print(f"\n=== EPOCH {epoch+1} SUMMARY ===")
        print(f"Avg Loss: {epoch_loss/len(dataloader):.4f}")
        print(f"Avg Speeds: Hist={avg_hist_speed:.4f}, GT={avg_gt_speed:.4f}, Pred={avg_pred_speed:.4f}")
        print(f"Speed Error: {abs(avg_gt_speed - avg_pred_speed):.4f}")
        print(f"Violations: {total_violations} | Constraint Compliance: {(1 - total_violations/(len(dataloader)*batch_size*5*12))*100:.2f}%")
        print("=" * 40)

        # === CLEAR CUDA CACHE AND COLLECT GARBAGE AFTER EACH EPOCH ===
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()

    torch.autograd.set_detect_anomaly(False)

    # === SAVE MODEL ===
    print("\n💾 Saving model...")
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epochs,
        'loss': epoch_loss/len(dataloader),
        'speed_constraints': {
            'target_avg_speed': model.target_avg_speed,
            'min_speed': model.min_speed,
            'max_speed': model.max_speed,
            'tolerance': model.speed_tolerance
        },
        'final_avg_speeds': {
            'historical': avg_hist_speed,
            'ground_truth': avg_gt_speed,
            'predicted': avg_pred_speed
        }
    }, model_save_path)
    print(f"✅ Model saved at {model_save_path} with speed constraints!")


train_pfm_model(
        data_path="/content/crowds_zara02_test_cleaned.txt",
        model_save_path="/content/mta_pfm_trajectory_model_zara.pth",
        model_class=CheckpointedIntegratedMTAPFMModel,
        dataset_class=PFM_TrajectoryDataset,
        collate_fn=collate_fn,
        batch_size=32,
        epochs=5
    )


🎯 Speed Constraints Enabled:
   Target Avg Speed: 0.0270
   Allowed range: [0.0229, 0.0310]
   Tolerance: ±15.0%

[0] Loss: 17.3595 | k_att1 μ=1.00, σ²=0.00
    Speeds - GT: 8.935, Pred: 0.030, Hist: 8.964
    Speed Violations: 1605

=== EPOCH 1 SUMMARY ===
Avg Loss: 19.7292
Avg Speeds: Hist=9.0538, GT=9.0336, Pred=0.0310
Speed Error: 9.0026
Violations: 59019 | Constraint Compliance: 6.85%
[0] Loss: 17.3800 | k_att1 μ=0.99, σ²=0.00
    Speeds - GT: 8.996, Pred: 0.031, Hist: 9.038
    Speed Violations: 2049

=== EPOCH 2 SUMMARY ===
Avg Loss: 21.0473
Avg Speeds: Hist=9.1360, GT=9.1171, Pred=0.0310
Speed Error: 9.0861
Violations: 58408 | Constraint Compliance: 7.82%
[0] Loss: 17.2901 | k_att1 μ=0.99, σ²=0.00
    Speeds - GT: 9.118, Pred: 0.031, Hist: 9.100
    Speed Violations: 1966

=== EPOCH 3 SUMMARY ===
Avg Loss: 19.7568
Avg Speeds: Hist=9.0571, GT=9.0336, Pred=0.0310
Speed Error: 9.0026
Violations: 60147 | Constraint Compliance: 5.07%
[0] Loss: 20.8516 | k_att1 μ=0.99, σ²=0.00
    S

In [12]:
def test_with_metrics_mta_pfm(model_path, cleaned_txt_file):   #load different models accordingly
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    dataset = PFM_TrajectoryDataset(cleaned_txt_file)
    print(f"\n📦 Loaded {len(dataset)} frame samples for evaluation.")

    if len(dataset) == 0:
        print("❌ No valid samples available. Check your preprocessing.")
        return

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)

    model = CheckpointedIntegratedMTAPFMModel().to(device)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    total_ade, total_fde, total_miss = 0.0, 0.0, 0.0
    count = 0

    with torch.no_grad():
        for history, future, neighbors, goal in dataloader:
            history = history.to(device)
            future = future.to(device)
            neighbors = neighbors.to(device)
            goal = goal.to(device)

            pred, _, _ = model(history, neighbors, goal)  # [1, A, 12, 2]
            pred = pred[0]    # [A, 12, 2]
            future = future[0]  # [A, T, 2]

            # Fix length mismatch
            min_len = min(pred.size(1), future.size(1))
            if pred.size(1) != future.size(1):
                print(f"⚠️ Truncating: pred_len={pred.size(1)}, future_len={future.size(1)} → using {min_len}")
            pred = pred[:, :min_len, :]
            future = future[:, :min_len, :]

            ade = torch.norm(pred - future, dim=-1).mean().item()
            fde = torch.norm(pred[:, -1, :] - future[:, -1, :], dim=-1).mean().item()
            miss = (torch.norm(pred[:, -1, :] - future[:, -1, :], dim=-1) > 2.0).float().mean().item()

            total_ade += ade
            total_fde += fde
            total_miss += miss
            count += 1

    if count == 0:
        print("❌ Evaluation aborted: No samples processed.")
        return

    print(f"\n📊 Evaluation Metrics on Test Dataset:")
    print(f"🔹 Average ADE:  {total_ade / count:.4f}")
    print(f"🔹 Average FDE:  {total_fde / count:.4f}")
    print(f"🔹 Miss Rate:    {total_miss / count:.4f} (threshold: 2m)")

In [13]:
test_with_metrics_mta_pfm(
        '/content/mta_pfm_trajectory_model_zara.pth',
        '/content/crowds_zara02_test_cleaned.txt'
    )


📦 Loaded 1025 frame samples for evaluation.


  coeff_var = coeff_stack.var(unbiased=False)



📊 Evaluation Metrics on Test Dataset:
🔹 Average ADE:  nan
🔹 Average FDE:  nan
🔹 Miss Rate:    nan (threshold: 2m)
