In [3]:
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
import time

# Dataset Class
class PFM_TrajectoryDataset(torch.utils.data.Dataset):
    def __init__(self, file_path, history_len=8, prediction_len=12):
        self.data = self.load_data(file_path)
        self.history_len = history_len
        self.prediction_len = prediction_len
        # Create a list of valid frame indices that have enough history and future data
        self.valid_frames = self._get_valid_frames()

    def load_data(self, file_path):
        data = {}
        with open(file_path, 'r') as file:
            for line in file:
                parts = line.strip().split(',')
                if len(parts) == 4:  # Ensure valid line
                    frame, agent, x, y = map(float, parts)
                    frame, agent = int(frame/10), int(agent)
                    if frame not in data:
                        data[frame] = {}
                    data[frame][agent] = torch.tensor([x, y], dtype=torch.float32)
        return data

    def _get_valid_frames(self):
        """Get frames that have sufficient history and future data"""
        all_frames = sorted(self.data.keys())
        valid_frames = []

        for frame in all_frames:
            # Check if we have enough history and future frames
            history_start = frame - self.history_len + 1
            future_end = frame + self.prediction_len

            # Ensure we have data for the required time range
            if history_start >= min(all_frames) and future_end <= max(all_frames):
                valid_frames.append(frame)

        return valid_frames

    def __len__(self):
        return len(self.valid_frames)

    def __getitem__(self, idx):
        frame = self.valid_frames[idx]

        # Get all agents present at this frame
        if frame not in self.data:
            # Return empty tensors if no data
            return (torch.zeros(0, self.history_len, 2),
                   torch.zeros(0, self.prediction_len, 2),
                   torch.zeros(0, 0, 2),
                   torch.zeros(0, 2))

        agents = list(self.data[frame].keys())
        num_agents = len(agents)

        if num_agents == 0:
            return (torch.zeros(0, self.history_len, 2),
                   torch.zeros(0, self.prediction_len, 2),
                   torch.zeros(0, 0, 2),
                   torch.zeros(0, 2))

        history = torch.zeros(num_agents, self.history_len, 2)
        future = torch.zeros(num_agents, self.prediction_len, 2)
        goals = torch.zeros(num_agents, 2)

        for i, agent in enumerate(agents):
            # Fill history (going backwards from current frame)
            for t in range(self.history_len):
                hist_frame = frame - (self.history_len - 1 - t)  # Fixed indexing
                if hist_frame in self.data and agent in self.data[hist_frame]:
                    history[i, t] = self.data[hist_frame][agent]
                # If no data available, position remains zero (padding)

            # Fill future (going forwards from next frame)
            for t in range(self.prediction_len):
                fut_frame = frame + t + 1  # Start from next frame
                if fut_frame in self.data and agent in self.data[fut_frame]:
                    future[i, t] = self.data[fut_frame][agent]
                # If no data available, position remains zero (padding)

            # Extract goal from the last timestep of future trajectory
            # Find the last non-zero position in future, or use the last timestep
            non_zero_mask = torch.any(future[i] != 0, dim=1)
            if non_zero_mask.any():
                last_valid_idx = torch.where(non_zero_mask)[0][-1]
                goals[i] = future[i, last_valid_idx]
            else:
                # If no future data, use current position as goal
                goals[i] = self.data[frame][agent]

        # Collect neighbors for each agent at the current frame
        neighbors_list = []
        for i, agent in enumerate(agents):
            # Get positions of all other agents at the current frame
            agent_neighbors = []
            for other_agent in self.data[frame]:
                if other_agent != agent:
                    agent_neighbors.append(self.data[frame][other_agent])

            if agent_neighbors:
                neighbors_tensor = torch.stack(agent_neighbors)
            else:
                # If no neighbors, create a dummy neighbor at origin
                neighbors_tensor = torch.zeros(1, 2)

            neighbors_list.append(neighbors_tensor)

        # Pad neighbors to have the same number for all agents
        if neighbors_list:
            max_neighbors = max(n.shape[0] for n in neighbors_list)
            padded_neighbors = torch.zeros(num_agents, max_neighbors, 2)

            for i, neighbor_tensor in enumerate(neighbors_list):
                padded_neighbors[i, :neighbor_tensor.shape[0]] = neighbor_tensor

            neighbors = padded_neighbors
        else:
            neighbors = torch.zeros(num_agents, 1, 2)

        return history, future, neighbors, goals

In [1]:
from torch.nn.utils.rnn import pad_sequence
import torch

def collate_fn(batch):
    """Collate function for batches that include goal"""
    max_agents = max(sample[0].shape[0] for sample in batch)
    # history
    max_neighbors = max(sample[2].shape[1] for sample in batch)  # neighbors
    hist_len = batch[0][0].shape[1]
    fut_len = batch[0][1].shape[1]

    padded_histories = []
    padded_futures = []
    padded_neighbors = []
    padded_goals = []

    for sample in batch:
        history, future, neighbors, goal = sample
        A, H, D = history.shape
        N = neighbors.shape[1]

        padded_hist = torch.zeros(max_agents, hist_len, D)
        padded_fut = torch.zeros(max_agents, fut_len, D)
        padded_neigh = torch.zeros(max_agents, max_neighbors, D)
        padded_goal = torch.zeros(max_agents, D)

        padded_hist[:A] = history
        padded_fut[:A] = future
        padded_neigh[:A, :N] = neighbors
        padded_goal[:A] = goal

        padded_histories.append(padded_hist)
        padded_futures.append(padded_fut)
        padded_neighbors.append(padded_neigh)
        padded_goals.append(padded_goal)

    return (
        torch.stack(padded_histories),   # [B, A, hist_len, D]
        torch.stack(padded_futures),     # [B, A, fut_len, D]
        torch.stack(padded_neighbors),   # [B, A, max_neighbors, D]
        torch.stack(padded_goals)        # [B, A, D]
    )

# FIXED: Added goals to test data (4th element in each tuple)
test_batch = [
    (torch.rand(3, 8, 2), torch.rand(3, 12, 2), torch.rand(3, 5, 2), torch.rand(3, 2)),  # 3 agents with goals
    (torch.rand(2, 8, 2), torch.rand(2, 12, 2), torch.rand(2, 3, 2), torch.rand(2, 2)),  # 2 agents with goals
    (torch.rand(4, 8, 2), torch.rand(4, 12, 2), torch.rand(4, 6, 2), torch.rand(4, 2))   # 4 agents with goals
]

# Now this should work correctly
hist, fut, neigh, goals = collate_fn(test_batch)  # Note: also need to unpack goals here
print(f"History shape: {hist.shape}")     # Will be [3, 4, 8, 2]
print(f"Future shape: {fut.shape}")       # Will be [3, 4, 12, 2]
print(f"Neighbors shape: {neigh.shape}")  # Will be [3, 4, 6, 2]
print(f"Goals shape: {goals.shape}")      # Will be [3, 4, 2]

History shape: torch.Size([3, 4, 8, 2])
Future shape: torch.Size([3, 4, 12, 2])
Neighbors shape: torch.Size([3, 4, 6, 2])
Goals shape: torch.Size([3, 4, 2])


In [2]:
import numpy as np

def compute_average_speed_from_file_ZARA(file_path, delimiter=None):
    """
    Compute average speed across all agents from trajectory data in file.

    Args:
        file_path (str): Path to the file containing data: frame agent x y
        delimiter (str or None): delimiter for np.loadtxt (default None detects spaces).
                                Use "," if your file is comma separated.

    Returns:
        float: average speed (units per frame step)
    """
    # Load dataset: shape (N,4) --> frame, agent, x, y
    data = np.loadtxt(file_path, delimiter=delimiter)

    # Sort by agent then by frame to ensure proper ordering
    data = data[np.lexsort((data[:,0], data[:,1]))]

    total_distance = 0.0
    total_transitions = 0

    # Process each agent separately
    for agent_id in np.unique(data[:,1]):
        agent_data = data[data[:,1] == agent_id]

        # Filter out zero positions (padded)
        mask = ~((agent_data[:,2] == 0) & (agent_data[:,3] == 0))
        agent_data = agent_data[mask]

        if len(agent_data) < 2:
            continue

        frames = agent_data[:,0]
        positions = agent_data[:, 2:4]

        # Calculate Euclidean distances between consecutive positions
        displacements = positions[1:] - positions[:-1]
        distances = np.linalg.norm(displacements, axis=1)

        # Calculate time differences (frame differences)
        delta_times = frames[1:] - frames[:-1]

        # Ignore zero or negative time intervals (avoid division issues)
        valid_mask = delta_times > 0

        total_distance += distances[valid_mask].sum()
        total_transitions += delta_times[valid_mask].sum()  # total time elapsed

    if total_transitions > 0:
        avg_speed = total_distance / total_transitions
    else:
        avg_speed = 0.0

    return avg_speed

file_path = "/content/crowds_zara02_test.txt"
average_speed = compute_average_speed_from_file_ZARA(file_path)
print(f"Average speed: {average_speed:.4f} units per frame")

FileNotFoundError: /content/crowds_zara02_test.txt not found.

In [4]:
def calculate_speed(trajectories):
    """
    Calculate average speed from trajectory data
    Args:
        trajectories: [B, A, T, 2] tensor where T is time steps
    Returns:
        average_speed: scalar tensor
    """
    B, A, T, D = trajectories.shape

    # Filter out padding (0,0) points
    non_zero_mask = ~((trajectories[:, :, :, 0] == 0) & (trajectories[:, :, :, 1] == 0))

    # Calculate displacement between consecutive time steps
    # trajectories[:, :, 1:] - trajectories[:, :, :-1] gives [B, A, T-1, 2]
    displacements = trajectories[:, :, 1:] - trajectories[:, :, :-1]

    # Calculate distances (speeds) for each time step
    distances = torch.norm(displacements, dim=-1)  # [B, A, T-1]

    # Apply mask to ignore padded transitions (both current and next point should be non-zero)
    valid_transitions = non_zero_mask[:, :, :-1] & non_zero_mask[:, :, 1:]

    # Only consider valid (non-padded) transitions
    valid_distances = distances * valid_transitions.float()

    # Calculate average speed (sum of valid distances / number of valid transitions)
    total_distance = valid_distances.sum()
    total_transitions = valid_transitions.sum()

    if total_transitions > 0:
        avg_speed = total_distance / total_transitions
    else:
        avg_speed = torch.tensor(0.0, device=trajectories.device)

    return avg_speed

In [5]:
# === HELPER FUNCTION TO CHECK VIOLATIONS ===
def check_speed_violations(predictions, history, min_speed, max_speed):
    """
    Count the number of speed constraint violations in predictions
    """
    B, A, T, D = predictions.shape
    last_pos = history[:, :, -1, :]  # [B, A, 2]
    current_pos = last_pos.clone()

    violations = 0

    for t in range(T):
        displacement = predictions[:, :, t] - current_pos
        speeds = torch.norm(displacement, dim=-1)

        # Count violations (excluding zero speeds from padding)
        non_zero_mask = speeds > 0
        too_fast = (speeds > max_speed) & non_zero_mask
        too_slow = (speeds < min_speed) & non_zero_mask

        violations += (too_fast | too_slow).sum().item()
        current_pos = predictions[:, :, t].clone()

    return violations

In [6]:
import torch
import torch.nn as nn

# === Trainable Potential Field Module ===
class PotentialField(nn.Module):
    def __init__(self, goal, num_agents=1000, k_init=1.0, repulsion_radius=0.5):
        super().__init__()
        self.register_buffer('goal', torch.tensor(goal, dtype=torch.float32))
        self.repulsion_radius = repulsion_radius
        self.coeff_embedding = nn.Embedding(num_agents, 3)
        self.coeff_embedding.weight.data.fill_(k_init)

    def forward(self, pos, predicted, neighbors, goal, coeffs):
        k1, k2, kr = coeffs[..., 0:1], coeffs[..., 1:2], coeffs[..., 2:3]

        # Attractive force towards goal
        Fg = k1 * (goal - pos)

        # Forward motion force (towards predicted step)
        Fp = k2 * (predicted[:, :, 0, :] - pos)

        # Repulsive force from neighbors
        diffs = pos.unsqueeze(2) - neighbors
        dists = torch.norm(diffs, dim=-1, keepdim=True) + 1e-6
        mask = (dists < self.repulsion_radius).float()
        Fr = (kr.unsqueeze(2) * diffs / dists.pow(2) * mask).sum(dim=2)

        return Fg + Fp + Fr, coeffs


# === Trainable PFM Model (baseline) ===
class PFMOnlyModel(nn.Module):
    def __init__(self, goal=(4.2, 4.2), target_avg_speed=4.087,
                 speed_tolerance=0.15, num_agents=1000, dt=0.1,
                 pred_len=12):
        super().__init__()
        self.pfm = PotentialField(goal, num_agents)
        if target_avg_speed is None:
            raise ValueError("target_avg_speed required")
        self.target_avg_speed = target_avg_speed
        self.speed_tolerance = speed_tolerance
        self.min_speed = target_avg_speed * (1 - speed_tolerance)
        self.max_speed = target_avg_speed * (1 + speed_tolerance)
        self.dt = dt
        self.pred_len = pred_len

    def apply_speed_constraints(self, preds, last_pos):
        B, A, T, _ = preds.shape
        out = preds.clone()
        cur = last_pos.clone()
        for t in range(T):
            disp = out[:, :, t] - cur
            sp = torch.norm(disp, dim=-1, keepdim=True)
            nz = sp > 0
            clipped = torch.clamp(sp, self.min_speed, self.max_speed)
            sp_final = torch.where(nz, clipped, sp)
            dir = disp / (sp + 1e-8)
            out[:, :, t] = cur + dir * sp_final
            cur = out[:, :, t].clone()
        return out

    def forward(self, history, neighbors, goal):
        B, A, H, _ = history.shape
        agent_ids = torch.arange(A).repeat(B, 1).to(history.device)
        coeffs = self.pfm.coeff_embedding(agent_ids)

        preds = torch.zeros(B, A, self.pred_len, 2, device=history.device)
        cur = history[:, :, -1, :].clone()
        coeff_list = []

        for t in range(self.pred_len):
            if t == 0 and H >= 2:
                vel = history[:, :, -1, :] - history[:, :, -2, :]
                pred_slice = (cur + vel).unsqueeze(2)
            elif t == 0:
                pred_slice = cur.unsqueeze(2)
            else:
                pred_slice = preds[:, :, t - 1:t, :].clone()

            forces, cstep = self.pfm(cur, pred_slice, neighbors, goal, coeffs)
            nextp = cur + forces * self.dt
            preds[:, :, t] = nextp
            cur = nextp.clone()
            coeff_list.append(cstep)

        preds = self.apply_speed_constraints(preds, history[:, :, -1, :])
        stack = torch.stack(coeff_list, dim=0)
        return preds, stack.mean(), stack.var(unbiased=False)

In [7]:
import torch
import torch.nn as nn

def train_pfm_model(
    data_path,
    model_save_path,
    model_class,
    dataset_class,
    collate_fn,
    batch_size=32,
    epochs=3,
    learning_rate=0.001,
    weight_decay=0.0,
    device=None
):
    """
    Train a PFM trajectory prediction model with speed constraints.

    Args:
        data_path (str): Path to dataset file.
        model_save_path (str): Where to save the trained model.
        model_class (class): Model class to instantiate (e.g., IntegratedMTAPFMModel).
        dataset_class (class): Dataset loader class (e.g., PFM_TrajectoryDataset).
        collate_fn (function): Collate function for DataLoader.
        batch_size (int): Batch size for DataLoader.
        epochs (int): Number of training epochs.
        learning_rate (float): Learning rate for optimizer.
        weight_decay (float): Weight decay for optimizer.
        device (torch.device): Device to train on. If None, auto-select CUDA if available.
    """

    device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.autograd.set_detect_anomaly(True)

    # === DATA LOADING ===
    dataset = dataset_class(data_path)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

    # === MODEL / OPTIMIZER / LOSS ===
    model = model_class().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    criterion = nn.MSELoss()

    print(f"\n🎯 Speed Constraints Enabled:")
    print(f"   Target Avg Speed: {model.target_avg_speed:.4f}")
    print(f"   Allowed range: [{model.min_speed:.4f}, {model.max_speed:.4f}]")
    print(f"   Tolerance: ±{model.speed_tolerance * 100:.1f}%\n")

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        epoch_gt_speeds, epoch_pred_speeds, epoch_hist_speeds, epoch_violations = [], [], [], []

        for batch_idx, (history, future, neighbors, goal) in enumerate(dataloader):
            history = history.to(device)
            future = future.to(device)
            neighbors = neighbors.to(device)
            goal = future[:, :, -1, :].clone()  # final target position

            optimizer.zero_grad()
            pred, coeff_mean, coeff_var = model(history, neighbors, goal)
            loss = criterion(pred, future)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            epoch_loss += loss.item()

            # === SPEED TRACKING ===
            with torch.no_grad():
                gt_speed = calculate_speed(future)
                pred_speed = calculate_speed(pred)
                hist_speed = calculate_speed(history)

                epoch_gt_speeds.append(gt_speed.item())
                epoch_pred_speeds.append(pred_speed.item())
                epoch_hist_speeds.append(hist_speed.item())

                violation_count = check_speed_violations(pred, history, model.min_speed, model.max_speed)
                epoch_violations.append(violation_count)

            if batch_idx % 50 == 0:
                print(f"[{batch_idx}] Loss: {loss:.4f} | k_att1 μ={coeff_mean:.2f}, σ²={coeff_var:.2f}")
                print(f"    Speeds - GT: {gt_speed:.3f}, Pred: {pred_speed:.3f}, Hist: {hist_speed:.3f}")
                print(f"    Speed Violations: {violation_count}")

        # === EPOCH SUMMARY ===
        avg_gt_speed = sum(epoch_gt_speeds) / len(epoch_gt_speeds)
        avg_pred_speed = sum(epoch_pred_speeds) / len(epoch_pred_speeds)
        avg_hist_speed = sum(epoch_hist_speeds) / len(epoch_hist_speeds)
        total_violations = sum(epoch_violations)

        print(f"\n=== EPOCH {epoch+1} SUMMARY ===")
        print(f"Avg Loss: {epoch_loss/len(dataloader):.4f}")
        print(f"Avg Speeds: Hist={avg_hist_speed:.4f}, GT={avg_gt_speed:.4f}, Pred={avg_pred_speed:.4f}")
        print(f"Speed Error: {abs(avg_gt_speed - avg_pred_speed):.4f}")
        print(f"Violations: {total_violations} | Constraint Compliance: {(1 - total_violations/(len(dataloader)*batch_size*5*12))*100:.2f}%")
        print("=" * 40)

    torch.autograd.set_detect_anomaly(False)

    # === SAVE MODEL ===
    print("\n💾 Saving model...")
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epochs,
        'loss': epoch_loss/len(dataloader),
        'speed_constraints': {
            'target_avg_speed': model.target_avg_speed,
            'min_speed': model.min_speed,
            'max_speed': model.max_speed,
            'tolerance': model.speed_tolerance
        },
        'final_avg_speeds': {
            'historical': avg_hist_speed,
            'ground_truth': avg_gt_speed,
            'predicted': avg_pred_speed
        }
    }, model_save_path)
    print(f"✅ Model saved at {model_save_path} with speed constraints!")

if __name__ == "__main__":
    train_pfm_model(
        data_path="/content/crowds_zara02_test_cleaned.txt",
        model_save_path="/content/pfm_trajectory_model_Zara.pth",
        model_class=PFMOnlyModel,
        dataset_class=PFM_TrajectoryDataset,
        collate_fn=collate_fn,
        batch_size=32,
        epochs=5
    )


🎯 Speed Constraints Enabled:
   Target Avg Speed: 4.0870
   Allowed range: [3.4739, 4.7000]
   Tolerance: ±15.0%

[0] Loss: 3.6614 | k_att1 μ=1.00, σ²=0.00
    Speeds - GT: 0.259, Pred: 3.483, Hist: 0.264
    Speed Violations: 1003

=== EPOCH 1 SUMMARY ===
Avg Loss: 3.6460
Avg Speeds: Hist=0.2791, GT=0.2748, Pred=3.4809
Speed Error: 3.2062
Violations: 34975 | Constraint Compliance: 43.07%
[0] Loss: 3.0150 | k_att1 μ=1.00, σ²=0.00
    Speeds - GT: 0.275, Pred: 3.481, Hist: 0.279
    Speed Violations: 1155

=== EPOCH 2 SUMMARY ===
Avg Loss: 3.5991
Avg Speeds: Hist=0.2775, GT=0.2736, Pred=3.4812
Speed Error: 3.2076
Violations: 35165 | Constraint Compliance: 42.77%
[0] Loss: 3.5499 | k_att1 μ=1.00, σ²=0.00
    Speeds - GT: 0.261, Pred: 3.480, Hist: 0.282
    Speed Violations: 1056

=== EPOCH 3 SUMMARY ===
Avg Loss: 3.6214
Avg Speeds: Hist=0.2786, GT=0.2736, Pred=3.4812
Speed Error: 3.2077
Violations: 34916 | Constraint Compliance: 43.17%
[0] Loss: 3.4775 | k_att1 μ=1.00, σ²=0.00
    Speed

## train pfm not learnable

In [None]:
import torch
import torch.nn as nn

def train_pfm_notlearnable_model(
    data_path,
    model_save_path,
    model_class,
    dataset_class,
    collate_fn,
    batch_size=32,
    epochs=3,
    learning_rate=0.001,
    weight_decay=0.0,
    device=None,
    model_kwargs=None
):
    """
    Train a PFM trajectory prediction model with speed constraints.

    Args:
        data_path (str): Path to dataset file.
        model_save_path (str): Where to save the trained model.
        model_class (class): Model class to instantiate (e.g., IntegratedMTAPFMModel).
        dataset_class (class): Dataset loader class (e.g., PFM_TrajectoryDataset).
        collate_fn (function): Collate function for DataLoader.
        batch_size (int): Batch size for DataLoader.
        epochs (int): Number of training epochs.
        learning_rate (float): Learning rate for optimizer.
        weight_decay (float): Weight decay for optimizer.
        device (torch.device): Device to train on. If None, auto-select CUDA if available.
    """

    device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    torch.autograd.set_detect_anomaly(True)

    # === DATA LOADING ===
    dataset = dataset_class(data_path)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

    # === MODEL / OPTIMIZER / LOSS ===
    model = model_class().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    criterion = nn.MSELoss()

    print(f"\n🎯 Speed Constraints Enabled:")
    print(f"   Target Avg Speed: {model.target_avg_speed:.4f}")
    print(f"   Allowed range: [{model.min_speed:.4f}, {model.max_speed:.4f}]")
    print(f"   Tolerance: ±{model.speed_tolerance * 100:.1f}%\n")

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        epoch_gt_speeds, epoch_pred_speeds, epoch_hist_speeds, epoch_violations = [], [], [], []

        for batch_idx, (history, future, neighbors, goal) in enumerate(dataloader):
            history = history.to(device)
            future = future.to(device)
            neighbors = neighbors.to(device)
            goal = future[:, :, -1, :].clone()  # final target position

            optimizer.zero_grad()
            pred, coeff_mean, coeff_var = model(history, neighbors, goal)
            loss = criterion(pred, future)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            epoch_loss += loss.item()

            # === SPEED TRACKING ===
            with torch.no_grad():
                gt_speed = calculate_speed(future)
                pred_speed = calculate_speed(pred)
                hist_speed = calculate_speed(history)

                epoch_gt_speeds.append(gt_speed.item())
                epoch_pred_speeds.append(pred_speed.item())
                epoch_hist_speeds.append(hist_speed.item())

                violation_count = check_speed_violations(pred, history, model.min_speed, model.max_speed)
                epoch_violations.append(violation_count)

            if batch_idx % 50 == 0:
                print(f"[{batch_idx}] Loss: {loss:.4f} | k_att1 μ={coeff_mean:.2f}, σ²={coeff_var:.2f}")
                print(f"    Speeds - GT: {gt_speed:.3f}, Pred: {pred_speed:.3f}, Hist: {hist_speed:.3f}")
                print(f"    Speed Violations: {violation_count}")

        # === EPOCH SUMMARY ===
        avg_gt_speed = sum(epoch_gt_speeds) / len(epoch_gt_speeds)
        avg_pred_speed = sum(epoch_pred_speeds) / len(epoch_pred_speeds)
        avg_hist_speed = sum(epoch_hist_speeds) / len(epoch_hist_speeds)
        total_violations = sum(epoch_violations)

        print(f"\n=== EPOCH {epoch+1} SUMMARY ===")
        print(f"Avg Loss: {epoch_loss/len(dataloader):.4f}")
        print(f"Avg Speeds: Hist={avg_hist_speed:.4f}, GT={avg_gt_speed:.4f}, Pred={avg_pred_speed:.4f}")
        print(f"Speed Error: {abs(avg_gt_speed - avg_pred_speed):.4f}")
        print(f"Violations: {total_violations} | Constraint Compliance: {(1 - total_violations/(len(dataloader)*batch_size*5*12))*100:.2f}%")
        print("=" * 40)

    torch.autograd.set_detect_anomaly(False)

    # === SAVE MODEL ===
    print("\n💾 Saving model...")
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epochs,
        'loss': epoch_loss/len(dataloader),
        'speed_constraints': {
            'target_avg_speed': model.target_avg_speed,
            'min_speed': model.min_speed,
            'max_speed': model.max_speed,
            'tolerance': model.speed_tolerance
        },
        'final_avg_speeds': {
            'historical': avg_hist_speed,
            'ground_truth': avg_gt_speed,
            'predicted': avg_pred_speed
        }
    }, model_save_path)
    print(f"✅ Model saved at {model_save_path} with speed constraints!")

if __name__ == "__main__":
    train_pfm_model(
        data_path="/content/crowds_zara02_test_cleaned.txt",
        model_save_path="/content/pfm_trajectory_model_Zara.pth",
        model_class=PFMOnlyModel,
        dataset_class=PFM_TrajectoryDataset,
        collate_fn=collate_fn,
        batch_size=32,
        epochs=5
    )

## Test pfm learnable

In [8]:
import torch.nn.functional as F

def compute_ADE(pred, gt):
    # pred: [A, T, 2], gt: [A, T, 2]
    return torch.norm(pred - gt, dim=-1).mean().item()

def compute_FDE(pred, gt):
    # Final step error averaged over agents
    return torch.norm(pred[:, -1, :] - gt[:, -1, :], dim=-1).mean().item()

def compute_miss_rate(pred, gt, threshold=2.0):
    final_dist = torch.norm(pred[:, -1, :] - gt[:, -1, :], dim=-1)
    misses = (final_dist > threshold).float()
    return misses.mean().item()


In [9]:
def test_with_metrics_pfm(model_path, cleaned_txt_file, model_class):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    dataset = PFM_TrajectoryDataset(cleaned_txt_file)
    print(f"\n📦 Loaded {len(dataset)} frame samples for evaluation.")

    if len(dataset) == 0:
        print("❌ No valid samples available. Check your preprocessing.")
        return

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)

    # === LOAD MODEL ===
    model = model_class().to(device)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    total_ade, total_fde, total_miss = 0.0, 0.0, 0.0
    count = 0

    with torch.no_grad():
        for history, future, neighbors, goal in dataloader:
            history = history.to(device)
            future = future.to(device)
            neighbors = neighbors.to(device)
            goal = goal.to(device)

            # Forward pass
            pred, _, _ = model(history, neighbors, goal)  # [1, A, T, 2]
            pred = pred[0]      # [A, T, 2]
            future = future[0]  # [A, T, 2]

            # Fix length mismatch
            min_len = min(pred.size(1), future.size(1))
            if pred.size(1) != future.size(1):
                print(f"⚠️ Truncating: pred_len={pred.size(1)}, future_len={future.size(1)} → using {min_len}")
            pred = pred[:, :min_len, :]
            future = future[:, :min_len, :]

            # === METRICS ===
            ade = compute_ADE(pred, future)
            fde = compute_FDE(pred, future)
            miss = compute_miss_rate(pred, future, threshold=2.0)

            total_ade += ade
            total_fde += fde
            total_miss += miss
            count += 1

    if count == 0:
        print("❌ Evaluation aborted: No samples processed.")
        return

    print(f"\n📊 Evaluation Metrics on Test Dataset:")
    print(f"🔹 Average ADE:  {total_ade / count:.4f}")
    print(f"🔹 Average FDE:  {total_fde / count:.4f}")
    print(f"🔹 Miss Rate:    {total_miss / count:.4f} (threshold: 2m)")


In [13]:
test_with_metrics_pfm(
    model_path="/content/pfm_trajectory_model_Zara.pth",
    cleaned_txt_file="/content/crowds_zara02_test_cleaned.txt",
    model_class=PFMOnlyModel
)


📦 Loaded 1009 frame samples for evaluation.

📊 Evaluation Metrics on Test Dataset:
🔹 Average ADE:  4.2697
🔹 Average FDE:  5.9976
🔹 Miss Rate:    0.6975 (threshold: 2m)
