# EquiPredict: Robust Interaction Modeling for Multi-Agent Motion Forecasting
The ability to understand and predict the motion of multiple agents in dynamic environments is crucial for a range of applications, from autonomous vehicles navigating urban streets to sophisticated pedestrian traffic management systems. Despite considerable advancements, existing motion prediction methods often struggle to capture the intricate interdependencies and the variability these agents exhibit in real-world settings. 
A fundamental challenge lies in ensuring that predictions remain consistent (equivariance) under Euclidean transformations and that interactions among agents are invariant to these transformations. 

Building on the principles established by EqMotion, we propose a refined model, "EquiPredict: Robust Interaction Modeling for Multi-Agent Motion Forecasting," which specifically enhances the prediction of pedestrian trajectories. This re-implementation focuses on achieving high fidelity in motion forecasting by integrating an equivariant geometric feature learning module with an invariant interaction reasoning module. Our model aims to provide a robust and reliable framework specifically tuned to understand and predict pedestrian dynamics, addressing the nuanced movements and interactions typical in crowded urban environments.

## All necessary Imports

In [None]:
! pip install torch-geometric

In [None]:
import argparse
import torch
import os
from torch import nn, optim
import json
import time
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import math
import random
import sys 
from torch import nn
import torch.nn.functional as F
import math
import cv2
import glob
import copy
import warnings
import torch.nn.init as init
from sklearn.cluster import KMeans
from torch_geometric.nn import GCNConv, GATConv 
import matplotlib.animation as animation
from IPython.display import HTML
import torch.optim as optim
import pickle
warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn")
sys.path.append("..")

## Data Handling

In [None]:
# Define the path for storing saved model files.
saved_models = '/kaggle/working/saved_models'

# Check if the directory for saved models exists; if not, create it.
# This allows the storage of trained models so to access them without retraining
if not os.path.exists(saved_models):
        os.makedirs('/kaggle/working/saved_models')

## PreProcessing 

The DataPreprocessor class is designed to process trajectory data for agents in a sequence, extracting relevant movement information and preparing it for analysis. It handles loading, filtering, and processing of data to generate motion vectors and masks for both past and future frames.   
It is composed of the following methods : 
- *Initialization* : It constructs the full path to the data file and loads ground truth data from the specified file. It then sets the initial frame and calculates the total number of frames
- *Load Ground Truth* : It reads ground truth data from a specified file, expecting tab-limited data. It converts data into a NumPy array, extracts frame numbers, and calculates the range of frames
- *Get Id and Get Valid Id* : get_id extracts and returns a list of unique IDs from a given data array. The Ids returend are then passed to the get_valid_id method, which checks IDs in the past and future data frames to ensure they appear in the minimum required number of frames. Only IDs that meet these conditions are considered valid.
- *Filter Data by Frame* :  It generates a list of data arrays for frames around a given frame. It can filter for past or future frames based on the past flag. Frame indices are calculated using frame_skip and num_frames parameters 
- *Compute Motion* : It generates motion vectors and corresponding masks for each valid ID. It processes frames (either past or future) and computes motion by scaling trajectory coordinates
- *Call* : It processes a given frame number to compute motion data for both past and future frames. It filters the data to get relevant frames, identifies valid IDs that appear in both past and future data, and calculates the motion vectors and masks for these IDs

In [None]:
class DataPreprocessor(object):
    def __init__(self, data_root, config, seq_name, split, log=None):
        '''Initialization of the DataPreprocessor class with essential configurations'''
        self.config = config #configuration settings
        self.data_root = data_root #root directory for dataset files
        self.split = split #indicates the sub-folder to choose (test, train, val)
        self.seq_name = seq_name #name of the file to process
        self.log = log #logger to debug 
        self.label_path = os.path.join(data_root, config['dataset'], self.split, seq_name) #construct the full path to the txt file to consider
        self.gt = self.load_ground_truth() #load ground truth from the specified label path
        self.xind, self.zind = 2, 3  #Index Positions for x and z coordinates in the dataset

    def load_ground_truth(self):
        '''This function reads the ground truth data from the file located at "self.label_path".
           It expects data to be tab-limited and attempts to lead it into a numpy array.'''
        #load ground truth data from a text file with tab as delimiter
        self.gt = np.genfromtxt(self.label_path, delimiter='\t', dtype=str)
        
        if self.gt.ndim == 1:
            #if data is not loaded as a 2D array, log a warning
            print(f"Warning: The data in {self.label_path} is not loaded as a 2D array.")
            print(f"Data: {self.gt}")
            
        self.gt = self.gt.astype('float32') #convert data type from numerical operations
        frames = self.gt[:, 0].astype(np.float32).astype(np.int_)  #extract and convert frame numbers
        fr_start, fr_end = frames.min(), frames.max()  #minimum and maximum among all frames
        self.init_frame = fr_start  #set the initial frame 
        self.num_fr = (fr_end + 1 - fr_start)    #get total number of frames
        return self.gt
        

    def get_id(self, data):
        """This function extracts and returns a list of unique IDs from the given data.
           It takes a 2D numpy array `data` as input, extracts the second column 
           (assumed to be IDs), and returns a copy of this column as a list."""
        return data[:, 1].copy().tolist()

    def filter_data_by_frame(self, frame, frame_skip, num_frames, past=False):
        '''Filters data to get subsequent corresponding specific frames
           This function generates a list of data arrays for frames around a given frame. 
           Depending on whether the past flag is set, it will collect frames before or after 
           the given frame. The frame indices are calculated based on the frame_skip and 
           num_frames parameters.'''
        data_list = [] #initialize empty list to hold the filtered data for each frame
        for i in range(num_frames):
            #get frame index based on whether we are looking at past or future frames
            frame_idx = frame - i * frame_skip if past else frame + (i+1) * frame_skip
            
            # Check if the calculated frame index is before the initial frame
            if frame_idx < self.init_frame:
                data = [] # If so, initialize an empty list for data
            #filter the ground truth data for the current frame index
            data = self.gt[self.gt[:, 0] == frame_idx]
            
            data_list.append(data) #add filtered data for the current frame to the data list
        return data_list #return list of filtered data arrays

    def get_valid_id(self, pre_data, fut_data):
        '''This function checks the IDs present in the past data frames (`pre_data`) and the future 
           data frames (`fut_data`). It verifies that each ID appears in a minimum number of past 
           frames and future frames as specified in the configuration. Only IDs that meet both 
           conditions are considered valid.'''
        cur_id = self.get_id(pre_data[0]) #extract current IDs from the first past data frame
        valid_id = [] #initialize an empty list to hold valid IDs
        for idx in cur_id:
            #check if the ID exists in the required number of past frames
            exist_pre = all(idx in data[:, 1] for data in pre_data[:self.config['min_past_frames']] if len(data) > 0)
            #check if the ID exists in the required number of future frames
            exist_fut = all(idx in data[:, 1] for data in fut_data[:self.config['min_future_frames']] if len(data) > 0)
            
            #if the ID exists in both past and future frames, add it to the valid_id list
            if exist_pre and exist_fut:
                valid_id.append(idx)
        return valid_id

    def compute_motion(self, data_tuple, valid_id, past=True):
        '''This function generates motion vectors and corresponding masks for each valid ID 
        in the data. It processes a series of frames, either past or future, depending on 
        the 'past' flag, and computes the motion for each ID by scaling the trajectory 
        coordinates. It handles missing data by carrying forward the last known data.'''
        frames = self.config['past_frames'] if past else self.config['future_frames'] #determine the # of frames to process based on "past" flag
        traj_scale = self.config['traj_scale'] #Trajectory scale factor for normalization
        motion = [] #Initialize a list to hold the motion vectors
        mask = [] #Initialize a list to hold the masks
        
        for identity in valid_id:
            mask_i = torch.zeros(frames)  #Initialize a tensor for the mask of the current ID
            box_3d = torch.zeros([frames, 2])  #Initialize a tensor for the motion vectors of the current ID
            for j in range(frames):
                data = data_tuple[j]
                if len(data) > 0 and identity in data[:, 1]:
                    #extract and scale the coordinates for the current ID
                    found_data = data[data[:, 1] == identity].squeeze()[[self.xind, self.zind]] / traj_scale
                    if past:
                        box_3d[frames - 1 - j, :] = torch.from_numpy(found_data).float()
                        mask_i[frames - 1 - j] = 1.0
                    else:
                        box_3d[j, :] = torch.from_numpy(found_data).float()
                        mask_i[j] = 1.0
                elif j > 0:
                    #handle missing data by carrying forward the last known data
                    if past:
                        box_3d[frames - 1 - j, :] = box_3d[frames - j, :]
                    else:
                        box_3d[j, :] = box_3d[j - 1, :]
                else:
                    # Skip the case where the current ID is missing in the first frame
                    if past:
                        mask_i[frames - 1 - j] = 0.0
                        box_3d[frames - 1 - j, :] = torch.zeros(2)
                    else:
                        mask_i[j] = 0.0
                        box_3d[j,:] = torch.zeros(2)
            motion.append(box_3d)
            mask.append(mask_i)
        return motion, mask

    def __call__(self, frame):
        '''This method processes a given frame number to compute the motion data for both 
           past and future frames. It filters the data to get relevant frames, identifies 
           valid IDs that appear in both past and future data, and calculates the motion 
           vectors and masks for these IDs.'''
        #check if the frame is in the valid range
        if not (0 <= frame - self.init_frame < self.num_fr):
            raise ValueError(f'frame is {frame}, out of range')

        pre_data = self.filter_data_by_frame(frame, self.config['frame_skip'], self.config['past_frames'], past=True) #filter data to get past frames
        fut_data = self.filter_data_by_frame(frame, self.config['frame_skip'], self.config['future_frames'], past=False) #filter data to get future frames

        #identify valid IDs that appear in both past and future frames
        valid_id = self.get_valid_id(pre_data, fut_data)
        if len(pre_data[0]) == 0 or len(fut_data[0]) == 0 or not valid_id:
            #print('None')
            return None #if there's no valid ID then return None
       
        pre_motion_3D, pre_motion_mask = self.compute_motion(pre_data, valid_id, past=True) #compute motion vectors and masks for past frames
        fut_motion_3D, fut_motion_mask = self.compute_motion(fut_data, valid_id, past=False) #compute motion vectors and masks for future frames
        
        #prepare data dictionary with all relevant information
        data = {
            'pre_motion_3D': pre_motion_3D,
            'fut_motion_3D': fut_motion_3D,
            'fut_motion_mask': fut_motion_mask,
            'pre_motion_mask': pre_motion_mask,
            'pre_data': pre_data,
            'fut_data': fut_data,
            'valid_id': valid_id,
            'traj_scale': self.config['traj_scale'],
            'seq': self.seq_name,
            'frame': frame
        }
        
        #return stuctured data
        return data

In [None]:
#seeting up all the necessary parameters to configure the data processing pipeline
config = {
    'dataset': 'hotel',          # Name of the dataset
    'past_frames': 8,          # Number of past frames to consider
    'future_frames': 12,       # Number of future frames to consider
    'frame_skip': 10,          # Number of frames to skip between each step
    'min_past_frames': 8,      # Minimum number of past frames required for a valid ID
    'min_future_frames': 12,   # Minimum number of future frames required for a valid ID
    'traj_scale': 1,           # Scaling factor for trajectory coordinates
    'total_num': 3,            # Total number of frames to process
}

## Agent PreProcessing 

The AgentPreProcessing class is designed to handle and preprocess trajectory data for agents.
We provide two different implementations of the AgentPreProcessing class.  
The main difference between the two is represented by the handling of "invalid entries".  
As already stated in the DataPreprocessor call method, a given frame number is processed to compute motion data for both past and future frames. In doing so, it filters the data to get relevant frames, identifies valid IDs that appear in both past and future data, and calculates the motion vectors and masks for these IDs.  
This implies that a lot of data is not processed due to invalid IDs. For this reason, in AgentPreProcessing_with_Invalids we specifically handle this kind of scenario ( see comment in the next section's getitem and reformat_data methods).

Except for the getitem, both classes are composed of the same following methods : 
- *Initialization & Length* : Upon initialization, in the init we set up paths, frame settings, and various parameters. Sequences from the specified directory are read and the total number of samples across all sequences is calculated. Each sequence is processed using the DataPreprocessor class, and the sample indices are adjusted to account for the frame skip interval of 10 frames. A list sample_indeces provides a complete list of indices for all samples, which will be used to access and manage individual samples within the dataset .The length of the dataset is determined by the total number of samples divided by the skip frame interval (10) 
- *Locating Seq and Frame* : It determines the specific sequence and frame position for a given dataset index by calculating cumulative positions and adjusting for skipped frames 
- *Processing Data for Varying Agent Counts* : It includes both process_data_for_few_agents and process_data_for_many_agents methods. They handle cases where the number of valid agents in the data differs from the expected total. When there are few agents than expected ones, zero-padded arrays are created for prior and future movements to match the expected agent count. As opposite, if there are more agents, distances are calculated to select the nearest agents, ensuring the data meets the expected count 
- *Collecting Data* : It collects and stores past and future movement data along with valid agent counts. It concatenates all collected data into single arrays for unified dataset analysis and model training 


### AgentPreProcessing without Invalids 

Specifically for the AgentPreProcessing_without_Invalids : 
- *Reformat Data* : It reformats the input data into structured arrays for past and future movements. It preprocesses the data, handling cases either fewer or more agents valid agents than the expected total. If data is None, it is skipped and not processed for further analysis and training task 
- *GetItem* : It is responsible for retrieving and preparing a data sample from the dataset, ensuring that only valid data is returned. It attempts to fetch valid data using a retry mechanism : if the data is invalid, it retries up to a set limit. Once valid data is found, it is reformatted, converted to numpy arrays, and returned. 

In [None]:
# trajectory dataset skipping None data
class AgentPreProcessing_without_Invalids(Dataset):
    def __init__(self, root_path, settings, subset, history_frames, future_frames):
        """
        Initializes the dataset by setting up paths, loading sequences, and processing metadata. 
        Focusing on sequence Processing, for each sequence in the dataset, the following steps are perfomed :
        1. A data processor instance is created; 2. The number of samples is calculated and accumulated, based on frame intervals and sequence length; 
        3. The sample count and the data processor are stored for each sequence in the dataset ; 4. A list sample_indeces provides a complete list of indices 
        for all samples, which will be used to access and manage individual samples within the dataset """
  
        self.root_path = root_path
        self.settings = settings
        self.num_agents = self.settings['total_num']
        self.subset = subset
        self.directory = os.path.join(self.root_path, self.settings["dataset"], self.subset)
        self.sequences = os.listdir(self.directory)
        self.history_frames = history_frames
        self.future_frames = future_frames
        self.minimum_history = history_frames
        self.minimum_future = future_frames
        self.skip_frames = 10           # Skip frame is set to 10, since frames go 10 by 10 in the provided dataset 
        self.start_frame = 0
        self.total_samples = 0          # Total number of samples, obtained by summing each sequence total number of samples 
        self.samples_per_sequence = []  # List containing total number of samples per sequence 
        self.processed_sequences = []   # List containing processed_sequences 
        self.threshold = 5              # Threshold for distance calculations
        self.previous_data = []         # List to store past data    
        self.future_data = []           # List to store future data    
        self.valid_counts = []          # List to store the count of valid agents     
        self.all_prior_data = []        # List to store all prior data 
        self.all_future_data = []       # List to store all future data 
        self.all_valid_counts = []      # List to store all valid counts 
        self.valid_data_count = 0       # Count of valid data samples 
        self.invalid_data_count = 0     # Count of invalid data samples 
        
        processor_class = DataPreprocessor    # Reference to the data preprocessor class 
        for sequence_name in self.sequences: 
            sequence_processor = processor_class(root_path, settings, sequence_name, subset)   # Create a data processor for each sequence 
            sequence_sample_count = sequence_processor.num_fr + 1 - (self.minimum_history - 1) * self.skip_frames - self.minimum_future * self.skip_frames + 1   # Calculate the number of samples in the sequence 
            self.total_samples += sequence_sample_count   # Update total samples count 
            self.samples_per_sequence.append(sequence_sample_count)   # Store the sample count for this sequence 
            self.processed_sequences.append(sequence_processor)       # Store the processed sequence 
        
        self.sample_indices = list(range(self.total_samples))     # List of sample indices 
        self.current_index = 0    # Initialize current index 
        self.samples_per_sequence = [(x + 9) // 10 * 10 for x in self.samples_per_sequence]     # Adjusting samples per sequence for routing to the nearest 10 

    def __len__(self):
        """ the length of the dataset is given by the total number of samples divided by 10, since frame skip is 10 """
        return self.total_samples // 10

    def locate_sequence_and_frame(self, index):
        """ locate_sequence_and_frame determines the sequence and the specific frame position within that sequence corresponding to the given dataset index by 
            calculating the cumulative position and adjusting for skipped frames """
        current_position = copy.copy(index) * self.skip_frames
        for seq_id, count in enumerate(self.samples_per_sequence):
            if current_position < count:
                frame_position = current_position + (self.minimum_history - 1) * self.skip_frames + self.processed_sequences[seq_id].init_frame
                return seq_id, frame_position
            current_position -= count
        raise ValueError('Index {} is out of range'.format(index))

    def reformat_data(self, data):
        """ reformat data reformats the input data into structured arrays for past and future movements. It preprocesses the data, handling cases 
            either fewer or more agents valid agents than the expected total. If data is None, it is skipped and not processed for further analysis and training 
            task. The reformatted data is returned """
        if data is not None:
            prior_data, upcoming_data, valids = [], [], []
            prior_movement = np.array(torch.stack(data['pre_motion_3D'], dim=0))
            future_movement = np.array(torch.stack(data['fut_motion_3D'], dim=0))
            agent_count = prior_movement.shape[0]

            if agent_count < self.num_agents:
                self.process_data_for_few_agents(agent_count, prior_movement, future_movement, prior_data, upcoming_data, valids)
            else:
                self.process_data_for_many_agents(agent_count, prior_movement, future_movement, prior_data, upcoming_data, valids)

        return prior_data, upcoming_data, valids

    def process_data_for_few_agents(self, agent_count, prior_movement, future_movement, prior_data, upcoming_data, valids):
        """ process_data_for_few_agents handles cases where the current number of valid agents present in the data is less than the expected total """
        for i in range(agent_count):
            temp = np.zeros((self.num_agents, prior_movement.shape[1], 2))
            temp[:agent_count] = prior_movement
            prior_data.append(temp[None])

            temp = np.zeros((self.num_agents, future_movement.shape[1], 2))
            temp[:agent_count] = future_movement
            upcoming_data.append(temp[None])
            valids.append(agent_count)
        self.previous_data = prior_data
        self.future_data = upcoming_data
        self.valid_counts = valids
        

    def process_data_for_many_agents(self, agent_count, prior_movement, future_movement, prior_data, upcoming_data, valids):
        """ process_data_for_many_agents handles cases where the current number of valid agents present in the data is greater than or equal to the expected total """
        for i in range(agent_count):
            distances = np.linalg.norm(prior_movement[:, -1] - prior_movement[i:i+1, -1], axis=-1)
            close_indices = np.sum((distances < self.threshold).astype(int))

            if close_indices < self.num_agents:
                temp = np.zeros((self.num_agents, prior_movement.shape[1], 2))
                neighbors_idx = np.argsort(distances)
                neighbors_idx = neighbors_idx[:close_indices]
                temp[:close_indices] = prior_movement[neighbors_idx]
                prior_data.append(temp[None])

                temp = np.zeros((self.num_agents, future_movement.shape[1], 2))
                neighbors_idx = neighbors_idx[:close_indices]
                temp[:close_indices] = future_movement[neighbors_idx]
                upcoming_data.append(temp[None])
                valids.append(close_indices)
            else:
                neighbors_idx = np.argsort(distances)
                assert neighbors_idx[0] == i
                neighbors_idx = neighbors_idx[:self.num_agents]
                temp = prior_movement[neighbors_idx]
                prior_data.append(temp[None])
                temp = future_movement[neighbors_idx]
                upcoming_data.append(temp[None])
                valids.append(self.num_agents)
        self.previous_data = prior_data
        self.future_data = upcoming_data
        self.valid_counts = valids
        
    
    def collect_all_data(self, pre_data, fut_data, num_valid):
        """ collec_att_data collects and stores all past and future movement data, along with the count of valid agents, by appending the provided data to the 
            corresponding class attributes """
        self.all_prior_data.append(pre_data)
        self.all_future_data.append(fut_data)
        self.all_valid_counts.extend(num_valid)
    
    def get_concatenated_data(self):
        """ get_concated_data concatenates all connected past and future movement dat, along with the count of valid agents, into single arrays. It returns these 
            concatenated arrays """
        if self.all_prior_data:
            all_past_data = np.concatenate(self.all_prior_data, axis=0)
        else:
            all_past_data = np.empty((0, self.num_agents, self.history_frames, 2))

        if self.all_future_data:
            all_future_data = np.concatenate(self.all_future_data, axis=0)
        else:
            all_future_data = np.empty((0, self.num_agents, self.future_frames, 2))

        all_valid_num = np.array(self.all_valid_counts)

        return all_past_data, all_future_data, all_valid_num


    def __getitem__(self, index):
        """ getitem retrieves a data sample from the dataset. It attempts to fetch valid data using a retry mechanism; 
        if the data is invalid, it retries up to a set limit. Once valid data is found, it is reformatted, converted to numpy arrays, and returned """
        retry_count = 0
        max_retries = self.total_samples // 10
        while retry_count < max_retries:
            sample_idx = self.sample_indices[self.current_index]
            sequence_id, frame = self.locate_sequence_and_frame(sample_idx)
            sequence = self.processed_sequences[sequence_id]
            self.current_index += 1

            data = sequence(frame)
            if data is not None:
                self.valid_data_count += 1
                break
            else:
                self.invalid_data_count += 1
                if self.valid_data_count == 0 and self.invalid_data_count > 0:
                    retry_count += 1
                    continue

        if retry_count >= max_retries:
            raise ValueError("Too many invalid data samples after {} retries".format(max_retries))

        prepared_data = self.reformat_data(data)
        pre_data, fut_data, num_valid = prepared_data
        pre_data = np.array(pre_data, dtype=np.float32)
        fut_data = np.array(fut_data, dtype=np.float32)
        num_valid = np.array(num_valid)
        
        pre_data = pre_data.reshape(-1, self.num_agents, self.history_frames, 2)
        fut_data = fut_data.reshape(-1, self.num_agents, self.future_frames, 2)
        num_valid = num_valid.reshape(-1)

        self.collect_all_data(pre_data, fut_data, num_valid)

        return pre_data, fut_data, num_valid

### AgentPreProcessing with Invalids 

Specifically for the AgentPreProcessing_with_Invalids :  
- *Reformatting Data* : It reformats input data into structured arrays for past and future movements. If the data is invalid, it falls back on the last valid data, ensuring consistent input for further processing 
- *GetItem* : It just retrieves and processes data for a given index by determining the sample index, locating the corresponding sequence and frame, and reformatting the data into prior and future motion data and valid agent counts

In [None]:
# trajectory dataset handling None data
class AgentPreProcessing_with_Invalids(Dataset):
    def __init__(self, root_path, settings, subset, history_frames, future_frames):
        
        self.root_path = root_path
        self.settings = settings
        self.num_agents = self.settings['total_num']
        self.subset = subset
        self.directory = os.path.join(self.root_path, self.settings["dataset"], self.subset)
        self.sequences = os.listdir(self.directory)
        self.history_frames = history_frames
        self.future_frames = future_frames
        self.minimum_history = history_frames
        self.minimum_future = future_frames
        self.skip_frames = 10            # Skip frame is set to 10, since frames go 10 by 10 in the provided dataset 
        self.start_frame = 0
        self.total_samples = 0           # Total number of samples, obtained by summing each sequence total number of samples 
        self.samples_per_sequence = []   # List containing total number of samples per sequence 
        self.processed_sequences = []    # List containing processed_sequences 
        self.threshold = 5               # Threshold for distance calculations
        self.previous_data = []          # List to store past data    
        self.future_data = []            # List to store future data
        self.valid_counts = []           # List to store the count of valid agents 
        self.all_prior_data = []         # List to store all prior data 
        self.all_future_data = []        # List to store all future data 
        self.all_valid_counts = []       # List to store all valid counts 
        self.valid_data_count = 0        # Count of valid data samples 
        self.invalid_data_count = 0      # Count of invalid data samples 
        
        processor_class = DataPreprocessor    # Reference to the data preprocessor class 
        for sequence_name in self.sequences:
            sequence_processor = processor_class(root_path, settings, sequence_name, subset)     # Create a data processor for each sequence 
            # calculate the number of samples in the sequence 
            sequence_sample_count = sequence_processor.num_fr + 1 - (self.minimum_history - 1) * self.skip_frames - self.minimum_future * self.skip_frames + 1
            self.total_samples += sequence_sample_count  # Update total samples count
            self.samples_per_sequence.append(sequence_sample_count)   # Store the sample count for this sequence 
            self.processed_sequences.append(sequence_processor)   # Store the procced sequence 
        
        self.sample_indices = list(range(self.total_samples))   # List of sample indices -> [0,1, ..., self.total_samples] 
        self.current_index = 0   # Initializing current index 
        self.samples_per_sequence = [(x + 9) // 10 * 10 for x in self.samples_per_sequence]    # Adjusting samples per sequence for routing to the nearest 10 

    def __len__(self):
        """ the length of the dataset is given by the total number of samples divided by 10, since frame skip is 10 """
        return self.total_samples // 10

    def locate_sequence_and_frame(self, index):
        """ locate_sequence_and_frame determines the sequence and the specific frame position within that sequence corresponding to the given dataset index by 
            calculating the cumulative position and adjusting for skipped frames """
        current_position = copy.copy(index) * self.skip_frames
        for seq_id, count in enumerate(self.samples_per_sequence):
            if current_position < count:
                frame_position = current_position + (self.minimum_history - 1) * self.skip_frames + self.processed_sequences[seq_id].init_frame
                return seq_id, frame_position
            current_position -= count
        raise ValueError('Index {} is out of range'.format(index))

    def reformat_data(self, data):
        """ reformat data reformats the input data into structured arrays for past and future movements. It preprocesses the data, handling cases 
            either fewer or more agents valid agents than the expected total. If the data is invalid, it falls back on the last valid data, ensuring 
            consistent input for further processing. The reformatted data is returned """
        if data is not None:
            self.valid_data_count += 1
            prior_data, upcoming_data, valids = [], [], []
            prior_movement = np.array(torch.stack(data['pre_motion_3D'], dim=0))
            future_movement = np.array(torch.stack(data['fut_motion_3D'], dim=0))
            agent_count = prior_movement.shape[0]

            if agent_count < self.num_agents:
                self.process_data_for_few_agents(agent_count, prior_movement, future_movement, prior_data, upcoming_data, valids)
            else:
                self.process_data_for_many_agents(agent_count, prior_movement, future_movement, prior_data, upcoming_data, valids)
        else:
            # Fallback to previous valid data
            prior_data, upcoming_data, valids = self.previous_data, self.future_data, self.valid_counts
            self.invalid_data_count += 1

        return prior_data, upcoming_data, valids

    def process_data_for_few_agents(self, agent_count, prior_movement, future_movement, prior_data, upcoming_data, valids):
        """ process_data_for_few_agents handles cases where the current number of valid agents present in the data is less than the expected total """
        for i in range(agent_count):
            temp = np.zeros((self.num_agents, prior_movement.shape[1], 2))
            temp[:agent_count] = prior_movement
            prior_data.append(temp[None])

            temp = np.zeros((self.num_agents, future_movement.shape[1], 2))
            temp[:agent_count] = future_movement
            upcoming_data.append(temp[None])
            valids.append(agent_count)
        self.previous_data = prior_data
        self.future_data = upcoming_data
        self.valid_counts = valids
        

    def process_data_for_many_agents(self, agent_count, prior_movement, future_movement, prior_data, upcoming_data, valids):
        """ process_data_for_many_agents handles cases where the current number of valid agents present in the data is greater than or equal to the expected total """
        for i in range(agent_count):
            distances = np.linalg.norm(prior_movement[:, -1] - prior_movement[i:i+1, -1], axis=-1)
            close_indices = np.sum((distances < self.threshold).astype(int))

            if close_indices < self.num_agents:
                temp = np.zeros((self.num_agents, prior_movement.shape[1], 2))
                neighbors_idx = np.argsort(distances)
                neighbors_idx = neighbors_idx[:close_indices]
                temp[:close_indices] = prior_movement[neighbors_idx]
                prior_data.append(temp[None])

                temp = np.zeros((self.num_agents, future_movement.shape[1], 2))
                neighbors_idx = neighbors_idx[:close_indices]
                temp[:close_indices] = future_movement[neighbors_idx]
                upcoming_data.append(temp[None])
                valids.append(close_indices)
            else:
                neighbors_idx = np.argsort(distances)
                assert neighbors_idx[0] == i
                neighbors_idx = neighbors_idx[:self.num_agents]
                temp = prior_movement[neighbors_idx]
                prior_data.append(temp[None])
                temp = future_movement[neighbors_idx]
                upcoming_data.append(temp[None])
                valids.append(self.num_agents)
        self.previous_data = prior_data
        self.future_data = upcoming_data
        self.valid_counts = valids

    
    def collect_all_data(self, pre_data, fut_data, num_valid):
        """ collec_att_data collects and stores all past and future movement data, along with the count of valid agents, by appending the provided data to the 
            corresponding class attributes """
        self.all_prior_data.append(pre_data)
        self.all_future_data.append(fut_data)
        self.all_valid_counts.extend(num_valid)
    
    def get_concatenated_data(self):
        """ get_concated_data concatenates all connected past and future movement dat, along with the count of valid agents, into single arrays. It returns these 
            concatenated arrays """
        if self.all_prior_data:
            all_past_data = np.concatenate(self.all_prior_data, axis=0)
        else:
            all_past_data = np.empty((0, self.num_agents, self.history_frames, 2))

        if self.all_future_data:
            all_future_data = np.concatenate(self.all_future_data, axis=0)
        else:
            all_future_data = np.empty((0, self.num_agents, self.future_frames, 2))

        all_valid_num = np.array(self.all_valid_counts)

        return all_past_data, all_future_data, all_valid_num


    def __getitem__(self, index):
        """ getitem retrieves and process data for a given index. It determines the sample index, using the current index. Then, it locates the corresponding sequence 
            and frame based on the sample index and retrives the data for the specified frame from the identified sequence. It reformates the retrieved data into prior and future 
            motion data and valid agent counts """
        
        sample_idx = self.sample_indices[self.current_index]
        sequence_id, frame = self.locate_sequence_and_frame(sample_idx)
        sequence = self.processed_sequences[sequence_id]
        self.current_index += 1

        data = sequence(frame)

        prepared_data = self.reformat_data(data)
        pre_data, fut_data, num_valid = prepared_data
        pre_data = np.array(pre_data, dtype=np.float32)
        fut_data = np.array(fut_data, dtype=np.float32)
        num_valid = np.array(num_valid)
        
        pre_data = pre_data.reshape(-1, self.num_agents, self.history_frames, 2)
        fut_data = fut_data.reshape(-1, self.num_agents, self.future_frames, 2)
        num_valid = num_valid.reshape(-1)

        self.collect_all_data(pre_data, fut_data, num_valid)

        return pre_data, fut_data, num_valid

## Create final dataset

This section focuses on creating the final dataset and defining two classes, TrajectoryDataset_without_Invalids and TrajectoryDataset_with_Invalids, which are used to process and manage trajectory data. These classes are designed to work with the AgentPreProcessing pipeline and handle datasets with and without invalid entries, respectively.

The TrajectoryDataset_without_Invalids class processes data excluding any invalid entries, while the TrajectoryDataset_with_Invalids class includes invalid entries in its processing. Both classes convert the trajectory data into a format suitable for model training and evaluation, ensuring that the data is scaled appropriately and organized into past and future sequences.

In [None]:
# Class for Trajectory Dataset without Invalids
class TrajectoryDataset_without_Invalids(Dataset):
    """ Defining TrajectoryDataset to use with AgentPreProcessing_without_Invalids """
    def __init__(self, dataset, valids, settings, history_frames, future_frames):
        self.dataset = dataset # Assign the dataset
        self.valids = valids # Number of valid data points
        
        # Process each valid data point in the dataset
        for i in range(self.valids):
            prior_data, upcoming_data, valid = dataset[i]
            
        # Concatenate past and future data from the dataset
        self.all_past_data, self.all_future_data, self.all_valid_num = dataset.get_concatenated_data()
        self.settings = settings # Configuration settings
        self.traj_scale = self.settings["traj_scale"] # Scaling factor for trajectory coordinates
        self.history_frames = history_frames # Number of past frames
        self.future_frames = future_frames # Number of future frames
        
        # Concatenate past and future data along the third axis and convert to tensor
        self.all_past_future = np.concatenate([self.all_past_data, self.all_future_data], axis=2)
        self.all_past_future = torch.Tensor(self.all_past_future)
        self.all_valid_num = torch.Tensor(self.all_valid_num)
        
    def __len__(self):
        return self.all_past_future.shape[0] # Return the total number of sequences

    def __getitem__(self, index):
        """ Returns past_seq, future_seq, number of agents """
        # Normalize sequence by trajectory scale
        seq = self.all_past_future[index] / self.traj_scale
        valid_num = self.all_valid_num[index] # Number of valid agents
        past_seq = seq[:, :self.history_frames] # Extract past sequence
        future_seq = seq[:, self.history_frames:] # Extract future sequence
        return past_seq, future_seq, valid_num # Return past sequence, future sequence, and valid number

In [None]:
# Class for Trajectory Dataset with Invalids
class TrajectoryDataset_with_Invalids(Dataset):
    """ Defining TrajectoryDataset to use with AgentPreProcessing_with_Invalids """
    def __init__(self, dataset, settings, history_frames, future_frames):
        self.dataset = dataset # Assign the dataset
        
        # Process each data point in the dataset
        for i in range(len(self.dataset)):
            prior_data, upcoming_data, valid = dataset[i]
            
        # Concatenate past and future data from the dataset
        self.all_past_data, self.all_future_data, self.all_valid_num = dataset.get_concatenated_data()
        self.settings = settings # Configuration settings
        self.traj_scale = self.settings["traj_scale"] # Scaling factor for trajectory coordinates
        self.history_frames = history_frames # Number of past frames
        self.future_frames = future_frames # Number of future frames
        
        # Concatenate past and future data along the third axis and convert to tensor
        self.all_past_future = np.concatenate([self.all_past_data, self.all_future_data], axis=2)
        self.all_past_future = torch.Tensor(self.all_past_future)
        self.all_valid_num = torch.Tensor(self.all_valid_num)
        
    def __len__(self):
        return self.all_past_future.shape[0] # Return the total number of sequences

    def __getitem__(self, index):
        """ Returns past_seq, future_seq, number of agents """
        # Normalize sequence by trajectory scale
        seq = self.all_past_future[index] / self.traj_scale
        valid_num = self.all_valid_num[index] # Number of valid agents
        past_seq = seq[:, :self.history_frames] # Extract past sequence
        future_seq = seq[:, self.history_frames:] # Extract future sequence
        return past_seq, future_seq, valid_num # Return past sequence, future sequence, and valid number

## Defining Options to Train and Test

In this section of the notebook, the script prepares the configurations required to train and test the "EquiPredict" model. It defines various parameters, including experiment details, model specifications, and training hyperparameters. These parameters are encapsulated within a dictionary and are further converted into an object for ease of access and manipulation in subsequent code.

In [None]:
# Configuration for the model training and evaluation.
# Initialize a dictionary to store all experiment parameters.
args = {
    'exp_name': 'exp_1',  # Name of the experiment.
    'batch_size': 100,  # Number of samples in each batch.
    'epochs': 100,  # Total number of training epochs.
    'past_length': 8,  # Number of past frames to consider for the prediction.
    'future_length': 12,  # Number of future frames to predict.
    'no_cuda': False,  # Flag to disable CUDA even if available.
    'seed': -1,  # Seed for random number generation. -1 means no specific seed.
    'log_interval': 1,  # Frequency of logging training status.
    'test_interval': 1,  # Frequency of testing the model.
    'outf': 'n_body_system/logs',  # Output directory for logs.
    'lr': 1e-6,  # Learning rate for the optimizer.
    'epoch_decay': 2,  # Number of epochs after which learning rate will decay.
    'lr_gamma': 0.8,  # Learning rate decay factor.
    'nf': 64,  # Number of features.
    'model': 'egnn_vel',  # Model type to be used.
    'attention': 0,  # Whether to use attention mechanism.
    'n_layers': 4,  # Number of layers in the neural network.
    'degree': 2,  # Degree parameter for some models.
    'channels': 64,  # Number of channels in models.
    'max_training_samples': 3000,  # Maximum number of training samples to consider.
    'dataset': 'nbody',  # Dataset to use.
    'sweep_training': 0,  # Whether to use parameter sweeping in training.
    'time_exp': 0,  # Flag for time experiment.
    'weight_decay': 1e-12,  # Weight decay to prevent overfitting.
    'div': 1,  # Division factor for something (ambiguous without context).
    'norm_diff': False,  # Whether to normalize differences.
    'tanh': False,  # Whether to apply tanh activation.
    'subset': 'eth',  # Subset of data to be used.
    'model_save_dir': '/kaggle/working/saved_models',  # Directory to save trained models.
    'scale': 1,  # Scale factor for inputs/outputs.
    'apply_decay': False,  # Whether to apply decay to learning rate.
    'res_pred': False,  # Whether to use residual predictions.
    'supervise_all': False,  # Whether all layers are supervised.
    'model_name': 'eth_ckpt_best',  # Name to save the model checkpoint.
    'test_scale': 1,  # Scaling for testing phase.
    'test': False,  # Whether to run tests.
    'vis': False  # Whether to enable visualization.
}

In [None]:
# Create a class to mimic argparse's namespace functionality.
class ArgsNamespace:
    def __init__(self, adict):
        self.__dict__.update(adict)  # Update the object's dictionary with the passed dictionary.

# Instantiate the ArgsNamespace class with the args dictionary.
args = ArgsNamespace(args)
args.cuda = torch.cuda.is_available()  # Check if CUDA is available and update the args object.

## Initialize Dataset and Set-up Dataloader - without PreProcessing Invalid Data

In [None]:
#''' 
# Dataset Initialization
dataset_train = AgentPreProcessing_without_Invalids("/kaggle/input/eth-ucy-processed/datasets/", config, "train", history_frames=args.past_length, future_frames=args.future_length)
dataset_val = AgentPreProcessing_without_Invalids("/kaggle/input/eth-ucy-processed/datasets/", config, "val", history_frames=args.past_length, future_frames=args.future_length)
dataset_test = AgentPreProcessing_without_Invalids("/kaggle/input/eth-ucy-processed/datasets/", config, "test", history_frames=args.past_length, future_frames=args.future_length)

# final versions of datasets
final_dataset_train = TrajectoryDataset_without_Invalids(dataset_train, 3241, config, args.past_length, args.future_length)
final_dataset_val = TrajectoryDataset_without_Invalids(dataset_val, 756, config, args.past_length, args.future_length)
final_dataset_test = TrajectoryDataset_without_Invalids(dataset_test, 731,  config, args.past_length, args.future_length)

# Data Loader Setup
loader_train = DataLoader(final_dataset_train, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=4)
loader_val = DataLoader(final_dataset_val, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=4)
loader_test = DataLoader(final_dataset_test, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=4)
#''' 

## Initialize Dataset with invalids and Set-up Dataloader - PreProcessing also Invalid Data

In [None]:
'''
# Dataset Initialization
dataset_train = AgentPreProcessing_with_Invalids("/kaggle/input/eth-ucy-processed/datasets/", config, "train", history_frames=args.past_length, future_frames=args.future_length)
dataset_val = AgentPreProcessing_with_Invalids("/kaggle/input/eth-ucy-processed/datasets/", config, "val", history_frames=args.past_length, future_frames=args.future_length)
dataset_test = AgentPreProcessing_with_Invalids("/kaggle/input/eth-ucy-processed/datasets/", config, "test", history_frames=args.past_length, future_frames=args.future_length)

# final versions of datasets
final_dataset_train = TrajectoryDataset_with_Invalids(dataset_train, config, args.past_length, args.future_length)
final_dataset_val = TrajectoryDataset_with_Invalids(dataset_val, config, args.past_length, args.future_length)
final_dataset_test = TrajectoryDataset_with_Invalids(dataset_test, config, args.past_length, args.future_length)

# Data Loader Setup
loader_train = DataLoader(final_dataset_train, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=4)
loader_val = DataLoader(final_dataset_val, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=4)
loader_test = DataLoader(final_dataset_test, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=4)
'''

## Definition of all model components

This code defines several classes related to the neural network architecture for a machine learning model that deals with motion or trajectory prediction. The model components include feature initialization, interaction graphs, and a feature learning layer, each encapsulated within separate classes. These classes handle different aspects of feature processing and interaction reasoning among agents within the system, crucial for accurately predicting trajectories based on past and current data inputs.

#### **1. Feature Initialization**

The class is designed to initialize and transform input features into a more suitable format for subsequent processing by other parts of a neural network model. This transformation is critical for extracting and enhancing underlying patterns in the data that are not immediately apparent in their raw form.
It includes:
- *Embedding Layers*: Two separate layers transform the input features and velocity angles into a hidden space.
- *Activation Function*: A ReLU function introduces non-linearity, helping to capture more complex patterns.
- *Weight Initialization*: Uses Xavier initialization to optimize training by maintaining stable gradients.
- *Forward Pass*: Combines transformed input and velocity features into a single feature vector for subsequent processing.

In [None]:
class FeatureInitialization(nn.Module):
    """Class for initializing features of input data"""
    def __init__(self, in_node_nf, hidden_nf, act_fn=nn.ReLU):
        super().__init__()
        # Embedding layers to transform input node features into a hidden feature space.
        self.embedding = nn.Linear(in_node_nf, hidden_nf // 2)
        self.embedding2 = nn.Linear(in_node_nf, hidden_nf // 2)
        
        self.act_fn = act_fn() # Activation function for non-linearity in feature transformation.
        self._initialize_weights() # Initialize weights with a specific strategy for better training performance.

    def initialize_weights(self):
        """Method to initialize weights of embedding layers using Xavier initialization"""
        init.xavier_uniform_(self.embedding.weight)
        init.xavier_uniform_(self.embedding2.weight)

    def forward(self, h, vel_angle):
        """Forward pass to compute the combined feature vector from input features and velocity angles"""
        h = self.embedding(h)
        vel_angle_embedding = self.embedding2(vel_angle)
        return torch.cat([h, vel_angle_embedding], dim=-1)

#### **2. Interaction Graph**

This class models the interactions among agents by processing and combining their feature vectors through several multilayer perceptrons (MLPs). It also categorizes these interactions into different groups, which can help in segmenting agent behaviors or predicting different interaction outcomes.

It's structured in 2 main parts:
- *Initialization*:
    - Configuration Parameters: Accepts parameters for the sizes of hidden layers (hidden_nf), channel sizes (hid_channel), activation function (act_fn), and the number of categories (category_num).
      MLPs for Edge, Coordinate, and Node Features:
    - edge_mlp: Processes combined features of pairs of agents to learn about their direct interactions.
    - coord_mlp: Transforms coordinate features into a space that enhances their representation.
    - node_mlp: Processes node features to refine their information after interactions are considered.
    - category_mlp: Determines the category of interactions among agents based on their features.
    
- *Category Calculation*:
    - Feature Preparation: Flattens and concatenates the features of all agents for clustering.
    - Clustering: Uses K-means to categorize the interactions into predefined groups based on feature similarity.
    - Output Formatting: Transforms clustering labels into one-hot encoded format and adjusts them to match the expected dimensions for further processing in the network.

In [None]:
class InteractionGraph(nn.Module):
    """Class to model interaction graphs between agents using learned features"""
    def __init__(self, hidden_nf, hid_channel, act_fn=nn.ReLU, category_num=4):
        super().__init__()
        self.hidden_nf = hidden_nf  # Number of features in hidden layers.
        self.hid_channel = hid_channel  # Number of channels in hidden layers.
        self.category_num = category_num  # Number of categories to classify interactions.
        self.tao = 1  # Parameter, possibly a threshold or scaling factor, to be defined.

        # Define MLPs for processing different aspects of agent interaction.
        self.edge_mlp = nn.Sequential(
            nn.Linear(hidden_nf * 2 + hid_channel * 2, hidden_nf),
            act_fn(),
            nn.Linear(hidden_nf, hidden_nf),
            act_fn()
        )

        self.coord_mlp = nn.Sequential(
            nn.Linear(hid_channel, hidden_nf),
            act_fn(),
            nn.Linear(hidden_nf, hid_channel * 2),
            act_fn()
        )

        self.node_mlp = nn.Sequential(
            nn.Linear(hidden_nf + hidden_nf, hidden_nf),
            act_fn(),
            nn.Linear(hidden_nf, hidden_nf)
        )

        self.category_mlp = nn.Sequential(
            nn.Linear(hidden_nf * 2 + hid_channel * 2, hidden_nf),
            act_fn(),
            nn.Linear(hidden_nf, category_num),
            act_fn()
        )

    def calc_category(self, h, coord, valid_mask):
        """Method to calculate interaction categories based on combined features using clustering"""
        batch_size, agent_num = coord.shape[:2]
        
        # Flatten and concatenate features for clustering
        h_flat = h.view(batch_size * agent_num, -1)
        coord_flat = coord.view(batch_size * agent_num, -1)
        features = torch.cat([h_flat, coord_flat], dim=-1).detach().cpu().numpy()
        
        # Perform K-means clustering to categorize interactions
        kmeans = KMeans(n_clusters=self.category_num)
        kmeans.fit(features)
        cluster_labels = kmeans.labels_
        
        # Reshape cluster labels to match the batch and agent dimensions
        cluster_labels = torch.tensor(cluster_labels, dtype=torch.long, device=h.device)
        cluster_labels = cluster_labels.view(batch_size, agent_num)
        
        # Create one-hot encoding for cluster labels
        interaction_category = F.one_hot(cluster_labels, num_classes=self.category_num).float()
        
        # Expand dimensions to match expected output shape
        interaction_category = interaction_category.unsqueeze(2).expand(-1, -1, agent_num, -1)
        
        return interaction_category

#### **3. Feature Learning Layer**

This class is designed to process the features of agents, update their state based on interactions, and apply attention mechanisms to focus on relevant parts of the data dynamically. It uses multiple neural network layers to compute and refine feature representations, which are critical for tasks like trajectory prediction.

It's organized into 4 main parts:
- *Initialization*:
    - Configuration Parameters: Takes multiple parameters defining the sizes of inputs, outputs, and features, as well as flags for using recurrent updates, attention mechanisms, and reasoning.
    - Neural Network Layers:
        - coord_vel: Linear layer to update velocities.
        - edge_mlp, node_mlp, category_mlps: Multiple MLPs to process edge, node, and categorical features.
        - attention components (query, key): Defined if attention is enabled, for calculating attention weights.
- *MLP Builder*: Constructs MLPs dynamically based on specified input and output sizes and includes nonlinear activation functions, which can be extended with tanh activations if specified.
- *Feature Computation Methods*:
    - compute_edge_features: Calculates features based on the difference between agent coordinates.
    - update_coordinates: Updates agent coordinates based on edge features.
    - compute_node_features: Refines node features by incorporating aggregated edge information.
    - apply_attention: Adjusts coordinates based on attention-weighted features.
- *Forward Pass*: Integrates all the computations to update agent states and apply attention, returning updated feature representations and coordinates.

In [None]:
class Feature_learning_layer(nn.Module):
    """Class for learning and updating features in an agent-based model"""
    def __init__(self, input_nf, output_nf, hidden_nf, input_c, hidden_c, output_c, edges_in_d=0, nodes_att_dim=0, act_fn=nn.ReLU, recurrent=True, coords_weight=1.0, attention=True, norm_diff=False, tanh=False, apply_reasoning=True, input_reasoning=False, category_num=2):
        super().__init__()
        # Flags for model behavior and feature processing configurations.
        self.recurrent = recurrent  # Determines if layer updates should be recurrent.
        self.attention = attention  # Determines if attention mechanisms should be applied.
        self.apply_reasoning = apply_reasoning  # Enables reasoning within the layer.
        self.category_num = category_num  # Number of output categories for classification tasks.

        # Core layers for feature processing.
        self.coord_vel = nn.Linear(2, 2, bias=False)  # Layer to update velocities.
        self.edge_mlp = self._build_mlp(input_nf * 2 + edges_in_d + hidden_c, hidden_nf, act_fn, layers=3)  # MLP for processing edge features.
        self.node_mlp = self._build_mlp(input_nf + hidden_nf + nodes_att_dim, output_nf, act_fn)  # MLP for node feature aggregation.
        self.category_mlps = nn.ModuleList([self._build_mlp(input_nf * 2 + hidden_c, hidden_c, act_fn, layers=3) for _ in range(category_num)])  # List of MLPs for different categories.
        
        # Attention mechanisms, instantiated if enabled.
        if attention:
            self.query = nn.Linear(hidden_c, hidden_c, bias=False)  # Generates query vectors for the attention mechanism.
            self.key = nn.Linear(hidden_c, hidden_c, bias=False)  # Generates key vectors for the attention mechanism.

        self.inner_attention_mlp = nn.Sequential(nn.Linear(hidden_nf, hidden_c), act_fn())  # MLP to process features within the attention mechanism.

    def build_mlp(self, input_size, output_size, act_fn, layers=2, add_tanh=False):
        """Function to dynamically build MLP structures"""
        mlp_layers = [nn.Linear(input_size, output_size), act_fn()]
        for _ in range(1, layers):
            mlp_layers.append(nn.Linear(output_size, output_size))
            mlp_layers.append(act_fn())
        if add_tanh:
            mlp_layers.append(nn.Tanh())  # Adds an optional Tanh layer for additional non-linearity.
        return nn.Sequential(*mlp_layers)

    def compute_edge_features(self, h, coord):
        """Computes edge features by considering spatial relationships and feature differences"""
        batch_size, agent_num, coord_dim, _ = coord.shape
        h1 = h.unsqueeze(2).expand(-1, -1, agent_num, -1)  # Expand features for each agent pair.
        h2 = h.unsqueeze(1).expand(-1, agent_num, -1, -1)  # Second expansion for pairwise comparison.
        coord_diff = coord.unsqueeze(2) - coord.unsqueeze(1)  # Compute pairwise coordinate differences.
        coord_dist = coord_diff.norm(dim=-1, keepdim=False)  # Calculate Euclidean distance between coordinates.
        edge_features = torch.cat((h1, h2, coord_dist), dim=-1)  # Concatenate features and distances.
        return self.edge_mlp(edge_features)  # Process concatenated features through an MLP.

    def update_coordinates(self, coord, edge_features):
        """Update agent coordinates based on computed edge features"""
        coord_factors = edge_features.mean(dim=2)  # Average features across dimensions.
        coord_factors = coord_factors.unsqueeze(-1).expand_as(coord)  # Expand features to match coordinate dimensions.
        return coord + coord_factors  # Update coordinates by adding feature-driven adjustments.

    def compute_node_features(self, h, edge_features, valid_mask):
        """Aggregate and refine node features using the edge features"""
        aggregated_edges = edge_features.sum(dim=2)  # Sum edge features to aggregate information.
        return self.node_mlp(torch.cat((h, aggregated_edges), dim=-1))  # Combine and process node and edge features.

    def apply_attention(self, coord, h, valid_mask_agent, num_valid):
        """Apply attention to refine feature adjustments based on their relevance"""
        query = self.query(h)  # Generate query vectors.
        key = self.key(h).transpose(1, 2)  # Generate and transpose key vectors.
        att_weights = torch.bmm(query, key)  # Compute raw attention weights.
        seq_len = valid_mask_agent.size(1)
        valid_mask_agent = valid_mask_agent.squeeze(-1).transpose(1, 2).expand(-1, seq_len, -1)  # Adjust and expand valid masks.
        att_weights = F.softmax(att_weights, dim=2) * valid_mask_agent  # Apply softmax and mask to normalize attention weights.
        coord_flattened = coord.reshape(coord.size(0), coord.size(1), -1)  # Flatten coordinates for matrix multiplication.
        coord_adjusted = torch.bmm(att_weights, coord_flattened)  # Adjust coordinates based on attention weights.
        coord_final = coord_adjusted.reshape(coord.size(0), coord.size(1), coord.size(2), coord.size(3))  # Reshape back to original dimensions.
        return coord_final  # Return adjusted coordinates.

    def forward(self, h, coord, vel, valid_mask, valid_mask_agent, num_valid, category=None):
        """Forward pass integrates all computations to update agent features and coordinates based on interactions and attention"""
        edge_features = self.compute_edge_features(h, coord)  # Compute edge features.
        coord = self.update_coordinates(coord, edge_features)  # Update coordinates based on edge features.
        coord += self.coord_vel(vel)  # Apply velocity updates.
        h = self.compute_node_features(h, edge_features, valid_mask)  # Compute and update node features.
        coord = self.apply_attention(coord, h, valid_mask_agent, num_valid)  # Refine coordinates using attention.
        return h, coord, category  # Return updated features and coordinates.

### EquiPredict 

EquiPredict is a neural network model designed for motion prediction in a graph-based or relational context. It integrates multiple advanced neural network components, including multi-head attention, graph convolutional layers (GCNs / GATs), and recurrent layers (LSTMs).  
It is used to predict the future motion of nodes (agents) in a dynamic system based on their current features, positions, and velocities.   
Looking deeply at EqMotion forward function, the following steps are perfomed :  
- 1. **Embedding and Transforming Features** : 
        - *Node features and Velocity Angles* : They are transformed into lower-dimensional spaces to capture essential characteristics while reducing dimensionality and then embedded. This allows the model to handle and interpret complex features associaed with the agents' state and movements. 
        - *Geometric Coordinates and Velocities* : Transformations are then applied to the geometric coordinates and velocities of the agents, so they are scaled and normalized appropriately. By transforming coordinates and velocities into a higher-dimensional space, the model enhances its ability to capture subtle patterns and interactions.
- 2. **Discrete Cosine Transform** : Data is transformed into the frequency domain to simplify representation and capture nuanced patterns. Inverse DCT converts data back to the original domain for accurate predictions. IDCT will be used to convert back the final output 
- 3. **Interaction Categories and Feature Learning Layers** : 
This step is essential for understanding how agents interact with one another based on their features and coordinates. 
        - *Interaction categories* : They are computed using message passing and aggregation techniques, which involve evaluating how different agents influence each other. These categories guide the model in learning meaningful interactions between agents, helping refine their representations. 
        - *Feature Learning Layers and Graph Convolutional Layer* : The model utilizies Feature Learning Layers to process these interactions. By applying those layers, the model updates the node features based on the computed interaction categories, effectively capturing the relationships and dependencies among agents.
- 4. **Recurrent Processing for Temporal Dynamics** : If recurrent processing is enabled, the model incorporates an LSTM layer. It processes the node features across time steps, allowing the model to account for how agent interactions and states evolve over time. 
- 5. **Final Predictions** : The final component of the forward pass involves generating predictions using multiple prediction heads. Each head processes the refined node features to produce outputs related to the coordinates of the agents. The use of multiple prediction heads allows the model to generate diverse predictions and aggregate them for improved accuracy. After predictions are made, if DCT was applied earlier, the model transforms the coordinates back from the frequency domain to the spatial domain using the inverse DCT. 

In [None]:
""" MultiHead Attention processes input sequences through multiple parallel attention heads, each learning different  
aspects of the relationships between tokens. The results from each head are combined and projected back to the desired 
output dimension. This approach helps the model capture diverse features and dependencies in the input data """

class MultiHeadAttention(nn.Module):
    def __init__(self, input_dim, hidden_dim, device, num_heads=4):
        super(MultiHeadAttention, self).__init__()
        self.device = device
        
        self.num_heads = num_heads
        self.hidden_dim = hidden_dim
        self.head_dim = hidden_dim // num_heads
        
        self.linear_layers = nn.ModuleList([
            nn.Linear(input_dim, hidden_dim),
            nn.Linear(input_dim, hidden_dim),
            nn.Linear(input_dim, hidden_dim),
            nn.Linear(hidden_dim, hidden_dim)
        ])
        
        # Move all linear layers to device
        self.linear_layers.to(self.device)
        self.to(self.device)

    def forward(self, query, key, value, mask=None):
    
        batch_size, seq_length, _, _ = query.shape        
        
        # Move inputs to device
        query = query.to(self.device)
        key = key.to(self.device)
        value = value.to(self.device)
        
        # Project inputs using ModuleList
        Q = self.linear_layers[0](query)
        K = self.linear_layers[1](key)
        V = self.linear_layers[2](value)
        
        # Split heads
        Q = Q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Calculate attention scores
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
        if mask is not None:
            attention_scores = attention_scores.masked_fill(mask == 0, float('-inf'))
        
        attention_weights = F.softmax(attention_scores, dim=-1)
        attention_output = torch.matmul(attention_weights, V)
        
        attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, -1, self.hidden_dim)
        attention_output = self.linear_layers[3](attention_output)
        attention_output = attention_output.view(batch_size, seq_length, seq_length, attention_output.size(2))
        
        return attention_output

In [None]:

class EquiPredict(nn.Module):
    def __init__(self, node_features, edge_features, hidden_dim, input_dim, hidden_channel_dim, output_dim, device='cuda', act_fn=nn.SiLU(), layers=4, coords_weight=1.0, use_recurrent=False, normalize_diff=False, use_tanh=False, gnn_variant = 'GCN'):
        super(EquiPredict, self).__init__()
        self.hidden_dim = hidden_dim
        self.device = device
        self.layers = layers
        self.device = device 
        self.use_recurrent = use_recurrent

        self.node_embedding = nn.Linear(node_features, hidden_dim // 2)
        self.angle_embedding = nn.Linear(node_features, hidden_dim // 2)

        self.coord_transform = nn.Linear(input_dim, hidden_channel_dim, bias=False)
        self.velocity_transform = nn.Linear(input_dim, hidden_channel_dim, bias=False)

        self.use_dct = True
        self.validate_reasoning = True
        self.input_dim = input_dim
        self.output_dim = output_dim

        self.num_categories = 4
        self.tao = 1

        self.given_category = False
        if not self.given_category:
            self.edge_network, self.coord_network, self.node_network, self.category_network = self.init_mlps(hidden_dim, hidden_channel_dim, act_fn)
        
        
        # Choose GCN VARIANT 
        self.gnn_variant = gnn_variant 
        if self.gnn_variant == 'GCN':
            self.gcl = GCNConv(self.hidden_dim, self.hidden_dim)
        elif self.gnn_variant == 'GAT':
            self.gcl = GATConv(self.hidden_dim, self.hidden_dim)        
        
        # Feature Learning Layers
        self.gcls = nn.ModuleList([self.create_gcl_layer(edge_features, hidden_dim, input_dim, hidden_channel_dim, output_dim, act_fn = nn.SiLU(), coords_weight = 1.0, recurrent = False, norm_diff = False, tanh = False) for _ in range(layers - 1)])

        # Prediction Heads
        self.predict_heads = nn.ModuleList([self.create_gcl_layer(edge_features, hidden_dim, input_dim, hidden_channel_dim, output_dim, act_fn = nn.SiLU(), coords_weight = 1.0, recurrent = False , norm_diff = False, tanh = False) for _ in range(20)])
        self.predict_heads_linear = nn.ModuleList([nn.Linear(hidden_channel_dim, output_dim, bias=False) for _ in range(20)])

        self.to(self.device)
    
    def init_mlps(self, hidden_dim, hidden_channel_dim, act_fn):
        """ init_mlp defined mlps that will be used later """
        edge_network = nn.Sequential(
            nn.Linear(hidden_dim * 2 + hidden_channel_dim * 2, hidden_dim),
            act_fn,
            nn.Linear(hidden_dim, hidden_dim),
            act_fn
        )

        coord_network = nn.Sequential(
            nn.Linear(hidden_channel_dim * 2, hidden_dim),
            act_fn,
            nn.Linear(hidden_dim, hidden_channel_dim * 2),
            act_fn
        )

        node_network = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            act_fn,
            nn.Linear(hidden_dim, hidden_dim),
            act_fn
        )

        category_network = nn.Sequential(
            nn.Linear(hidden_dim * 2 + hidden_channel_dim * 2, hidden_dim),
            act_fn,
            nn.Linear(hidden_dim, self.num_categories),
            act_fn
        )

        return edge_network, coord_network, node_network, category_network
    
    def create_gcl_layer(self, in_edge_nf, hidden_nf, in_channel, hid_channel, out_channel, act_fn, coords_weight, recurrent, norm_diff, tanh):
        return Feature_learning_layer(hidden_nf, hidden_nf, hidden_nf, in_channel, hid_channel, out_channel, edges_in_d=0, nodes_att_dim=0, act_fn=nn.ReLU, recurrent=True, coords_weight=1.0, attention=True, norm_diff=False, tanh=False, apply_reasoning=False, input_reasoning=True, category_num=self.num_categories)
    
    def compute_dct_matrix(self, N, x):
        """ compute_dct_matrix compute the Discrete Cosine Transform (DCT) matrix and its inverse (IDCT). The DCT matrix is used to transform data into the frequency domain, while the IDCT matrix is used
            to transform data back from the frequency domain to the spatial domain"""
        dct_matrix = np.eye(N)
        for k in range(N):
            for i in range(N):
                weight = np.sqrt(2 / N)
                if k == 0:
                    weight = np.sqrt(1 / N)
                dct_matrix[k, i] = weight * np.cos(np.pi * (i + 0.5) * k / N)
        idct_matrix = np.linalg.inv(dct_matrix)
        dct_matrix = torch.from_numpy(dct_matrix).type_as(x)
        idct_matrix = torch.from_numpy(idct_matrix).type_as(x)
        return dct_matrix, idct_matrix
    
    def apply_dct(self, coords, vel, valid_agent_mask, agent_num, num_valid, batch_size):
        """ apply_dct applies Discrete Cosine Transform (DCT) to coordinates and velocities"""
        coords_center = torch.mean(coords * valid_agent_mask, dim=(1, 2), keepdim=True) * (agent_num / num_valid[:, None, None, None])
        coords -= coords_center
        dct_m, idct_m = self.compute_dct_matrix(self.input_dim, coords), self.compute_dct_matrix(self.output_dim, coords)
        dct_m, idct_m = dct_m[0].repeat(batch_size, agent_num, 1, 1), idct_m[1].repeat(batch_size, agent_num, 1, 1)
        coords, vel = torch.matmul(dct_m, coords), torch.matmul(dct_m, vel)
        return coords, coords_center, vel, idct_m
    

    def compute_interaction_categories(self, node_features, coords, valid_mask):
        """ compute_interaction_categories computes interaction categories between nodes based on their features and coordinates """
        
        batch_size, num_agents, _, _ = coords.shape
        node_features_1 = node_features[:, :, None, :].repeat(1,1,num_agents,1)
        node_features_2 = node_features[:, None, :, :].repeat(1,num_agents,1,1)
        
        # Calculate coordinate differences and distances
        coord_diff = coords[:, :, None, :, :] - coords[:, None, :, :, :]
        distances = torch.norm(coord_diff, dim=-1)
        distances = self.coord_network(distances)
        
        # Initialize edge features
        edge_features = self.message_passing(node_features_1, node_features_2, distances)

        # Compute interaction categories through message passing
        interaction_categories = self.message_aggregation(node_features, edge_features, distances, valid_mask, num_agents, batch_size)

        return interaction_categories

    def message_passing(self, node_features_1, node_features_2, distances):
        """ message_passing performs message passing to compute edge features using multi-head attention """
    
        edge_input = torch.cat([node_features_1, node_features_2, distances], dim=-1)
        
        # Apply multi-head attention
        multihead_attention = MultiHeadAttention(input_dim=edge_input.size(-1), hidden_dim=self.hidden_dim, device = self.device)
        edge_features = multihead_attention(edge_input, edge_input, edge_input)  # Self-attention
        

        return edge_features

    def message_aggregation(self, node_features, computed_edge_features, distances, valid_mask, num_agents, batch_size):
        """ message_aggregation aggregates edge features to update node representations and compute interaction categories """
        # Prepare mask to ignore self-loops
        mask = (torch.ones((num_agents, num_agents)) - torch.eye(num_agents)).type_as(computed_edge_features)
        mask = mask[None, :, :, None].repeat(batch_size, 1, 1, 1)

        # Aggregate edge features and update node representations
        updated_node_features = self.node_network(torch.cat([node_features, torch.sum(valid_mask * mask * computed_edge_features, dim=2)], dim=-1))

        # Prepare updated node features for interaction computation
        updated_node_features_1 = updated_node_features[:, :, None, :].repeat(1,1, num_agents,1)
        updated_node_features_2 = updated_node_features[:, None, :, :].repeat(1, num_agents, 1,1)
        updated_edge_input = torch.cat([updated_node_features_1, updated_node_features_2, distances], dim=-1)

        # Compute interaction categories
        interaction_categories = F.softmax(self.category_network(updated_edge_input) / self.tao, dim=-1)

        return interaction_categories

    def create_valid_mask(self, num_valid, num_agents):
        """ create_valid_mask create a mask to indicate valid interactions between agents in a 2D grid """
        batch_size = num_valid.shape[0]
        valid_mask = torch.zeros((batch_size, num_agents, num_agents))
        for i in range(batch_size):
            valid_mask[i, :num_valid[i], :num_valid[i]] = 1
        return valid_mask.unsqueeze(-1)

    def create_valid_mask2(self, num_valid, num_agents):
        """ create_valid_mask2 creates a mask to indicate valid agents in a 1D vector."""
        batch_size = num_valid.shape[0]
        valid_mask = torch.zeros((batch_size, num_agents))
        for i in range(batch_size):
            valid_mask[i, :num_valid[i]] = 1
        return valid_mask.unsqueeze(-1).unsqueeze(-1)

    def forward(self, node_features, coords, velocities, num_valid, edge_attr=None):
        """ forward method is the core of the EqMotion model, explained in details above """
        
        # Defining previous velocities, used to compute the cosine of the angle between them and the current velocity vectors 
        velocities_pre = torch.zeros_like(velocities)
        velocities_pre[:, :, 1:] = velocities[:, :, :-1]
        velocities_pre[:, :, 0] = velocities[:, :, 0]
        EPS = 1e-6
        vel_cosangle = torch.sum(velocities_pre * velocities, dim=-1) / ((torch.norm(velocities_pre, dim=-1) + EPS) * (torch.norm(velocities, dim=-1) + EPS))
        vel_angle = torch.acos(torch.clamp(vel_cosangle, -1, 1))

        batch_size, num_agents, _, _ = coords.shape

        valid_agent_mask = self.create_valid_mask2(num_valid, num_agents).type_as(node_features)   # It indicates which agents are valid in the current batch, helping in filtering out invalid data 

        # Applying DCT transform to coordinates and velocities to transform them into the frequency domain 
        if self.use_dct:
            coords, coords_center, velocities, idct_matrix = self.apply_dct(coords, velocities, valid_agent_mask, num_agents, num_valid, batch_size)

        # Creating embedding of node features and velocity angles using learned linear transformations 
        node_features = self.node_embedding(node_features)
        vel_angle_embedding = self.angle_embedding(vel_angle)
        # node_features is the feature vector that will be passed to the feature learning layer. It is a combination of the node features and the velocity angle embeddings 
        node_features = torch.cat([node_features, vel_angle_embedding], dim=-1)

        # Normalizing and transforming the coordinates and velocities to account for batch-wise variantions and prepare them for further preprocessing 
        coords_mean = torch.mean(torch.mean(coords * valid_agent_mask, dim=-2, keepdim=True), dim=-3, keepdim=True) * (num_agents / num_valid[:, None, None, None])
        coords = self.coord_transform((coords - coords_mean).transpose(2, 3)).transpose(2, 3) + coords_mean
        velocities = self.velocity_transform(velocities.transpose(2, 3)).transpose(2, 3)
        coord_velocity_combined = torch.cat([coords, velocities], dim=-2)

        valid_mask = self.create_valid_mask(num_valid, num_agents).type_as(node_features)
        
        # Determines the interaction categories for each edge. If categories are predefined, they are processed accordingly; otherwise, they are computed based on node features and edge attributes.
        category = F.one_hot(((edge_attr / 2) + 1).long(), num_classes=self.num_categories) if self.given_category else self.compute_interaction_categories(node_features, coord_velocity_combined, valid_mask)
        
        # Iteratively applying Feature Learning Layers to update the node features and coordinates based on the interaction categories
        category_per_layer = []
        for gcl in self.gcls:
            node_features, coords, _ = gcl(node_features, coords, velocities, valid_mask, valid_agent_mask, num_valid, category=category)
            category_per_layer.append(category)
        
        # Creating an index for all pairs of nodes to define edges and applies the final graph convolution layer to the node features using these edges 
        edge_index = torch.combinations(torch.arange(num_agents), r=2).t().to(self.device)
        edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)
        node_features = self.gcl(node_features, edge_index)
        
        # If recurrent processing is enabled, node features are processed trough an LSTM layer to capture temporal dependencies
        if self.use_recurrent:
            node_features = node_features.view(batch_size * num_agents, -1, self.hidden_dim)
            node_features, _ = self.lstm(node_features)
            node_features = node_features.view(batch_size, num_agents, -1, self.hidden_dim)
        
        # Useing multiple prediction heads to generate outputs. Each head processes the node features to predict the coordinates, and the results are adjusted for mean and combined 
        all_out = []
        for i, (head, head_linear) in enumerate(zip(self.predict_heads, self.predict_heads_linear)):
            _, out, _ = head(node_features, coords, velocities, valid_mask, valid_agent_mask, num_valid, category=None)
            out_mean = torch.mean(torch.mean(out * valid_agent_mask, dim=-2, keepdim=True), dim=-3, keepdim=True) * (num_agents / num_valid[:, None, None, None])
            out = head_linear((out - out_mean).transpose(2, 3)).transpose(2, 3) + out_mean
            all_out.append(out[:, :, None, :, :])
        
        # Concatenating the outputs from all prediction heads and reshapes them to the final output format 
        coords = torch.cat(all_out, dim=2).view(batch_size, num_agents, 20, self.output_dim, -1)

        # If DCT was applied initially, it performs the inverse DCT to transform the coordinates back to the original domain 
        if self.use_dct:
            idct_matrix = idct_matrix[:, :, None, :, :]
            coords = torch.matmul(idct_matrix, coords)
            coords = coords + coords_center.unsqueeze(2)
        
        # Returning final predicted coordinates. If validate_reasoning is enabled, it also returns the interaction categories computed during the forward pass 
        if self.validate_reasoning:
            return coords, category_per_layer
        else:
            return coords, None 

### Training 

In [None]:
# Utility Functions
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

def lr_decay(optimizer, lr_now, gamma):
    lr_new = lr_now * gamma
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr_new
    return lr_new

# Save Model Checkpoint
def save_checkpoint(epoch, model, optimizer, model_save_dir, subset, best=False):
    state = {'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}
    file_name = f"{subset}_ckpt_best.pth.tar" if best else f"{subset}_ckpt_{epoch}.pth.tar"
    file_path = os.path.join(model_save_dir, file_name)
    torch.save(state, file_path)
    
# Mask Function for Training
def get_valid_mask2(num_valid, agent_num):
    batch_size = num_valid.shape[0]
    valid_mask = torch.zeros((batch_size, agent_num))
    for i in range(batch_size):
        valid_mask[i, :num_valid[i]] = 1
    return valid_mask.unsqueeze(-1).unsqueeze(-1)

# Function to clear cache
def clear_cache():
    torch.cuda.empty_cache()
    
def load_and_plot_results(res_path):
    # Load the results from file
    with open(res_path, 'rb') as f:
        data = pickle.load(f)
    
    train_preds = data['results']['train_preds']
    train_gt = data['results']['train_gt']
    train_losses = data['results']['train_losses']
    val_losses = data['results']['val_losses']

    # Plotting losses
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training and Validation Losses')
    plt.show()

    return train_preds, train_gt, train_losses, val_losses 

# Creating animation to show differences between predictions and ground truth labels 
def create_animation(train_preds, train_gt, train_losses):
    fig = plt.figure(figsize=(12, 8))
    ax = fig.add_subplot()
    plt.axis('off')

    def animate(i):
        ax.clear()
        preds = train_preds[i]
        gt = train_gt[i]
        x_preds = preds[0, 0, :, 0]   # Adjust indices as per your data shape
        y_preds = preds[0, 0, :, 1]
        x_gt = gt[0, 0, :, 0]
        y_gt = gt[0, 0, :, 1]
        ax.scatter(x_preds, y_preds, c='blue', label='Predictions', s=50)
        ax.scatter(x_gt, y_gt, c='red', label='Ground Truth', s=50)
        plt.title(f'Epoch {i} | Train Loss: {train_losses[i]:.5f}', fontsize=18, pad=20)
        plt.legend()
        plt.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)

    anim = animation.FuncAnimation(fig, animate, frames=len(train_losses), interval=800, repeat=True)
    html = HTML(anim.to_html5_video())
    plt.close()
    return html

In [None]:
def train(model, optimizer, epoch, loader, device, backprop=True):
    all_predictions = []
    all_gt = []
    
    start_time = time.time()
    model.train() if backprop else model.eval()
    res = {'epoch': epoch, 'loss': 0, 'counter': 0}
    
    for batch_idx, data in enumerate(loader):
        if data is not None:
            pre_data, fut_data, num_valid = data
            pre_data, fut_data, num_valid = pre_data.to(device), fut_data.to(device), num_valid.to(device).type(torch.int)
            
            vel = torch.zeros_like(pre_data).to(device)
            vel[:, :, 1:] = pre_data[:, :, 1:] - pre_data[:, :, :-1]
            vel[:, :, 0] = vel[:, :, 1]
            
            batch_size, agent_num, length, _ = pre_data.size()
            optimizer.zero_grad()
            
            nodes = torch.sqrt(torch.sum(vel ** 2, dim=-1)).detach()
            loc_pred, category = model(nodes, pre_data.detach(), vel, num_valid, agent_num)
            fut_data = fut_data[:, :, None, :, :]
            
            if args.supervise_all:
                mask = get_valid_mask2(num_valid, pre_data.size(1)).to(device)[:, :, None, :, :]
                loss = torch.mean(torch.min(torch.mean(torch.norm(mask * (loc_pred - fut_data), dim=-1), dim=3), dim=2)[0])
                # if args.supervise_all, the loss is the mean of the minimum norms of the difference between predicted and true locations
            else:
                loss = torch.mean(torch.min(torch.mean(torch.norm(loc_pred[:, 0:1] - fut_data[:, 0:1], dim=-1), dim=-1), dim=-1)[0])
                # if not args.supervise_all, the loss is computed only for the first agent (ego agent)
            
            if backprop:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)  # Clip gradients to avoid explosion
                optimizer.step()

            res['loss'] += loss.item() * batch_size
            res['counter'] += batch_size
            
            all_predictions.append(loc_pred.detach().cpu())
            all_gt.append(fut_data.detach().cpu())
    
    avg_loss = res['loss'] / res['counter']
    epoch_time = time.time() - start_time
    
    all_predictions = torch.cat(all_predictions, dim=0)
    all_gt = torch.cat(all_gt, dim=0)

    print(f"{'==> ' if not backprop else ''}epoch {epoch} avg train loss: {avg_loss:.5f}, time taken: {epoch_time:.2f} seconds")
    return avg_loss, all_predictions, all_gt


In [None]:
def validate(model, optimizer, epoch, loader, device):
    start_time = time.time()
    model.eval()
    res = {'epoch': epoch, 'loss': 0, 'counter': 0, 'ade': 0}
    
    with torch.no_grad():
        for batch_idx, data in enumerate(loader):
            if data is not None:
                pre_data, fut_data, num_valid = data
                pre_data, fut_data, num_valid = pre_data.to(device), fut_data.to(device), num_valid.to(device).type(torch.int)

                vel = torch.zeros_like(pre_data).to(device)
                vel[:, :, 1:] = pre_data[:, :, 1:] - pre_data[:, :, :-1]
                vel[:, :, 0] = vel[:, :, 1]
                
                batch_size, agent_num, length, _ = pre_data.size()
                optimizer.zero_grad()
                
                nodes = torch.sqrt(torch.sum(vel ** 2, dim=-1)).detach()
                loc_pred, category_list = model(nodes, pre_data.detach(), vel, num_valid, agent_num)


                loc_pred = loc_pred.cpu().numpy()
                fut_data = fut_data.cpu().numpy()[:, :, None, :, :]
                ade = np.mean(np.min(np.mean(np.linalg.norm(loc_pred[:, 0:1] - fut_data[:, 0:1], axis=-1), axis=-1), axis=-1))
                # ade measures the average distance between predicted and ground truth locations over all time steps. It is computed as the mean of the minimum error across time steps 
                fde = np.mean(np.min(np.mean(np.linalg.norm(loc_pred[:, 0:1, :, -1:] - fut_data[:, 0:1, :, -1:], axis=-1), axis=-1), axis=-1))
                # fde measures the average distance between predicted and ground truth locations over all time steps. It is computed as the mean of the minimum error across time steps
                
                res['loss'] += fde*batch_size
                res['ade'] += ade*batch_size
                res['counter'] += batch_size
                
    res['ade'] *= args.test_scale
    res['loss'] *= args.test_scale
    epoch_time = time.time() - start_time
    print(f"==> epoch {epoch} avg val loss: {res['loss'] / res['counter']:.5f} ade: {res['ade'] / res['counter']:.5f}, time taken: {epoch_time:.2f} seconds")
    
    return res['loss'] / res['counter'], res['ade'] / res['counter']

In [None]:
def test(model, loader, device):
    start_time = time.time()
    model.eval()
    res = {'loss': 0, 'counter': 0, 'ade': 0}

    with torch.no_grad():
        for batch_idx, data in enumerate(loader):
            if data is not None:
                pre_data, fut_data, num_valid = data
                pre_data, fut_data, num_valid = pre_data.to(device), fut_data.to(device), num_valid.to(device).type(torch.int)

                vel = torch.zeros_like(pre_data).to(device)
                vel[:, :, 1:] = pre_data[:, :, 1:] - pre_data[:, :, :-1]
                vel[:, :, 0] = vel[:, :, 1]
                
                batch_size, agent_num, length, _ = pre_data.size()
                #optimizer.zero_grad()
                
                nodes = torch.sqrt(torch.sum(vel ** 2, dim=-1)).detach()
                loc_pred, category_list = model(nodes, pre_data.detach(), vel, num_valid, agent_num)
                
                loc_pred = loc_pred.cpu().numpy()
                fut_data = fut_data.cpu().numpy()[:, :, None, :, :]
                ade = np.mean(np.min(np.mean(np.linalg.norm(loc_pred[:, 0:1] - fut_data[:, 0:1], axis=-1), axis=-1), axis=-1))
                # ade measures the average distance between predicted and ground truth locations over all time steps. It is computed as the mean of the minimum error across time steps
                fde = np.mean(np.min(np.mean(np.linalg.norm(loc_pred[:, 0:1, :, -1:] - fut_data[:, 0:1, :, -1:], axis=-1), axis=-1), axis=-1))
                # fde measures the average distance between predicted and ground truth locations over all time steps. It is computed as the mean of the minimum error across time steps
                
                res['loss'] += fde*batch_size
                res['ade'] += ade*batch_size
                res['counter'] += batch_size
                
    res['ade'] *= args.test_scale
    res['loss'] *= args.test_scale
    epoch_time = time.time() - start_time
    print(f"Test avg loss: {res['loss'] / res['counter']:.5f} ade: {res['ade'] / res['counter']:.5f}, time taken: {epoch_time:.2f} seconds")

    
    return  res['loss'] / res['counter'], res['ade'] / res['counter']

In [None]:
device = torch.device("cuda" if args.cuda else "cpu")

In [None]:
# Final results function
def final_results():
    # Seed setup
    if args.seed >= 0:
        seed = args.seed
        setup_seed(seed)
    else:
        seed = random.randint(0, 1000)
        setup_seed(seed)
    print('The seed is:', seed)

    # Model setup
    model = EquiPredict(
        node_features=args.past_length, 
        edge_features=2, 
        hidden_dim=args.nf, 
        input_dim=args.past_length, 
        hidden_channel_dim=args.channels, 
        output_dim=args.future_length, 
        device=device, 
        act_fn=nn.SiLU(), 
        layers=args.n_layers, 
        coords_weight=1.0, 
        use_recurrent=False, 
        normalize_diff=False, 
        use_tanh=args.tanh
    )

    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)

    results = {'epochs': [], 'train_losses': [], 'val_losses': [], 'train_preds': [], 'train_gt': []}
    best_val_loss = 1e8
    best_val_ade = 1e8
    best_epoch = 0
    lr_now = args.lr

    for epoch in range(0, args.epochs):
        # Apply learning rate decay if specified
        if args.apply_decay:
            if epoch % args.epoch_decay == 0 and epoch > 0:
                lr_now = lr_decay(optimizer, lr_now, args.lr_gamma)
        
        # Train the model
        train_loss, train_preds, train_gt = train(model, optimizer, epoch, loader_train, device)
        results['train_losses'].append(train_loss)
        results['train_preds'].append(train_preds)
        results['train_gt'].append(train_gt)
        print(f'Epoch {epoch}: Train Loss: {train_loss:.5f}')
        
        # Evaluate on validation set
        val_loss, val_ade = validate(model, optimizer, epoch, loader_val, device)
        results['val_losses'].append(val_loss)
        print(f'Epoch {epoch}: Validation Loss: {val_loss:.5f}, Validation ADE: {val_ade:.5f}')
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_val_ade = val_ade
            best_epoch = epoch
            state = {'epoch': epoch,
                     'state_dict': model.state_dict(),
                     'optimizer': optimizer.state_dict()}
            file_path = os.path.join(args.model_save_dir, f'{args.subset}_ckpt_best.pth')
            torch.save(state, file_path)

        clear_cache()

        # Save intermediate results to reduce memory usage
        if epoch % 5 == 0:  # Adjust frequency as needed
            results_path = os.path.join('/kaggle/working/saved_models', f'training_results_epoch_{epoch}.pkl')
            with open(results_path, 'wb') as f:
                pickle.dump({'results': results}, f)
            clear_cache()

    # Save the final model at the end of the training and validation phases 
    state = {'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}
    file_path = os.path.join(args.model_save_dir, f'{args.subset}_final.pth')
    torch.save(state, file_path)
    
    # Load the trained model
    model_trained_path = file_path 
    print('Loading model from:', model_trained_path)
    model_ckpt = torch.load(model_trained_path)
    model.load_state_dict(model_ckpt['state_dict'], strict=False)
    test_loss, ade = test(model, loader_test, device)
    print('ADE final:', ade, 'FDE final:', test_loss)
    
    # Save the final results dictionary
    results_path = os.path.join('/kaggle/working/saved_models', 'training_results.pkl')
    with open(results_path, 'wb') as f:
        pickle.dump({'results': results}, f)

    clear_cache()
    return results

In [None]:
results = final_results()

In [None]:
# Save the results dictionary and other outputs
res_path = os.path.join('/kaggle/working/saved_models', f'training_results_outside.pkl')
with open(res_path, 'wb') as f:
        pickle.dump({
            'results': results
        }, f)

In [None]:
train_preds, train_gt, train_losses, val_losses = load_and_plot_results(res_path)

In [None]:
def create_animation(train_preds, train_gt, train_losses):
    fig = plt.figure(figsize=(12, 8))
    ax = fig.add_subplot()
    plt.axis('off')

    def animate(i):
        ax.clear()
        preds = train_preds[i]
        gt = train_gt[i]
        x_preds = preds[0, 0, :, 0]   # Adjust indices as per your data shape
        y_preds = preds[0, 0, :, 1]
        x_gt = gt[0, 0, :, 0]
        y_gt = gt[0, 0, :, 1]
        ax.scatter(x_preds, y_preds, c='blue', label='Predictions', s=50)
        ax.scatter(x_gt, y_gt, c='red', label='Ground Truth', s=50)
        plt.title(f'Epoch {i} | Train Loss: {train_losses[i]:.5f}', fontsize=18, pad=20)
        plt.legend()
        plt.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)

    anim = animation.FuncAnimation(fig, animate, frames=len(train_losses), interval=800, repeat=True)
    html = HTML(anim.to_html5_video())
    plt.close()
    return html

In [None]:
animation_html = create_animation(train_preds, train_gt, train_losses)
display(animation_html)