In [1]:
import torch
from torch.utils.data import Dataset


class PFM_TrajectoryDataset_neighbours(Dataset):
    def __init__(self, file_path, history_len=8, prediction_len=12, max_neighbors=12):
        self.data = self.load_data(file_path)
        self.history_len = history_len
        self.prediction_len = prediction_len
        self.max_neighbors = max_neighbors
        self.valid_frames = self._get_valid_frames()

    def load_data(self, file_path):
        data = {}
        with open(file_path, "r") as file:
            for line in file:
                parts = line.strip().split(",")
                if len(parts) == 4:
                    frame, agent, x, y = map(float, parts)
                    frame, agent = int(frame), int(agent)
                    if frame not in data:
                        data[frame] = {}
                    data[frame][agent] = torch.tensor([x, y], dtype=torch.float32)
        return data

    def _get_valid_frames(self):
        all_frames = sorted(self.data.keys())
        valid_frames = []
        for frame in all_frames:
            history_start = frame - self.history_len + 1
            future_end = frame + self.prediction_len
            if history_start >= min(all_frames) and future_end <= max(all_frames):
                valid_frames.append(frame)
        return valid_frames

    def __len__(self):
        return len(self.valid_frames)

    def __getitem__(self, idx):
        frame = self.valid_frames[idx]
        if frame not in self.data:
            print("Dataset __getitem__: invalid frame, returning empty tensors")
            return (torch.zeros(0, self.history_len, 2),
                    torch.zeros(0, self.prediction_len, 2),
                    torch.zeros(0, self.max_neighbors, self.history_len, 2),
                    torch.zeros(0, 2))

        agents = list(self.data[frame].keys())
        num_agents = len(agents)
        history = torch.zeros(num_agents, self.history_len, 2)
        future = torch.zeros(num_agents, self.prediction_len, 2)
        goals = torch.zeros(num_agents, 2)

        for i, agent in enumerate(agents):
            for t in range(self.history_len):
                hist_frame = frame - (self.history_len - 1 - t)
                if hist_frame in self.data and agent in self.data[hist_frame]:
                    history[i, t] = self.data[hist_frame][agent]
            for t in range(self.prediction_len):
                fut_frame = frame + t + 1
                if fut_frame in self.data and agent in self.data[fut_frame]:
                    future[i, t] = self.data[fut_frame][agent]
            non_zero_mask = torch.any(future[i] != 0, dim=1)
            if non_zero_mask.any():
                last_valid_idx = torch.where(non_zero_mask)[0][-1]
                goals[i] = future[i, last_valid_idx]
            else:
                goals[i] = self.data[frame][agent]

        neighbor_histories = torch.zeros(num_agents, self.max_neighbors, self.history_len, 2)

        for i, agent in enumerate(agents):
            other_agents = [a for a in agents if a != agent][: self.max_neighbors]
            for n_idx, neighbor in enumerate(other_agents):
                for t in range(self.history_len):
                    hist_frame = frame - (self.history_len - 1 - t)
                    if hist_frame in self.data and neighbor in self.data[hist_frame]:
                        neighbor_histories[i, n_idx, t] = self.data[hist_frame][neighbor]

        mask = torch.ones(history.shape[0], dtype=torch.bool)
        for i in range(history.shape[0]):
            if not torch.any(history[i]): mask[i] = False
            if not torch.any(future[i]): mask[i] = False

        history = history[mask]
        future = future[mask]
        neighbor_histories = neighbor_histories[mask]
        goals = goals[mask]

        print("[DATASET] __getitem__ output shapes:",
              "history", history.shape,
              "future", future.shape,
              "neighbor_histories", neighbor_histories.shape,
              "goals", goals.shape)

        return history, future, neighbor_histories, goals

In [3]:
import torch

def collate_fn(batch, history_len=None, prediction_len=None, max_neighbors=None):
    histories, futures, neighbor_histories, goals = zip(*batch)
    batch_size = len(histories)
    max_agents = max([h.shape[0] for h in histories])
    max_neighbors = max_neighbors or neighbor_histories[0].shape[1]
    history_len = history_len or histories[0].shape[1]
    prediction_len = prediction_len or futures[0].shape[1]

    # Initialize batch tensors
    history_batch = torch.zeros(batch_size, max_agents, history_len, 2)
    future_batch = torch.zeros(batch_size, max_agents, prediction_len, 2)
    neighbor_histories_batch = torch.zeros(batch_size, max_agents, max_neighbors, history_len, 2)
    goals_batch = torch.zeros(batch_size, max_agents, 2)

    for i in range(batch_size):
        num_agents = histories[i].shape[0]
        history_batch[i, :num_agents] = histories[i]
        future_batch[i, :num_agents] = futures[i]
        neighbor_histories_batch[i, :num_agents] = neighbor_histories[i]
        goals_batch[i, :num_agents] = goals[i]

    print("[COLLATE] output shapes:",
          "history_batch", history_batch.shape,
          "future_batch", future_batch.shape,
          "neighbor_histories_batch", neighbor_histories_batch.shape,
          "goals_batch", goals_batch.shape)
    return history_batch, future_batch, neighbor_histories_batch, goals_batch

In [2]:
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint


class PotentialField(nn.Module):
    def __init__(self, num_agents=1000, k_init=1.0, repulsion_radius=0.5):
        super().__init__()
        self.repulsion_radius = repulsion_radius
        self.coeff_embedding = nn.Embedding(num_agents, 3)
        self.coeff_embedding.weight.data.fill_(k_init)

    def forward(self, pos, predicted, neighbors, goal, coeffs):
        print("[PFM] pos.shape:", pos.shape)
        print("[PFM] predicted.shape:", predicted.shape)
        print("[PFM] neighbors.shape:", neighbors.shape)
        print("[PFM] goal.shape:", goal.shape)
        print("[PFM] coeffs.shape:", coeffs.shape)

        k_att1 = coeffs[..., 0:1]
        k_att2 = coeffs[..., 1:2]
        k_rep = coeffs[..., 2:3]

        F_goal = k_att1 * (goal - pos)

        # Handle predicted dim cases for ego or neighbors
        if predicted.dim() == 3:
            F_pred = k_att2 * (predicted - pos)
        else:
            F_pred = k_att2 * (predicted[:, :, 0, :] - pos)

        # ROBUST repulsion handling - completely avoid computation when no neighbors
        if neighbors.size(1) == 0:
            F_rep = torch.zeros_like(pos)
            print("[PFM] No neighbors - F_rep set to zero with shape:", F_rep.shape)
        else:
            # Only compute when neighbors actually exist
            diffs = pos.unsqueeze(2) - neighbors  # [N, 1, 2] - [N, M, 2] = [N, M, 2]
            dists = torch.norm(diffs, dim=-1, keepdim=True) + 1e-6  # [N, M, 1]
            mask = (dists < self.repulsion_radius).float()  # [N, M, 1]

            # Ensure k_rep broadcasting matches diffs dimensions exactly
            k_rep_expanded = k_rep.unsqueeze(2)  # [N, 1, 1]
            repulsion_per_neighbor = k_rep_expanded * diffs / dists.pow(2) * mask  # [N, M, 2]
            F_rep = repulsion_per_neighbor.sum(dim=2)  # [N, 2]

            print("[PFM] Computed F_rep with shape:", F_rep.shape)

        total_force = F_goal + F_pred + F_rep
        print("[PFM] Total force shape:", total_force.shape)

        return total_force, coeffs


class CheckpointedIntegratedMTAPFMModel_neighbours(nn.Module):
    def __init__(self, input_size=2, hidden_size=64, num_layers=2,
                 target_avg_speed=4.087, speed_tolerance=0.15, num_agents=1000):
        super().__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size

        self.input_embed = nn.Linear(input_size, hidden_size)
        self.encoder = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True)
        self.decoder = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True)

        self.postprocess = nn.Linear(input_size, input_size)
        self.coeff_projector_ego = nn.Linear(hidden_size, 3)
        self.coeff_projector_neighbors = nn.Linear(hidden_size, 3)

        self.output = nn.Linear(hidden_size, input_size)
        self.pfm = PotentialField(num_agents=num_agents)

        if target_avg_speed is None:
            raise ValueError("target_avg_speed must be computed from dataset and passed here.")

        self.target_avg_speed = target_avg_speed
        self.speed_tolerance = speed_tolerance
        self.min_speed = self.target_avg_speed * (1 - self.speed_tolerance)
        self.max_speed = self.target_avg_speed * (1 + self.speed_tolerance)
        self.dt = 0.1  # timestep used for potential field motion update

    def apply_speed_constraints(self, predictions, last_positions):
        B, A, T, D = predictions.shape
        constrained_preds = predictions.clone()
        current_pos = last_positions.clone()
        for t in range(T):
            displacement = predictions[:, :, t] - current_pos
            speeds = torch.norm(displacement, dim=-1, keepdim=True)
            non_zero_mask = speeds > 0
            clipped_speeds = torch.clamp(speeds, self.min_speed, self.max_speed)
            final_speeds = torch.where(non_zero_mask, clipped_speeds, speeds)
            direction = displacement / (speeds + 1e-8)
            clipped_displacement = direction * final_speeds
            constrained_preds[:, :, t] = current_pos + clipped_displacement
            current_pos = constrained_preds[:, :, t].clone()
        return constrained_preds

    def forward_encoder(self, x):
        return self.encoder(x)

    def forward_decoder_step(self, input_step, hx):
        return self.decoder(input_step, hx)

    def forward(self, history_neighbors, goal):
        """
        history_neighbors: [B, A, max_neighbors+1, H, 2]
           neighbors dim 0 = ego, others = neighbors
        goal: [B, A, 2]
        """
        print("[MODEL] forward: history_neighbors.shape =", history_neighbors.shape)
        print("[MODEL] forward: goal.shape =", goal.shape)

        B, A, max_nb_plus_1, H, D = history_neighbors.shape
        device = history_neighbors.device

        histories_flat = history_neighbors.reshape(B * A * max_nb_plus_1, H, D)
        print("[MODEL] histories_flat.shape after reshape:", histories_flat.shape)

        embedded = self.input_embed(histories_flat)
        print("[MODEL] embedded.shape after input_embed:", embedded.shape)

        _, (h_n, c_n) = checkpoint.checkpoint(self.forward_encoder, embedded)

        print("[MODEL] h_n.shape:", h_n.shape)
        print("[MODEL] c_n.shape:", c_n.shape)

        h_top = h_n[-1].view(B, A, max_nb_plus_1, self.hidden_size)
        print("[MODEL] h_top.shape after reshape:", h_top.shape)

        h_ego = h_top[:, :, 0, :]
        h_neighbors = h_top[:, :, 1:, :]

        print("[MODEL] h_ego.shape:", h_ego.shape)
        print("[MODEL] h_neighbors.shape:", h_neighbors.shape)

        coeffs_ego = self.coeff_projector_ego(h_ego)
        coeffs_neighbors = self.coeff_projector_neighbors(h_neighbors)

        print("[MODEL] coeffs_ego.shape:", coeffs_ego.shape)
        print("[MODEL] coeffs_neighbors.shape:", coeffs_neighbors.shape)

        pred_flat = torch.zeros(B * A, 12, D, device=device)
        h_n_decoder = h_n[:, :B * A, :].contiguous()
        c_n_decoder = c_n[:, :B * A, :].contiguous()
        hx = (h_n_decoder, c_n_decoder)

        for t in range(12):
            current_pred = pred_flat[:, t:t + 1].clone()
            print(f"[MODEL] Step {t} - current_pred.shape:", current_pred.shape)

            pred_embedded = self.input_embed(current_pred)
            print(f"[MODEL] Step {t} - pred_embedded.shape:", pred_embedded.shape)

            h, c = hx
            hx = (h.contiguous(), c.contiguous())
            print(f"[MODEL] Step {t} - hx shapes: h {h.shape}, c {c.shape}")

            out, hx = checkpoint.checkpoint(self.forward_decoder_step, pred_embedded, hx)
            print(f"[MODEL] Step {t} - decoder out.shape:", out.shape)

            step_output = self.output(out.squeeze(1))
            print(f"[MODEL] Step {t} - step_output.shape:", step_output.shape)

            pred_flat[:, t] = self.postprocess(step_output)
            print(f"[MODEL] Step {t} - pred_flat state updated")

        predictions = pred_flat.view(B, A, 12, D)
        print("[MODEL] predictions.shape after view:", predictions.shape)
        adjusted_preds = torch.zeros_like(predictions)
        print("[MODEL] initialized adjusted_preds.shape:", adjusted_preds.shape)

        current_pos_ego = history_neighbors[:, :, 0, -1, :].clone()
        current_pos_neighbors = history_neighbors[:, :, 1:, -1, :].clone()

        print("[MODEL] current_pos_ego.shape:", current_pos_ego.shape)
        print("[MODEL] current_pos_neighbors.shape:", current_pos_neighbors.shape)

        coeff_list = []

        for t in range(12):
            pred_slice_ego = predictions[:, :, t:t + 1].clone()
            print(f"[MODEL] neighbor loop step {t} - pred_slice_ego.shape:", pred_slice_ego.shape)

            forces_ego, coeff_step_ego = self.pfm(current_pos_ego, pred_slice_ego,
                                                 current_pos_neighbors, goal, coeffs_ego)
            print(f"[MODEL] neighbor loop step {t} - forces_ego.shape:", forces_ego.shape)
            print(f"[MODEL] neighbor loop step {t} - coeff_step_ego.shape:", coeff_step_ego.shape)

            if t == 0:
                adjusted_preds[:, :, t] = current_pos_ego + forces_ego
            else:
                adjusted_preds[:, :, t] = adjusted_preds[:, :, t - 1] + forces_ego

            current_pos_ego = adjusted_preds[:, :, t].clone()
            print(f"[MODEL] neighbor loop step {t} - updated current_pos_ego.shape:", current_pos_ego.shape)

            B_, A_, N_ = current_pos_neighbors.shape[:3]
            cur_neighbors_flat = current_pos_neighbors.reshape(B_ * A_ * N_, 2)
            pred_slice_neighbors = cur_neighbors_flat.unsqueeze(1)
            coeffs_neighbors_flat = coeffs_neighbors.reshape(B_ * A_ * N_, 3)

            print(f"[MODEL] neighbor loop step {t} - current_pos_neighbors.shape:", current_pos_neighbors.shape)
            print(f"[MODEL] neighbor loop step {t} - cur_neighbors_flat.shape:", cur_neighbors_flat.shape)
            print(f"[MODEL] neighbor loop step {t} - pred_slice_neighbors.shape:", pred_slice_neighbors.shape)
            print(f"[MODEL] neighbor loop step {t} - coeffs_neighbors_flat.shape:", coeffs_neighbors_flat.shape)

            neighbors_neighbors = torch.empty(cur_neighbors_flat.shape[0], 0, 2, device=device)
            goal_expanded_neighbors = goal.unsqueeze(2).expand(B_, A_, N_, 2).contiguous().reshape(B_ * A_ * N_, 2)

            print(f"[MODEL] neighbor loop step {t} - neighbors_neighbors.shape:", neighbors_neighbors.shape)
            print(f"[MODEL] neighbor loop step {t} - goal_expanded_neighbors.shape:", goal_expanded_neighbors.shape)

            forces_neighbors, _ = self.pfm(
                pos=cur_neighbors_flat,
                predicted=pred_slice_neighbors,
                neighbors=neighbors_neighbors,
                goal=goal_expanded_neighbors,
                coeffs=coeffs_neighbors_flat
            )

            print(f"[MODEL] neighbor loop step {t} - forces_neighbors.shape after PFM:", forces_neighbors.shape)
            print(f"[MODEL] neighbor loop step {t} - cur_neighbors_flat.numel():", cur_neighbors_flat.numel())
            print(f"[MODEL] neighbor loop step {t} - forces_neighbors.numel():", forces_neighbors.numel())

            # Force correct shape if PFM returns unexpected dimensions
            expected_shape = cur_neighbors_flat.shape  # [N, 2]
            if forces_neighbors.shape != expected_shape:
                print(f"[FIX] forces_neighbors shape mismatch! Expected: {expected_shape}, Got: {forces_neighbors.shape}")
                forces_neighbors = torch.zeros_like(cur_neighbors_flat)
                print(f"[FIX] Forced forces_neighbors to correct shape: {forces_neighbors.shape}")

            nextp_neighbors = cur_neighbors_flat + forces_neighbors * self.dt

            print(f"[MODEL] neighbor loop step {t} - nextp_neighbors.shape before view:", nextp_neighbors.shape)
            print(f"[MODEL] neighbor loop step {t} - expected shape for view: ({B_}, {A_}, {N_}, 2)")

            try:
                current_pos_neighbors = nextp_neighbors.view(B_, A_, N_, 2).clone().detach()
                print(f"[MODEL] neighbor loop step {t} - successfully reshaped current_pos_neighbors")
            except Exception as e:
                print(f"[ERROR] neighbor loop step {t} - Failed to view nextp_neighbors:", e)
                raise

            coeff_list.append(coeff_step_ego)

        last_known_pos_ego = history_neighbors[:, :, 0, -1, :]
        constrained_preds = self.apply_speed_constraints(adjusted_preds, last_known_pos_ego)

        coeff_stack = torch.stack(coeff_list, dim=0)
        coeff_mean = coeff_stack.mean()
        coeff_var = coeff_stack.var(unbiased=False)

        print("[MODEL] forward completed: returning predictions and coeff stats")

        return constrained_preds, coeff_mean, coeff_var

In [4]:
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint


class PotentialField(nn.Module):
    def __init__(self, num_agents=1000, k_init=1.0, repulsion_radius=0.5):
        super().__init__()
        self.repulsion_radius = repulsion_radius
        self.coeff_embedding = nn.Embedding(num_agents, 3)
        self.coeff_embedding.weight.data.fill_(k_init)

    def forward(self, pos, predicted, neighbors, goal, coeffs):
        print("[PFM] pos.shape:", pos.shape)
        print("[PFM] predicted.shape:", predicted.shape)
        print("[PFM] neighbors.shape:", neighbors.shape)
        print("[PFM] goal.shape:", goal.shape)
        print("[PFM] coeffs.shape:", coeffs.shape)

        k_att1 = coeffs[..., 0:1]
        k_att2 = coeffs[..., 1:2]
        k_rep = coeffs[..., 2:3]

        F_goal = k_att1 * (goal - pos)

        # Handle predicted dim cases for ego or neighbors
        if predicted.dim() == 3:
            F_pred = k_att2 * (predicted - pos)
        else:
            F_pred = k_att2 * (predicted[:, :, 0, :] - pos)

        # ROBUST repulsion handling - completely avoid computation when no neighbors
        if neighbors.size(1) == 0:
            F_rep = torch.zeros_like(pos)
            print("[PFM] No neighbors - F_rep set to zero with shape:", F_rep.shape)
        else:
            # Only compute when neighbors actually exist
            diffs = pos.unsqueeze(2) - neighbors  # [N, 1, 2] - [N, M, 2] = [N, M, 2]
            dists = torch.norm(diffs, dim=-1, keepdim=True) + 1e-6  # [N, M, 1]
            mask = (dists < self.repulsion_radius).float()  # [N, M, 1]

            # Ensure k_rep broadcasting matches diffs dimensions exactly
            k_rep_expanded = k_rep.unsqueeze(2)  # [N, 1, 1]
            repulsion_per_neighbor = k_rep_expanded * diffs / dists.pow(2) * mask  # [N, M, 2]
            F_rep = repulsion_per_neighbor.sum(dim=2)  # [N, 2]

            print("[PFM] Computed F_rep with shape:", F_rep.shape)

        total_force = F_goal + F_pred + F_rep
        print("[PFM] Total force shape:", total_force.shape)

        return total_force, coeffs


class CheckpointedIntegratedMTAPFMModel_neighbours(nn.Module):
    def __init__(self, input_size=2, hidden_size=64, num_layers=2,
                 target_avg_speed=4.087, speed_tolerance=0.15, num_agents=1000):
        super().__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size

        self.input_embed = nn.Linear(input_size, hidden_size)
        self.encoder = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True)
        self.decoder = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True)

        self.postprocess = nn.Linear(input_size, input_size)
        self.coeff_projector_ego = nn.Linear(hidden_size, 3)
        self.coeff_projector_neighbors = nn.Linear(hidden_size, 3)

        self.output = nn.Linear(hidden_size, input_size)
        self.pfm = PotentialField(num_agents=num_agents)

        if target_avg_speed is None:
            raise ValueError("target_avg_speed must be computed from dataset and passed here.")

        self.target_avg_speed = target_avg_speed
        self.speed_tolerance = speed_tolerance
        self.min_speed = self.target_avg_speed * (1 - self.speed_tolerance)
        self.max_speed = self.target_avg_speed * (1 + self.speed_tolerance)
        self.dt = 0.1  # timestep used for potential field motion update

    def apply_speed_constraints(self, predictions, last_positions):
        B, A, T, D = predictions.shape
        constrained_preds = predictions.clone()
        current_pos = last_positions.clone()
        for t in range(T):
            displacement = predictions[:, :, t] - current_pos
            speeds = torch.norm(displacement, dim=-1, keepdim=True)
            non_zero_mask = speeds > 0
            clipped_speeds = torch.clamp(speeds, self.min_speed, self.max_speed)
            final_speeds = torch.where(non_zero_mask, clipped_speeds, speeds)
            direction = displacement / (speeds + 1e-8)
            clipped_displacement = direction * final_speeds
            constrained_preds[:, :, t] = current_pos + clipped_displacement
            current_pos = constrained_preds[:, :, t].clone()
        return constrained_preds

    def forward_encoder(self, x):
        return self.encoder(x)

    def forward_decoder_step(self, input_step, hx):
        return self.decoder(input_step, hx)

    def forward(self, history_neighbors, goal):
        """
        history_neighbors: [B, A, max_neighbors+1, H, 2]
           neighbors dim 0 = ego, others = neighbors
        goal: [B, A, 2]
        """
        print("[MODEL] forward: history_neighbors.shape =", history_neighbors.shape)
        print("[MODEL] forward: goal.shape =", goal.shape)

        B, A, max_nb_plus_1, H, D = history_neighbors.shape
        device = history_neighbors.device

        histories_flat = history_neighbors.reshape(B * A * max_nb_plus_1, H, D)
        print("[MODEL] histories_flat.shape after reshape:", histories_flat.shape)

        embedded = self.input_embed(histories_flat)
        print("[MODEL] embedded.shape after input_embed:", embedded.shape)

        _, (h_n, c_n) = checkpoint.checkpoint(self.forward_encoder, embedded)

        print("[MODEL] h_n.shape:", h_n.shape)
        print("[MODEL] c_n.shape:", c_n.shape)

        h_top = h_n[-1].view(B, A, max_nb_plus_1, self.hidden_size)
        print("[MODEL] h_top.shape after reshape:", h_top.shape)

        h_ego = h_top[:, :, 0, :]
        h_neighbors = h_top[:, :, 1:, :]

        print("[MODEL] h_ego.shape:", h_ego.shape)
        print("[MODEL] h_neighbors.shape:", h_neighbors.shape)

        coeffs_ego = self.coeff_projector_ego(h_ego)
        coeffs_neighbors = self.coeff_projector_neighbors(h_neighbors)

        print("[MODEL] coeffs_ego.shape:", coeffs_ego.shape)
        print("[MODEL] coeffs_neighbors.shape:", coeffs_neighbors.shape)

        pred_flat = torch.zeros(B * A, 12, D, device=device)
        h_n_decoder = h_n[:, :B * A, :].contiguous()
        c_n_decoder = c_n[:, :B * A, :].contiguous()
        hx = (h_n_decoder, c_n_decoder)

        for t in range(12):
            current_pred = pred_flat[:, t:t + 1].clone()
            print(f"[MODEL] Step {t} - current_pred.shape:", current_pred.shape)

            pred_embedded = self.input_embed(current_pred)
            print(f"[MODEL] Step {t} - pred_embedded.shape:", pred_embedded.shape)

            h, c = hx
            hx = (h.contiguous(), c.contiguous())
            print(f"[MODEL] Step {t} - hx shapes: h {h.shape}, c {c.shape}")

            out, hx = checkpoint.checkpoint(self.forward_decoder_step, pred_embedded, hx)
            print(f"[MODEL] Step {t} - decoder out.shape:", out.shape)

            step_output = self.output(out.squeeze(1))
            print(f"[MODEL] Step {t} - step_output.shape:", step_output.shape)

            pred_flat[:, t] = self.postprocess(step_output)
            print(f"[MODEL] Step {t} - pred_flat state updated")

        predictions = pred_flat.view(B, A, 12, D)
        print("[MODEL] predictions.shape after view:", predictions.shape)
        adjusted_preds = torch.zeros_like(predictions)
        print("[MODEL] initialized adjusted_preds.shape:", adjusted_preds.shape)

        current_pos_ego = history_neighbors[:, :, 0, -1, :].clone()
        current_pos_neighbors = history_neighbors[:, :, 1:, -1, :].clone()

        print("[MODEL] current_pos_ego.shape:", current_pos_ego.shape)
        print("[MODEL] current_pos_neighbors.shape:", current_pos_neighbors.shape)

        coeff_list = []

        for t in range(12):
            pred_slice_ego = predictions[:, :, t:t + 1].clone()
            print(f"[MODEL] neighbor loop step {t} - pred_slice_ego.shape:", pred_slice_ego.shape)

            forces_ego, coeff_step_ego = self.pfm(current_pos_ego, pred_slice_ego,
                                                 current_pos_neighbors, goal, coeffs_ego)
            print(f"[MODEL] neighbor loop step {t} - forces_ego.shape:", forces_ego.shape)
            print(f"[MODEL] neighbor loop step {t} - coeff_step_ego.shape:", coeff_step_ego.shape)

            if t == 0:
                adjusted_preds[:, :, t] = current_pos_ego + forces_ego
            else:
                adjusted_preds[:, :, t] = adjusted_preds[:, :, t - 1] + forces_ego

            current_pos_ego = adjusted_preds[:, :, t].clone()
            print(f"[MODEL] neighbor loop step {t} - updated current_pos_ego.shape:", current_pos_ego.shape)

            B_, A_, N_ = current_pos_neighbors.shape[:3]
            cur_neighbors_flat = current_pos_neighbors.reshape(B_ * A_ * N_, 2)
            pred_slice_neighbors = cur_neighbors_flat.unsqueeze(1)
            coeffs_neighbors_flat = coeffs_neighbors.reshape(B_ * A_ * N_, 3)

            print(f"[MODEL] neighbor loop step {t} - current_pos_neighbors.shape:", current_pos_neighbors.shape)
            print(f"[MODEL] neighbor loop step {t} - cur_neighbors_flat.shape:", cur_neighbors_flat.shape)
            print(f"[MODEL] neighbor loop step {t} - pred_slice_neighbors.shape:", pred_slice_neighbors.shape)
            print(f"[MODEL] neighbor loop step {t} - coeffs_neighbors_flat.shape:", coeffs_neighbors_flat.shape)

            neighbors_neighbors = torch.empty(cur_neighbors_flat.shape[0], 0, 2, device=device)
            goal_expanded_neighbors = goal.unsqueeze(2).expand(B_, A_, N_, 2).contiguous().reshape(B_ * A_ * N_, 2)

            print(f"[MODEL] neighbor loop step {t} - neighbors_neighbors.shape:", neighbors_neighbors.shape)
            print(f"[MODEL] neighbor loop step {t} - goal_expanded_neighbors.shape:", goal_expanded_neighbors.shape)

            forces_neighbors, _ = self.pfm(
                pos=cur_neighbors_flat,
                predicted=pred_slice_neighbors,
                neighbors=neighbors_neighbors,
                goal=goal_expanded_neighbors,
                coeffs=coeffs_neighbors_flat
            )

            print(f"[MODEL] neighbor loop step {t} - forces_neighbors.shape after PFM:", forces_neighbors.shape)
            print(f"[MODEL] neighbor loop step {t} - cur_neighbors_flat.numel():", cur_neighbors_flat.numel())
            print(f"[MODEL] neighbor loop step {t} - forces_neighbors.numel():", forces_neighbors.numel())

            # Force correct shape if PFM returns unexpected dimensions
            expected_shape = cur_neighbors_flat.shape  # [N, 2]
            if forces_neighbors.shape != expected_shape:
                print(f"[FIX] forces_neighbors shape mismatch! Expected: {expected_shape}, Got: {forces_neighbors.shape}")
                forces_neighbors = torch.zeros_like(cur_neighbors_flat)
                print(f"[FIX] Forced forces_neighbors to correct shape: {forces_neighbors.shape}")

            nextp_neighbors = cur_neighbors_flat + forces_neighbors * self.dt

            print(f"[MODEL] neighbor loop step {t} - nextp_neighbors.shape before view:", nextp_neighbors.shape)
            print(f"[MODEL] neighbor loop step {t} - expected shape for view: ({B_}, {A_}, {N_}, 2)")

            try:
                current_pos_neighbors = nextp_neighbors.view(B_, A_, N_, 2).clone().detach()
                print(f"[MODEL] neighbor loop step {t} - successfully reshaped current_pos_neighbors")
            except Exception as e:
                print(f"[ERROR] neighbor loop step {t} - Failed to view nextp_neighbors:", e)
                raise

            coeff_list.append(coeff_step_ego)

        last_known_pos_ego = history_neighbors[:, :, 0, -1, :]
        constrained_preds = self.apply_speed_constraints(adjusted_preds, last_known_pos_ego)

        coeff_stack = torch.stack(coeff_list, dim=0)
        coeff_mean = coeff_stack.mean()
        coeff_var = coeff_stack.var(unbiased=False)

        print("[MODEL] forward completed: returning predictions and coeff stats")

        return constrained_preds, coeff_mean, coeff_var

In [7]:
import torch
import torch.nn as nn

def calculate_speed(trajectory):
    """Calculate average speed from trajectory tensor [B, A, T, 2]"""
    diffs = trajectory[:, :, 1:, :] - trajectory[:, :, :-1, :]  # Frame-to-frame differences
    speeds = torch.norm(diffs, dim=-1)  # [B, A, T-1]
    valid_mask = speeds > 0
    if valid_mask.sum() > 0:
        avg_speed = speeds[valid_mask].mean()
    else:
        avg_speed = torch.tensor(0.0)
    return avg_speed

def check_speed_violations(predictions, history, min_speed, max_speed):
    """Count speed violations in predicted trajectories"""
    # Get last history position
    last_pos = history[:, :, -1:, :]  # [B, A, 1, 2]

    # Combine last history with predictions for speed calculation
    full_traj = torch.cat([last_pos, predictions], dim=2)  # [B, A, T+1, 2]

    # Calculate speeds
    diffs = full_traj[:, :, 1:, :] - full_traj[:, :, :-1, :]
    speeds = torch.norm(diffs, dim=-1)  # [B, A, T]

    # Count violations
    violations = ((speeds < min_speed) | (speeds > max_speed)).sum().item()
    return violations

In [9]:
import torch
import gc
from torch import nn
from torch.utils.data import DataLoader, random_split
# from datasets.pfm_trajectory_dataset_neighbours import PFM_TrajectoryDataset_neighbours
# from utils.collate_mta_pfm_neighbours import collate_fn
# from utils.speed_utils import calculate_speed, check_speed_violations
# from models.mta_pfm_model_neighbours import CheckpointedIntegratedMTAPFMModel_neighbours


def train_mta_pfm_model(
    data_path,
    model_save_path,
    model_class,
    dataset_class,
    collate_fn,
    batch_size=1,  # lowered for memory stability
    epochs=80,
    learning_rate=0.001,
    weight_decay=0.0,
    patience=7,
    accumulation_steps=4,
    device=None,
):
    device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
    torch.autograd.set_detect_anomaly(True)

    dataset = dataset_class(data_path)
    val_size = int(0.2 * len(dataset))
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=lambda b: collate_fn(
            b,
            history_len=dataset.history_len,
            prediction_len=dataset.prediction_len,
            max_neighbors=dataset.max_neighbors,
        ),
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=lambda b: collate_fn(
            b,
            history_len=dataset.history_len,
            prediction_len=dataset.prediction_len,
            max_neighbors=dataset.max_neighbors,
        ),
    )

    model = model_class().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=3)
    criterion = nn.MSELoss()

    print(f"\n🎯 Speed Constraints Enabled:")
    print(f"   Target Avg Speed: {model.target_avg_speed:.4f}")
    print(f"   Allowed range: [{model.min_speed:.4f}, {model.max_speed:.4f}]")
    print(f"   Tolerance: ±{model.speed_tolerance * 100:.1f}%\n")

    best_val_loss = float("inf")
    patience_counter = 0

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        epoch_gt_speeds, epoch_pred_speeds, epoch_hist_speeds, epoch_violations = [], [], [], []

        optimizer.zero_grad()

        for batch_idx, (history, future, neighbor_histories, _) in enumerate(train_loader):
            history = history.to(device)
            future = future.to(device)
            neighbor_histories = neighbor_histories.to(device)

            # Dynamic goal from last timestep in future trajectory
            goal = future[:, :, -1, :].clone()

            # Build input tensor [B, A, max_neighbors+1, H, 2]
            if neighbor_histories.dim() == 4:  # Ensure last dim=2 for coords
                neighbor_histories = neighbor_histories.unsqueeze(-1).repeat(1, 1, 1, 1, 2)
            history_neighbors = torch.cat([history.unsqueeze(2), neighbor_histories], dim=2)

            pred, coeff_mean, coeff_var = model(history_neighbors, goal)
            loss = criterion(pred, future)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()

            epoch_loss += loss.item()

            with torch.no_grad():
                gt_speed = calculate_speed(future)
                pred_speed = calculate_speed(pred)
                hist_speed = calculate_speed(history)

                epoch_gt_speeds.append(gt_speed.item())
                epoch_pred_speeds.append(pred_speed.item())
                epoch_hist_speeds.append(hist_speed.item())
                violation_count = check_speed_violations(pred, history, model.min_speed, model.max_speed)
                epoch_violations.append(violation_count)

            if batch_idx % 50 == 0:
                print(
                    f"[{batch_idx}] Loss: {loss:.4f} | k_att1 μ={coeff_mean:.2f}, σ²={coeff_var:.2f}"
                )
                print(
                    f"    Speeds - GT: {gt_speed:.3f}, Pred: {pred_speed:.3f}, Hist: {hist_speed:.3f}"
                )
                print(f"    Speed Violations: {violation_count}")

        avg_train_loss = epoch_loss / len(train_loader)

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for history, future, neighbors, _ in val_loader:
                history = history.to(device)
                future = future.to(device)
                neighbors = neighbors.to(device)
                goal = future[:, :, -1, :].clone()

                pred, _, _ = model(history, goal)
                loss = criterion(pred, future)
                val_loss += loss.item()
        avg_val_loss = val_loss / len(val_loader)

        print(
            f"\n=== Epoch {epoch + 1}/{epochs} SUMMARY ==="
            f"\nTrain Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}\n"
        )
        avg_gt_speed = sum(epoch_gt_speeds) / len(epoch_gt_speeds)
        avg_pred_speed = sum(epoch_pred_speeds) / len(epoch_pred_speeds)
        avg_hist_speed = sum(epoch_hist_speeds) / len(epoch_hist_speeds)
        total_violations = sum(epoch_violations)
        samples = len(train_loader) * batch_size * 5 * 12
        constraint_compliance = (1 - total_violations / samples) * 100 if samples > 0 else 0

        print(f"Avg Speeds: Hist={avg_hist_speed:.4f}, GT={avg_gt_speed:.4f}, Pred={avg_pred_speed:.4f}")
        print(f"Speed Error: {abs(avg_gt_speed - avg_pred_speed):.4f}")
        print(f"Violations: {total_violations} | Constraint Compliance: {constraint_compliance:.2f}%")
        print("=" * 40)

        scheduler.step(avg_val_loss)

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            torch.save({"model_state_dict": model.state_dict()}, model_save_path)
            print(f"✅ New best model saved at epoch {epoch + 1} with Val Loss {avg_val_loss:.4f}")
        else:
            patience_counter += 1
            print(f"⚠️ EarlyStopping counter: {patience_counter}/{patience}")

        if patience_counter >= patience:
            print("⏹ Early stopping triggered.")
            break

        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()

    torch.autograd.set_detect_anomaly(False)
    print(f"Training completed. Best model saved to {model_save_path}")


if __name__ == "__main__":
    train_mta_pfm_model(
        data_path="/content/combined_annotations.csv",
        model_save_path="/content/mta_pfm_trajectory_model_own.pth",
        model_class=CheckpointedIntegratedMTAPFMModel_neighbours,
        dataset_class=PFM_TrajectoryDataset_neighbours,
        collate_fn=collate_fn,
        batch_size=1,
        epochs=80,
        patience=7,
        accumulation_steps=4,
    )



🎯 Speed Constraints Enabled:
   Target Avg Speed: 4.0870
   Allowed range: [3.4739, 4.7000]
   Tolerance: ±15.0%

[DATASET] __getitem__ output shapes: history torch.Size([809, 8, 2]) future torch.Size([809, 12, 2]) neighbor_histories torch.Size([809, 12, 8, 2]) goals torch.Size([809, 2])
[COLLATE] output shapes: history_batch torch.Size([1, 809, 8, 2]) future_batch torch.Size([1, 809, 12, 2]) neighbor_histories_batch torch.Size([1, 809, 12, 8, 2]) goals_batch torch.Size([1, 809, 2])
[MODEL] forward: history_neighbors.shape = torch.Size([1, 809, 13, 8, 2])
[MODEL] forward: goal.shape = torch.Size([1, 809, 2])
[MODEL] histories_flat.shape after reshape: torch.Size([10517, 8, 2])
[MODEL] embedded.shape after input_embed: torch.Size([10517, 8, 64])


  return fn(*args, **kwargs)


[MODEL] h_n.shape: torch.Size([2, 10517, 64])
[MODEL] c_n.shape: torch.Size([2, 10517, 64])
[MODEL] h_top.shape after reshape: torch.Size([1, 809, 13, 64])
[MODEL] h_ego.shape: torch.Size([1, 809, 64])
[MODEL] h_neighbors.shape: torch.Size([1, 809, 12, 64])
[MODEL] coeffs_ego.shape: torch.Size([1, 809, 3])
[MODEL] coeffs_neighbors.shape: torch.Size([1, 809, 12, 3])
[MODEL] Step 0 - current_pred.shape: torch.Size([809, 1, 2])
[MODEL] Step 0 - pred_embedded.shape: torch.Size([809, 1, 64])
[MODEL] Step 0 - hx shapes: h torch.Size([2, 809, 64]), c torch.Size([2, 809, 64])
[MODEL] Step 0 - decoder out.shape: torch.Size([809, 1, 64])
[MODEL] Step 0 - step_output.shape: torch.Size([809, 2])
[MODEL] Step 0 - pred_flat state updated
[MODEL] Step 1 - current_pred.shape: torch.Size([809, 1, 2])
[MODEL] Step 1 - pred_embedded.shape: torch.Size([809, 1, 64])
[MODEL] Step 1 - hx shapes: h torch.Size([2, 809, 64]), c torch.Size([2, 809, 64])
[MODEL] Step 1 - decoder out.shape: torch.Size([809, 1, 64]

KeyboardInterrupt: 