In [None]:
import csv
import gc
import json
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from sklearn.preprocessing import MinMaxScaler

In [None]:
class TransformerDataset(Dataset):
    """
    A PyTorch Dataset class for Transformer models.
    This dataset provides pre-processed source and target sequences for training.
    """

    def __init__(self, data):
        """
        Initialize the dataset.

        Args:
            data (dict): A dictionary containing:
                - 'src': Source sequence tensor (shape: [num_samples, src_length, num_features]).
                - 'tgt': Target sequence tensor (shape: [num_samples, tgt_length, num_features]).
        """
        self.data = data  # Store the preprocessed source and target sequences.

    def __len__(self):
        """
        Return the number of samples in the dataset.
        """
        return self.data["src"].shape[0]

    def __getitem__(self, idx):
        """
        Retrieve a single sample from the dataset.

        Args:
            idx (int): The index of the sample to retrieve.

        Returns:
            data (dict): A dictionary containing:
                - 'src': Source sequence tensor (shape: [src_length, num_features]).
                - 'tgt': Target sequence tensor (shape: [tgt_length, num_features]).
        """
        src, tgt = self.data['src'], self.data['tgt']
        data = {
            "src": src[idx, :, :].type(torch.float16),  # Source sequence as half-precision tensor.
            "tgt": tgt[idx, :, :].type(torch.float16)   # Target sequence as half-precision tensor.
        }
        return data

In [None]:
class STARDataset(Dataset):
    """
    A PyTorch Dataset class for training with preprocessed trajectory and distance data.
    """
    def __init__(self, data, max_num_agents):
        self.data=data
        self.max_num_agents = max_num_agents

    def __len__(self):
        """
        Initialize the dataset.

        Args:
            data (dict): A dictionary containing:
                - 'src': Source trajectory tensor (shape: [num_samples, src_length, num_features]).
                - 'tgt': Target trajectory tensor (shape: [num_samples, tgt_length, num_features]).
                - 'distance': Distances to agents and objects (shape: [num_samples, src_length, num_distances]).
                - 'type': Types of agents and objects (shape: [num_samples, src_length, num_types]).
        """
        return self.data["src"].shape[0]

    def __getitem__(self, idx):
        """
        Retrieve a single sample from the dataset.

        Args:
            idx (int): The index of the sample to retrieve.

        Returns:
            data (dict): A dictionary containing:
                - 'src': Source trajectory tensor.
                - 'tgt': Target trajectory tensor.
                - 'distance': Distance tensor.
                - 'type': Type tensor.
        """
        src, tgt, distance, type = self.data['src'], self.data['tgt'], self.data['distance'], self.data['type']
        data = {
          "src": src[idx, :, :].type(torch.float32),
          "tgt": tgt[idx, :, :].type(torch.float32),
          "distance": distance[idx, :, :self.max_num_agents].type(torch.float32),
          "type": type[idx, :, :self.max_num_agents].type(torch.long)
        }
        return data

In [None]:
class SAESTARDataset(Dataset):
    """
    A PyTorch Dataset class for training with preprocessed trajectory and distance data.
    """

    def __init__(self, data : dict):
        """
        Initialize the dataset.

        Args:
            data (dict): A dictionary containing:
                - 'src': Source trajectory tensor (shape: [num_samples, src_length, num_features]).
                - 'tgt': Target trajectory tensor (shape: [num_samples, tgt_length, num_features]).
                - 'distance': Distances to agents and objects (shape: [num_samples, src_length, num_distances]).
                - 'type': Types of agents and objects (shape: [num_samples, src_length, num_types]).
        """
        self.data = data

    def __len__(self):
        """
        Return the number of samples in the dataset.
        """
        return self.data["src"].shape[0]

    def __getitem__(self, idx):
        """
        Retrieve a single sample from the dataset.

        Args:
            idx (int): The index of the sample to retrieve.

        Returns:
            data (dict): A dictionary containing:
                - 'src': Source trajectory tensor.
                - 'tgt': Target trajectory tensor.
                - 'distance': Distance tensor.
                - 'type': Type tensor.
        """
        src, tgt, distance, dist_type = self.data['src'], self.data['tgt'], self.data['distance'], self.data['type']
        data = {
            "src": src[idx, :, :].type(torch.float32),
            "tgt": tgt[idx, :, :].type(torch.float32),
            "distance": distance[idx, :, :].type(torch.float32),
            "type": dist_type[idx, :, :].type(torch.long)
        }
        return data

In [None]:
class TransformerBaseEmbedding(nn.Module):
    """
    A PyTorch module for creating input embeddings for a transformer model.
    
    This class embeds input sequences using linear projections for specific features and 
    adds positional encoding to the embeddings. It supports separate embeddings for 
    source (`src`) and target (`tgt`) sequences.

    Attributes:
        embedding_dims (list[int]): List of embedding dimensions for each feature.
        size (int): Total embedding size (sum of `embedding_dims`).
        linear_layers_tgt (nn.ModuleList): Linear layers for projecting the first two features of `tgt` inputs.
        linear_layers_src (nn.ModuleList): Linear layers for projecting the first two features of `src` inputs.
        positional_encoding_src (torch.Tensor): Positional encoding tensor for `src` inputs.
        positional_encoding_tgt (torch.Tensor): Positional encoding tensor for `tgt` inputs.
    """
    
    def __init__(self, embedding_dims, sequence_length_src=10, sequence_length_tgt=40):
        """
        Initializes the embedding class with specified dimensions and sequence lengths.

        Args:
            embedding_dims (list[int]): Embedding dimensions for the features.
            sequence_length_src (int, optional): Length of the source sequence. Defaults to 10.
            sequence_length_tgt (int, optional): Length of the target sequence. Defaults to 40.
        """
        super(TransformerBaseEmbedding, self).__init__()
        self.embedding_dims = embedding_dims
        self.size = sum(embedding_dims)

        # Linear layers for projecting the first two features of `tgt`
        self.linear_layers_tgt = nn.ModuleList([
            nn.Linear(1, embedding_dim)
            for embedding_dim in embedding_dims[:2]
        ])

        # Linear layers for projecting the first two features of `src`
        self.linear_layers_src = nn.ModuleList([
            nn.Linear(1, embedding_dim)
            for embedding_dim in embedding_dims[:2]
        ])

        # Positional encodings for `src` and `tgt` sequences
        self.positional_encoding_src = self._generate_positional_encoding(self.size, sequence_length_src).cuda()
        self.positional_encoding_tgt = self._generate_positional_encoding(self.size, sequence_length_tgt).cuda()

    def _generate_positional_encoding(self, embedding_size, sequence_length):
        """
        Generates positional encoding for a sequence of a given length and embedding size.

        Args:
            embedding_size (int): Dimensionality of the embedding space.
            sequence_length (int): Length of the sequence.

        Returns:
            torch.Tensor: A tensor containing positional encodings with shape (1, sequence_length, embedding_size).
        """
        positional_encoding = torch.zeros(sequence_length, embedding_size)
        position = torch.arange(0, sequence_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embedding_size, 2).float() * (-torch.log(torch.tensor(10000.0)) / embedding_size))

        # Apply sine and cosine to alternate dimensions
        positional_encoding[:, 0::2] = torch.sin(position * div_term)
        positional_encoding[:, 1::2] = torch.cos(position * div_term)

        return positional_encoding.unsqueeze(0)  # Add batch dimension

    def forward(self, input_tensor, is_src=True):
        """
        Forward pass for embedding input sequences with optional source or target embeddings.

        Args:
            input_tensor (torch.Tensor): Input tensor of shape (batch_size, seq_length, num_features).
            is_src (bool, optional): If `True`, processes as source (`src`) sequence; otherwise as target (`tgt`).
                                     Defaults to `True`.

        Returns:
            torch.Tensor: Embedded tensor of shape (batch_size, seq_length, total_embedding_size).
        """
        batch_size, seq_length, num_features = input_tensor.size()

        # Initialize an empty tensor for the embedded output
        embedded_tensor = torch.zeros(batch_size, seq_length, self.size).cuda()

        # Process source (`src`) sequences
        if is_src:
            start_index = 0
            for i, linear_layer in enumerate(self.linear_layers_src):
                # Apply the linear layer to the i-th feature
                linear_output = linear_layer(input_tensor[:, :, i].view(-1, 1)).cuda()

                # Determine the embedding dimension for the current feature
                embedding_dim = self.embedding_dims[i]
                end_index = start_index + embedding_dim

                # Place the linearly projected feature into the correct slice of the output tensor
                embedded_tensor[:, :, start_index:end_index] = linear_output.view(batch_size, seq_length, embedding_dim).cuda()
                start_index = end_index

        # Process target (`tgt`) sequences
        else:
            start_index = 0
            for i, linear_layer in enumerate(self.linear_layers_tgt):
                # Apply the linear layer to the i-th feature
                linear_output = linear_layer(input_tensor[:, :, i].view(-1, 1)).cuda()

                # Determine the embedding dimension for the current feature
                embedding_dim = self.embedding_dims[i]
                end_index = start_index + embedding_dim

                # Place the linearly projected feature into the correct slice of the output tensor
                embedded_tensor[:, :, start_index:end_index] = linear_output.view(batch_size, seq_length, embedding_dim).cuda()
                start_index = end_index

        # Add positional encoding to the embeddings
        positional_encoding = self.positional_encoding_src if is_src else self.positional_encoding_tgt
        positional_encoding = positional_encoding.cuda()
        embedded_tensor = embedded_tensor + positional_encoding

        # Return the final embedded tensor
        return embedded_tensor.view(batch_size, seq_length, -1).cuda()

In [None]:
class TransformerEmbedding(nn.Module):
    """
    A PyTorch module for embedding input sequences with source (`src`) and target (`tgt`) features 
    using linear projections for numerical features and embedding layers for categorical features. 
    Additionally, positional encodings are added to the embeddings.
    
    Attributes:
        size (int): Total size of embeddings (sum of src_dims).
        src_dims (list[int]): List of embedding dimensions for the source sequence.
        tgt_dims (list[int]): List of embedding dimensions for the target sequence.
        src_len (int): Length of the source sequence.
        tgt_len (int): Length of the target sequence.
        linear_layers_src (nn.ModuleList): List of linear layers for the source sequence features.
        linear_layers_tgt (nn.ModuleList): List of linear layers for the target sequence features.
        embedding_layer_src (nn.Embedding): Embedding layer for the last categorical feature of the source sequence.
        positional_encoding_src (torch.Tensor): Positional encoding for the source sequence.
        positional_encoding_tgt (torch.Tensor): Positional encoding for the target sequence.
    """
    
    def __init__(self, src_dims, tgt_dims, num_agents, sequence_length_src=10, sequence_length_tgt=40):
        """
        Initializes the embedding class with specified dimensions and sequence lengths for source 
        and target sequences.
        
        Args:
            src_dims (list[int]): List of embedding dimensions for the features in the source sequence.
            tgt_dims (list[int]): List of embedding dimensions for the features in the target sequence.
            num_agents (int): Number of distinct categories for the last feature in the source sequence.
            sequence_length_src (int, optional): Length of the source sequence. Defaults to 10.
            sequence_length_tgt (int, optional): Length of the target sequence. Defaults to 40.
        """
        super(TransformerEmbedding, self).__init__()
        
        self.size = sum(src_dims)  # Total embedding size (sum of src_dims)
        self.src_dims = src_dims
        self.tgt_dims = tgt_dims
        self.src_len = sequence_length_src
        self.tgt_len = sequence_length_tgt 
        
        # Linear layers for projecting the first 3 features of the source sequence
        self.linear_layers_src = nn.ModuleList([
            nn.Linear(1, embedding_dim)
            for embedding_dim in src_dims[:3]  # Use linear layers for the first 3 features of src
        ])
        
        # Linear layers for projecting the first 2 features of the target sequence
        self.linear_layers_tgt = nn.ModuleList([
            nn.Linear(1, embedding_dim)
            for embedding_dim in tgt_dims[:2]  # Use linear layers for the first 2 features of tgt
        ])

        # Embedding layer for the last feature of the source sequence (categorical)
        self.embedding_layer_src = nn.Embedding(num_agents, src_dims[-1]).cuda()

        # Generate positional encodings for both source and target sequences
        self.positional_encoding_src = self._generate_positional_encoding(self.size, sequence_length_src).cuda()
        self.positional_encoding_tgt = self._generate_positional_encoding(self.size, sequence_length_tgt).cuda()

    def _generate_positional_encoding(self, embedding_size, sequence_length):
        """
        Generates a sinusoidal positional encoding for a given sequence length and embedding size.
        
        Args:
            embedding_size (int): Dimensionality of the embedding space.
            sequence_length (int): Length of the sequence for which to generate the encoding.

        Returns:
            torch.Tensor: Positional encoding tensor with shape (1, sequence_length, embedding_size).
        """
        positional_encoding = torch.zeros(sequence_length, embedding_size)  # Initialize the encoding tensor
        position = torch.arange(0, sequence_length, dtype=torch.float).unsqueeze(1)  # Position tensor
        div_term = torch.exp(torch.arange(0, embedding_size, 2).float() * (-torch.log(torch.tensor(10000.0)) / embedding_size))  # Scaling factor
        
        # Apply sine and cosine functions to alternating dimensions of the encoding
        positional_encoding[:, 0::2] = torch.sin(position * div_term)  # Even indices
        positional_encoding[:, 1::2] = torch.cos(position * div_term)  # Odd indices
        
        return positional_encoding.unsqueeze(0)  # Add batch dimension and return

    def forward(self, input_tensor, is_src=True):
        """
        Forward pass for embedding input sequences, with separate handling for source and target sequences.
        
        Args:
            input_tensor (torch.Tensor): Input tensor of shape (batch_size, seq_length, num_features).
            is_src (bool, optional): If `True`, processes the source sequence; otherwise processes the target sequence. Defaults to `True`.

        Returns:
            torch.Tensor: The final embedded tensor of shape (batch_size, seq_length, total_embedding_size).
        """
        batch_size, seq_length, num_features = input_tensor.size()
        
        # Select the embedding dimensions based on whether the sequence is source or target
        embedding_dims = self.src_dims if is_src else self.tgt_dims

        # Initialize an empty tensor for the embedded output
        embedded_tensor = torch.zeros(batch_size, seq_length, self.size).cuda()

        # Handle source sequence embedding
        if is_src:
            start_index = 0
            for i, linear_layer in enumerate(self.linear_layers_src):
                # Linearly project each feature of the source sequence
                linear_output = linear_layer(input_tensor[:, :, i].view(-1, 1)).cuda()

                # Get the embedding dimension for the current feature
                embedding_dim = embedding_dims[i]
                end_index = start_index + embedding_dim

                # Place the linearly projected feature at the correct position in the embedded tensor
                embedded_tensor[:, :, start_index:end_index] = linear_output.view(batch_size, seq_length, embedding_dim).cuda()
                start_index = end_index

            # Embed the last feature of the source sequence (categorical)
            embedded_tensor[:, :, start_index:] = self.embedding_layer_src(input_tensor[:, :, 3].long()).cuda()

        # Handle target sequence embedding
        else:
            start_index = 0
            for i, linear_layer in enumerate(self.linear_layers_tgt):
                # Linearly project each feature of the target sequence
                linear_output = linear_layer(input_tensor[:, :, i].view(-1, 1)).cuda()

                # Get the embedding dimension for the current feature
                embedding_dim = embedding_dims[i]
                end_index = start_index + embedding_dim

                # Place the linearly projected feature at the correct position in the embedded tensor
                embedded_tensor[:, :, start_index:end_index] = linear_output.view(batch_size, seq_length, embedding_dim).cuda()
                start_index = end_index

        # Add positional encoding to the embeddings
        positional_encoding = self.positional_encoding_src if is_src else self.positional_encoding_tgt
        embedded_tensor = embedded_tensor + positional_encoding.cuda()

        # Reshape and return the final embedded tensor
        embedded_tensor = embedded_tensor.view(batch_size, seq_length, -1).cuda()
        return embedded_tensor.cuda()

In [None]:
class STAREmbedding(nn.Module):
    """
    A PyTorch module for embedding input sequences with source (`src`) and distance-related features.

    This class embeds source and distance sequences using linear projections for numerical features,
    embedding layers for categorical features, and applies positional encoding to the resulting embeddings.
    
    Attributes:
        num_features (int): Number of features in the source (`src`) sequence.
        distance_len (int): Number of distance-related features.
        type_len (int): Number of categorical features in the distance sequence.
        src_dims (list[int]): Embedding dimensions for the source sequence.
        dist_dims (list[int]): Embedding dimensions for the distance-related sequence.
        type_dims (list[int]): Embedding dimensions for categorical features in the distance sequence.
        size (int): Total embedding size (sum of source feature dimensions).
        linear_layers_src (nn.ModuleList): Linear layers for embedding source sequence features.
        linear_layers_distance (nn.ModuleList): Linear layers for embedding distance-related features.
        type_layers_distance (nn.ModuleList): Embedding layers for categorical features in the distance sequence.
        embedding_layer_src (nn.Embedding): Embedding layer for the last categorical feature of the source sequence.
        positional_encoding_src (torch.Tensor): Positional encoding for the source sequence.
        positional_encoding_distance (torch.Tensor): Positional encoding for the distance sequence.
    """
    
    def __init__(self, src_dims, dist_dims, type_dims, num_types, sequence_length=10):
        """
        Initializes the embedding class with specified dimensions and sequence length.

        Args:
            src_dims (list[int]): List of embedding dimensions for the features in the source sequence.
            dist_dims (list[int]): List of embedding dimensions for the features in the distance sequence.
            type_dims (list[int]): List of embedding dimensions for categorical features.
            num_types (int): Number of distinct categories for the last feature of the source sequence.
            sequence_length (int, optional): Length of the source and distance sequences. Defaults to 10.
        """
        super(STAREmbedding, self).__init__()
        
        # Initialize attributes for feature counts and dimensions
        self.num_features = len(src_dims)
        self.distance_len = len(dist_dims)
        self.type_len = len(type_dims)
        self.size = sum(src_dims)
        
        self.src_dims = src_dims
        self.dist_dims = dist_dims
        self.type_dims = type_dims

        # Linear layers for projecting source sequence features
        self.linear_layers_src = nn.ModuleList([
            nn.Linear(1, embedding_dim)
            for embedding_dim in src_dims[:self.num_features - 1]  # Exclude the last feature for different treatment
        ])
        
        # Linear layers for projecting distance-related features
        self.linear_layers_distance = nn.ModuleList([
            nn.Linear(1, embedding_dim)
            for embedding_dim in dist_dims[:]
        ])
        
        # Embedding layers for categorical features in the distance sequence
        self.type_layers_distance = nn.ModuleList([
            nn.Embedding(num_types, embedding_dim).cuda()
            for embedding_dim in type_dims[:]
        ])
        
        # Embedding layer for the last categorical feature of the source sequence
        self.embedding_layer_src = nn.Embedding(num_types, src_dims[-1]).cuda()

        # Generate positional encodings for source and distance sequences
        self.positional_encoding_src = self._generate_positional_encoding(self.size, sequence_length).cuda()
        self.positional_encoding_distance = self._generate_positional_encoding(self.size, sequence_length).cuda()

    def _generate_positional_encoding(self, embedding_size, sequence_length):
        """
        Generates a sinusoidal positional encoding for a given sequence length and embedding size.
        
        Args:
            embedding_size (int): Dimensionality of the embedding space.
            sequence_length (int): Length of the sequence for which to generate the encoding.

        Returns:
            torch.Tensor: Positional encoding tensor with shape (1, sequence_length, embedding_size).
        """
        positional_encoding = torch.zeros(sequence_length, embedding_size).cuda()
        position = torch.arange(0, sequence_length, dtype=torch.float).unsqueeze(1).cuda()
        div_term = torch.exp(torch.arange(0, embedding_size, 2).float().cuda() * (-torch.log(torch.tensor(10000.0)) / embedding_size)).cuda()
        
        # Apply sine and cosine functions to alternate dimensions of the encoding
        positional_encoding[:, 0::2] = torch.sin(position * div_term).cuda()
        positional_encoding[:, 1::2] = torch.cos(position * div_term).cuda()
        
        return positional_encoding.unsqueeze(0).cuda()  # Add batch dimension

    def forward(self, input_tensor, is_src=True, src_tensor=None, type_tensor=None):
        """
        Forward pass for embedding input sequences, with separate handling for source and distance sequences.
        
        Args:
            input_tensor (torch.Tensor): Input tensor of shape (batch_size, seq_length, num_features).
            is_src (bool, optional): If `True`, processes the source sequence; otherwise processes the distance sequence. Defaults to `True`.
            src_tensor (torch.Tensor, optional): Tensor containing source sequence features. Defaults to `None`.
            type_tensor (torch.Tensor, optional): Tensor containing categorical features for the distance sequence. Defaults to `None`.

        Returns:
            torch.Tensor: The final embedded tensor of shape (batch_size, seq_length, total_embedding_size).
        """
        batch_size, seq_length, num_features = input_tensor.size()
        
        # Select the embedding dimensions based on whether the sequence is source or distance
        embedding_dims = self.src_dims if is_src else self.dist_dims

        # Initialize an empty tensor for the embedded output
        embedded_tensor = torch.zeros(batch_size, seq_length, self.size).cuda()

        # Handle source sequence embedding
        if is_src:
            start_index = 0
            for i, linear_layer in enumerate(self.linear_layers_src):
                # Linearly project each feature of the source sequence
                linear_output = linear_layer(input_tensor[:, :, i].view(-1, 1)).cuda()

                # Get the embedding dimension for the current feature
                embedding_dim = embedding_dims[i]
                end_index = start_index + embedding_dim

                # Place the linearly projected feature at the correct position in the embedded tensor
                embedded_tensor[:, :, start_index:end_index] = linear_output.view(batch_size, seq_length, embedding_dim).cuda()
                start_index = end_index

            # Embed the last feature of the source sequence (categorical)
            embedded_tensor[:, :, start_index:] = self.embedding_layer_src(input_tensor[:, :, 3].long()).cuda()

        # Handle distance sequence embedding
        else:
            embedded_tensor = torch.zeros(batch_size, seq_length, self.size).cuda()
            start_index = 0
            for i, (linear_layer, embedding_layer) in enumerate(zip(self.linear_layers_distance, self.type_layers_distance)):
                # Linearly project each feature of the distance sequence
                linear_output = linear_layer(input_tensor[:, :, i].view(-1, 1)).cuda()

                # Get the embedding dimension for the current feature
                embedding_dim = embedding_dims[i]
                
                # Embed the categorical features in the distance sequence
                embedding_output = embedding_layer(type_tensor[:, :, i].long()).cuda()
                embedded_tensor[:, :, start_index:start_index + embedding_output.shape[-1]] = embedding_output
                start_index = start_index + embedding_output.shape[-1]

                # Place the linearly projected distance feature into the tensor
                end_index = start_index + embedding_dim
                embedded_tensor[:, :, start_index:end_index] = linear_output.view(batch_size, seq_length, embedding_dim).cuda()
                start_index = end_index

        # Add positional encoding to the embeddings
        positional_encoding = self.positional_encoding_src if is_src else self.positional_encoding_distance
        embedded_tensor = embedded_tensor + positional_encoding

        # Reshape the output tensor and return the final result
        embedded_tensor = embedded_tensor.view(batch_size, seq_length, -1).cuda()
        return embedded_tensor.cuda()

In [None]:
class SAESTAREmbedding(nn.Module):
    """
    A PyTorch module for embedding input sequences with source and distance-related features.
    
    This class embeds the source (`src`) and distance-related (`dist`) sequences using 
    separate linear projections for the numerical features, embedding layers for categorical 
    features, and adds positional encoding for both types of inputs.

    Attributes:
        num_features (int): Number of features in the source (`src`) input.
        dist_len (int): Number of distance-related features.
        type_len (int): Number of types in the categorical input.
        size (int): Total embedding size (sum of source feature dimensions).
        src_dims (list[int]): Embedding dimensions for the source sequence.
        dist_dims (list[int]): Embedding dimensions for the distance sequence.
        linear_layers_src (nn.ModuleList): Linear layers for embedding source features.
        linear_layers_distance (nn.ModuleList): Linear layers for embedding distance features.
        type_layers_distance (nn.ModuleList): Embedding layers for categorical features in the distance sequence.
        embedding_layer_src (nn.Embedding): Embedding layer for the last categorical feature of source sequence.
        positional_encoding_src (torch.Tensor): Positional encoding for the source sequence.
        positional_encoding_distance (torch.Tensor): Positional encoding for the distance sequence.
    """
    
    def __init__(self, src_dims, dist_dims, type_dims, num_types, num_types_dist, sequence_length_src=10):
        """
        Initializes the embedding class with specified dimensions and sequence lengths.
        
        Args:
            src_dims (list[int]): List of embedding dimensions for the features in the source sequence.
            dist_dims (list[int]): List of embedding dimensions for the features in the distance sequence.
            type_dims (list[int]): List of embedding dimensions for categorical features.
            num_types (int): Number of distinct categories for the last feature of the source sequence.
            num_types_dist (int): Number of distinct categories for the categorical features in the distance sequence.
            sequence_length_src (int, optional): Length of the source sequence. Defaults to 10.
        """
        super(SAESTAREmbedding, self).__init__()
        self.num_features = len(src_dims)
        self.dist_len = len(dist_dims)
        self.type_len = len(type_dims)
        self.size = sum(src_dims)
        self.src_dims = src_dims
        self.dist_dims = dist_dims

        # Linear layers for projecting the first features of the source sequence
        self.linear_layers_src = nn.ModuleList([
            nn.Linear(1, embedding_dim)
            for embedding_dim in src_dims[:self.num_features - 1]  # Exclude the last feature for different treatment
        ])

        # Linear layers for projecting distance-related features
        self.linear_layers_distance = nn.ModuleList([
            nn.Linear(1, embedding_dim)
            for embedding_dim in dist_dims[:]
        ])
        
        # Embedding layers for categorical features in the distance sequence
        self.type_layers_distance = nn.ModuleList([
            nn.Embedding(num_types_dist + num_types , embedding_dim).cuda()
            for embedding_dim in type_dims[:]
        ])

        # Embedding layer for the last categorical feature of the source sequence
        self.embedding_layer_src = nn.Embedding(num_types, src_dims[-1]).cuda()

        # Generate positional encoding for source and distance sequences
        self.positional_encoding_src = self._generate_positional_encoding(self.size, sequence_length_src).cuda()
        self.positional_encoding_distance = self._generate_positional_encoding(self.size, sequence_length_src).cuda()

    def _generate_positional_encoding(self, embedding_size, sequence_length):
        """
        Generates a sinusoidal positional encoding for a given sequence length and embedding size.

        Args:
            embedding_size (int): The dimensionality of the embedding space.
            sequence_length (int): The length of the sequence for which to generate the encoding.

        Returns:
            torch.Tensor: A tensor containing the positional encoding for the sequence.
        """
        positional_encoding = torch.zeros(sequence_length, embedding_size).cuda()
        position = torch.arange(0, sequence_length, dtype=torch.float).unsqueeze(1).cuda()
        div_term = torch.exp(torch.arange(0, embedding_size, 2).float().cuda() * (-torch.log(torch.tensor(10000.0)) / embedding_size)).cuda()
        
        # Apply sine and cosine functions to alternate dimensions of the encoding
        positional_encoding[:, 0::2] = torch.sin(position * div_term).cuda()
        positional_encoding[:, 1::2] = torch.cos(position * div_term).cuda()

        return positional_encoding.unsqueeze(0).cuda()  # Add batch dimension

    def forward(self, input_tensor, is_src=True, src_tensor=None, type_tensor=None):
        """
        Forward pass for embedding input sequences, with separate handling for source and distance sequences.

        Args:
            input_tensor (torch.Tensor): Input tensor of shape (batch_size, seq_length, num_features).
            is_src (bool, optional): If `True`, processes the source sequence; otherwise processes the distance sequence. Defaults to `True`.
            src_tensor (torch.Tensor, optional): Tensor containing source sequence features. Defaults to `None`.
            type_tensor (torch.Tensor, optional): Tensor containing categorical features for the distance sequence. Defaults to `None`.

        Returns:
            torch.Tensor: The final embedded tensor of shape (batch_size, seq_length, total_embedding_size).
        """
        batch_size, seq_length, num_features = input_tensor.size()

        # Select the embedding dimensions based on whether the sequence is source or distance
        embedding_dims = self.src_dims if is_src else self.dist_dims

        # Initialize an empty tensor for the embedded output
        embedded_tensor = torch.zeros(batch_size, seq_length, self.size).cuda()

        # Handle source sequence embedding
        if is_src:
            start_index = 0
            for i, linear_layer in enumerate(self.linear_layers_src):
                # Linearly project each feature of the source sequence
                linear_output = linear_layer(input_tensor[:, :, i].view(-1, 1)).cuda()

                # Get the embedding dimension for the current feature
                embedding_dim = embedding_dims[i]
                end_index = start_index + embedding_dim

                # Place the linearly projected feature at the correct position in the embedded tensor
                embedded_tensor[:, :, start_index:end_index] = linear_output.view(batch_size, seq_length, embedding_dim).cuda()
                start_index = end_index

            # Embed the last feature of the source sequence (categorical)
            embedded_tensor[:, :, start_index:] = self.embedding_layer_src(input_tensor[:, :, 3].long()).cuda()

        # Handle distance sequence embedding
        else:
            embedded_tensor = torch.zeros(batch_size, seq_length, self.size).cuda()
            start_index = 0
            for i, (linear_layer, embedding_layer) in enumerate(zip(self.linear_layers_distance, self.type_layers_distance)):
                # Linearly project each feature of the distance sequence
                linear_output = linear_layer(input_tensor[:, :, i].view(-1, 1)).cuda()

                # Get the embedding dimension for the current feature
                embedding_dim = embedding_dims[i]
                
                # Embed the categorical features in the distance sequence
                embedding_output = embedding_layer(type_tensor[:, :, i].long()).cuda()
                embedded_tensor[:, :, start_index:start_index + embedding_output.shape[-1]] = embedding_output
                start_index = start_index + embedding_output.shape[-1]

                # Place the linearly projected distance feature into the tensor
                end_index = start_index + embedding_dim
                embedded_tensor[:, :, start_index:end_index] = linear_output.view(batch_size, seq_length, embedding_dim).cuda()
                start_index = end_index

        # Add positional encoding to the embeddings
        positional_encoding = self.positional_encoding_src if is_src else self.positional_encoding_distance
        embedded_tensor = embedded_tensor + positional_encoding

        # Reshape the output tensor and return the final result
        embedded_tensor = embedded_tensor.view(batch_size, seq_length, -1).cuda()
        return embedded_tensor.cuda()

In [None]:
class MLPDecoder(nn.Module):
    """
    A Multi-Layer Perceptron (MLP) decoder for transforming input features
    into desired output representations. The architecture consists of two 
    fully connected layers with ReLU activation and dropout for regularization.

    Args:
        input_size (int): The number of input features.
        hidden_size (int): The number of units in the hidden layer.
        output_size (int): The number of output features.
        dropout (float): Dropout probability for regularization.

    Methods:
        forward(x): Forward pass through the network. Processes the input 
                    tensor `x` through dropout, activation, and fully connected layers.
    """
    def __init__(self, input_size, hidden_size, output_size, dropout):
        """
        Initialize the MLPDecoder model.

        Args:
            input_size (int): The size of the input features.
            hidden_size (int): The size of the hidden layer.
            output_size (int): The size of the output layer.
            dropout (float): Dropout probability for regularization.
        """
        super(MLPDecoder, self).__init__()

        # First fully connected layer: input_size -> hidden_size
        self.fc1 = nn.Linear(input_size, hidden_size).cuda()
        
        # Second fully connected layer: hidden_size -> output_size
        self.fc2 = nn.Linear(hidden_size, output_size).cuda()
        
        # Dropout layer for regularization after input
        self.dropout = nn.Dropout(p=dropout)
        
        # ReLU activation after dropout and first FC layer
        self.relu = nn.ReLU()
        
        # Additional dropout and ReLU activation after the first FC layer
        self.dropout1 = nn.Dropout(p=dropout)
        self.relu1 = nn.ReLU()

    def forward(self, x):
        """
        Defines the forward pass through the MLP decoder.

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, input_size).

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, output_size).
        """
        # Apply dropout to the input
        x = self.dropout(x).cuda()
        
        # Apply ReLU activation
        x = self.relu(x).cuda()
        
        # First fully connected layer
        x = self.fc1(x).cuda()
        
        # Additional dropout after the first FC layer
        x = self.dropout1(x).cuda()
        
        # Additional ReLU activation
        x = self.relu1(x).cuda()
        
        # Second fully connected layer (output layer)
        x = self.fc2(x).cuda()
        
        return x

In [None]:
class TransformerBase(nn.Module):
    """
    A Transformer-based neural network model for sequence-to-sequence tasks. 
    This class integrates embedding layers, transformer encoder-decoder blocks, 
    and a projection layer for output.

    Args:
        embedding_dims (list[int]): List of embedding dimensions for each feature.
        src_len (int): Length of the source sequence.
        tgt_len (int): Length of the target sequence.
        hidden (int): Hidden layer size for the MLP decoder.
        num_layers (int): Number of layers in the encoder and decoder.
        num_heads (int): Number of attention heads in the transformer layers.
        dropout (float): Dropout probability for regularization (default is 0.1).

    Methods:
        forward(src, tgt=None, src_mask=None, tgt_mask=None, training=True):
            Executes the forward pass of the model. If `training` is True, 
            the target tensor is used for decoding; otherwise, decoding is done 
            without a target.
    """
    def __init__(self, embedding_dims, src_len, tgt_len, hidden, num_layers, num_heads, dropout=0.1):
        """
        Initialize the TransformerBase model.

        Args:
            embedding_dims (list[int]): Embedding dimensions for each input feature.
            src_len (int): Source sequence length.
            tgt_len (int): Target sequence length.
            hidden (int): Hidden size for the MLP decoder.
            num_layers (int): Number of transformer encoder/decoder layers.
            num_heads (int): Number of attention heads in each transformer layer.
            dropout (float): Dropout rate for regularization.
        """
        super(TransformerBase, self).__init__()
        self.src_len = src_len
        self.tgt_len = tgt_len
        self.embedding_size = sum(embedding_dims)  # Total embedding size.
        self.hidden_size = hidden
        self.output_size = len(embedding_dims)  # Number of output dimensions (e.g., x and y coordinates).

        # Embedding layer for source and target sequences.
        self.embedding_layer = TransformerBaseEmbedding(
            embedding_dims, sequence_length_src=src_len, sequence_length_tgt=tgt_len
        ).cuda()

        # Transformer encoder and decoder layers.
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=self.embedding_size, nhead=num_heads, batch_first=True).cuda()
        self.decoder_layer = nn.TransformerDecoderLayer(d_model=self.embedding_size, nhead=num_heads, batch_first=True).cuda()

        # Full encoder and decoder using the above layers.
        self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers).cuda()
        self.decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=num_layers).cuda()

        # Linear projection layer for output mapping.
        self.linear_output = nn.Linear(self.embedding_size, self.output_size).cuda()

        # MLP Decoder for final feature transformation.
        self.mlpdecoder = MLPDecoder(self.embedding_size, self.hidden_size, self.output_size, dropout=dropout).cuda()

        # Regularization layers.
        self.dropout = nn.Dropout(p=dropout).cuda()
        self.relu = nn.ReLU().cuda()

    def forward(self, src, tgt=None, src_mask=None, tgt_mask=None, training=True):
        """
        Forward pass through the TransformerBase model.

        Args:
            src (torch.Tensor): Source sequence tensor of shape (batch_size, src_len, embedding_size).
            tgt (torch.Tensor, optional): Target sequence tensor of shape (batch_size, tgt_len, embedding_size). Required if training=True.
            src_mask (torch.Tensor, optional): Mask for the source sequence (default is None).
            tgt_mask (torch.Tensor, optional): Mask for the target sequence (default is None).
            training (bool, optional): If True, the model uses the target tensor for decoding (default is True).

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, tgt_len, output_size).
        """
        # Embed the source sequence.
        src_embedded = self.embedding_layer(src, is_src=True).cuda()
        src_embedded = self.dropout(src_embedded).cuda()  # Apply dropout.
        src_embedded = self.relu(src_embedded).cuda()  # Apply ReLU activation.

        # Pass the source embeddings through the transformer encoder.
        memory = self.encoder(src_embedded, mask=src_mask).cuda()

        if training:
            # Embed the target sequence for training.
            tgt_embedded = self.embedding_layer(tgt, is_src=False).cuda()
            
            # Decode using the target embeddings and encoded memory.
            output = self.decoder(tgt_embedded, memory, tgt_mask=tgt_mask).cuda()
        else:
            # For inference, create a start token tensor for decoding.
            start_token = torch.zeros(1, self.tgt_len, src_embedded.shape[2]).cuda()
            start_token = start_token.repeat(src.shape[0], 1, 1)  # Repeat for the batch size.

            # Perform decoding using the start token and encoded memory.
            output = self.decoder(start_token, memory, tgt_mask=tgt_mask).cuda()

        # Pass the output through the MLP decoder.
        output = self.mlpdecoder(output).cuda()

        # Reshape to (batch_size, tgt_len, output_size) for final output.
        output = output.view(src.shape[0], self.tgt_len, self.output_size).cuda()

        return output


In [None]:
class Transformer(nn.Module):
    """
    Transformer Model: Implements a transformer-based architecture for sequence-to-sequence 
    prediction tasks, particularly for spatial-temporal data processing.

    This model uses an encoder-decoder structure with transformer layers and MLP decoders 
    for output projection.

    Args:
        src_dim (list): Dimensions of the source input features.
        tgt_dim (list): Dimensions of the target output features.
        num_agents (int): Number of agents in the dataset.
        num_layers (int): Number of transformer layers in the encoder and decoder.
        num_heads (int): Number of attention heads in the transformer layers.
        hidden (int): Hidden layer size for the MLP decoder.
        src_len (int): Length of the source sequence (default=10).
        tgt_len (int): Length of the target sequence (default=40).
        dropout (float): Dropout probability for regularization (default=0.1).

    Methods:
        forward(src, tgt=None, src_mask=None, tgt_mask=None, training=True):
            Performs a forward pass through the model.
    """
    def __init__(self, src_dim, tgt_dim, num_agents, num_layers, num_heads, hidden, src_len=10, tgt_len=40, dropout=0.1):
        """
        Initialize the Transformer model.

        Args:
            src_dim (list): Source input feature dimensions.
            tgt_dim (list): Target output feature dimensions.
            num_agents (int): Number of agents in the dataset.
            num_layers (int): Number of encoder and decoder layers.
            num_heads (int): Number of attention heads in transformer layers.
            hidden (int): Hidden size for the MLP decoder.
            src_len (int): Length of the source sequence.
            tgt_len (int): Length of the target sequence.
            dropout (float): Dropout probability.
        """
        super(Transformer, self).__init__()

        # Key parameters for architecture
        self.src_len = src_len  # Source sequence length
        self.tgt_len = tgt_len  # Target sequence length
        self.size = sum(src_dim)  # Total input embedding size
        self.hidden = hidden  # Hidden size for MLP decoder
        self.output_size = len(tgt_dim)  # Output size (e.g., x and y coordinates)

        # Embedding layer for source and target data
        self.embedding_layer = TransformerEmbedding(
            src_dims=src_dim,
            tgt_dims=tgt_dim,
            num_agents=num_agents,
            sequence_length_src=src_len,
            sequence_length_tgt=tgt_len
        ).cuda()

        # Transformer encoder and decoder layers
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=self.size, nhead=num_heads, batch_first=True).cuda()
        self.decoder_layer = nn.TransformerDecoderLayer(d_model=self.size, nhead=num_heads, batch_first=True).cuda()
        

        # Full encoder and decoder stacks
        self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers).cuda()
        
        self.decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=num_layers).cuda()

        # Linear projection layer for output
        self.mlpdecoder = MLPDecoder(self.size, self.hidden, self.output_size, dropout=dropout).cuda()

        # Regularization and activation layers
        self.dropout = nn.Dropout(p=dropout).cuda()
        self.relu = nn.ReLU().cuda()

    def forward(self, src, tgt=None, src_mask=None, tgt_mask=None, training=True):
        """
        Forward pass of the Transformer model.

        Args:
            src (torch.Tensor): Source input tensor of shape (batch_size, src_len, src_dim).
            tgt (torch.Tensor, optional): Target input tensor of shape (batch_size, tgt_len, tgt_dim).
                Required if training is True.
            src_mask (torch.Tensor, optional): Attention mask for source input (default=None).
            tgt_mask (torch.Tensor, optional): Attention mask for target input (default=None).
            training (bool): Indicates whether the model is in training mode (default=True).

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, tgt_len, output_dim).
        """
        # Embed the source sequence.
        src_embedded = self.embedding_layer(src, is_src=True).cuda()
        src_embedded = self.dropout(src_embedded).cuda()  # Apply dropout.
        src_embedded = self.relu(src_embedded).cuda()  # Apply ReLU activation.

        # Pass the source embeddings through the transformer encoder.
        memory = self.encoder(src_embedded, mask=src_mask).cuda()

        if training:
            # Embed the target sequence for training.
            tgt_embedded = self.embedding_layer(tgt, is_src=False).cuda()
            
            # Decode using the target embeddings and encoded memory.
            output = self.decoder(tgt_embedded, memory, tgt_mask=tgt_mask).cuda()
        else:
            # For inference, create a start token tensor for decoding.
            start_token = torch.zeros(1, self.tgt_len, src_embedded.shape[2]).cuda()
            start_token = start_token.repeat(src.shape[0], 1, 1)  # Repeat for the batch size.

            # Perform decoding using the start token and encoded memory.
            output = self.decoder(start_token, memory, tgt_mask=tgt_mask).cuda()

        # Pass the output through the MLP decoder.
        output = self.mlpdecoder(output).cuda()

        # Reshape to (batch_size, tgt_len, output_size) for final output.
        output = output.view(src.shape[0], self.tgt_len, self.output_size).cuda()

        return output

In [None]:
class STAR(torch.nn.Module):
    """
    STAR Model: A spatial-temporal attention-based model for sequence prediction tasks.

    This model incorporates spatial and temporal transformer encoders, 
    along with MLP decoders, to predict sequences over a target length.

    Args:
        embedding_dims (np.array): Dimensions of the input embeddings.
        dist_dims (np.array): Dimensions of the distance embeddings.
        type_dims (np.array): Dimensions of the type embeddings.
        num_types (int): Number of unique types in the type embeddings.
        hidden (int): Hidden layer size for the decoders (default=256).
        num_layers (int): Number of transformer encoder layers (default=16).
        num_heads (int): Number of attention heads in the transformer layers (default=8).
        src_len (int): Length of the source sequence.
        tgt_len (int): Length of the target sequence.
        dropout (float): Dropout probability for regularization (default=0.1).

    Methods:
        forward(src, distance, type, src_mask=None, dist_key_padding_mask=None):
            Performs a forward pass through the model, processing embeddings, encoding 
            spatial and temporal features, and decoding predictions.
    """
    def __init__(
        self, 
        embedding_dims: np.array, 
        dist_dims: np.array, 
        type_dims: np.array, 
        num_types: int, 
        hidden: int = 256, 
        num_layers: int = 16, 
        num_heads: int = 8, 
        src_len: int = 10, 
        tgt_len: int = 40, 
        dropout: float = 0.1
    ):
        """
        Initialize the STAR model.

        Args:
            embedding_dims (np.array): Dimensions for input embeddings.
            dist_dims (np.array): Dimensions for distance embeddings.
            type_dims (np.array): Dimensions for type embeddings.
            num_types (int): Number of unique types for type embeddings.
            hidden (int): Hidden layer size for decoders.
            num_layers (int): Number of layers in transformer encoders.
            num_heads (int): Number of attention heads in transformers.
            src_len (int): Length of the source sequence.
            tgt_len (int): Length of the target sequence.
            dropout (float): Dropout probability for regularization.
        """
        super(STAR, self).__init__()

        # Define key architectural parameters
        self.embedding_size = sum(embedding_dims)  # Total embedding size after concatenation.
        self.output_size = 2  # Output size (e.g., x and y coordinates).
        self.dropout_prob = dropout  # Dropout rate.
        self.src_len = src_len  # Source sequence length.
        self.tgt_len = tgt_len  # Target sequence length.
        self.hidden_size = hidden  # Hidden size for decoders.

        # Embedding layer for input, distance, and type data.
        self.embedding_layer = STAREmbedding(embedding_dims, dist_dims, type_dims, num_types, src_len).cuda()

        # Transformer encoder layers for temporal and spatial data.
        self.temporal_encoder_layer = nn.TransformerEncoderLayer(d_model=self.embedding_size, nhead=num_heads, batch_first=True).cuda()
        self.spatial_encoder_layer = nn.TransformerEncoderLayer(d_model=self.embedding_size, nhead=num_heads, batch_first=True).cuda()

        # Stacked transformer encoders for spatial and temporal processing.
        self.spatial_encoder_1 = nn.TransformerEncoder(self.spatial_encoder_layer, num_layers).cuda()
        self.spatial_encoder_2 = nn.TransformerEncoder(self.spatial_encoder_layer, num_layers).cuda()
        self.temporal_encoder_1 = nn.TransformerEncoder(self.temporal_encoder_layer, num_layers).cuda()
        self.temporal_encoder_2 = nn.TransformerEncoder(self.temporal_encoder_layer, num_layers).cuda()

        # Decoders for transforming encoded features to output.
        self.decoder1 = MLPDecoder(self.embedding_size, self.hidden_size, self.output_size, dropout=dropout).cuda()
        self.decoder2 = MLPDecoder(self.embedding_size, self.hidden_size, self.output_size, dropout=dropout).cuda()
        self.decoder3 = MLPDecoder(self.embedding_size, self.hidden_size, self.output_size, dropout=dropout).cuda()
        self.decoder4 = MLPDecoder(self.embedding_size, self.hidden_size, self.output_size, dropout=dropout).cuda()

        # Fusion layer to combine spatial and temporal features.
        self.fusion_layer = nn.Linear(self.embedding_size * 2, self.embedding_size).cuda()

        # Regularization and activation layers.
        self.dropout = nn.Dropout(p=dropout)
        self.relu = nn.ReLU()
        self.dropout1 = nn.Dropout(p=dropout)
        self.relu1 = nn.ReLU()

    def forward(self, src, distance, type, src_mask=None, dist_key_padding_mask=None):
        """
        Forward pass of the STAR model.

        Args:
            src (torch.Tensor): Source input tensor of shape (batch_size, src_len, embedding_dims).
            distance (torch.Tensor): Distance input tensor of shape (batch_size, src_len, dist_dims).
            type (torch.Tensor): Type input tensor of shape (batch_size, src_len, type_dims).
            src_mask (torch.Tensor, optional): Mask for source input (default=None).
            dist_key_padding_mask (torch.Tensor, optional): Mask for distance input (default=None).

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, tgt_len, output_size).
        """
        # Temporal embedding for source data.
        src_temporal_embedded = self.embedding_layer(src, is_src=True).cuda()
        src_temporal_embedded = self.dropout(src_temporal_embedded).cuda()  # Apply dropout.
        src_temporal_embedded = self.relu(src_temporal_embedded).cuda()  # Apply ReLU activation.

        # Spatial embedding for distance data.
        src_dist_embedding = self.embedding_layer(distance, is_src=False, src_tensor=src, type_tensor=type).cuda()
        src_dist_embedded = self.dropout1(src_dist_embedding).cuda()  # Apply dropout.
        src_dist_embedded = self.relu1(src_dist_embedded).cuda()  # Apply ReLU activation.

        # Process embeddings through spatial and temporal encoders.
        spatial_input_embedded = self.spatial_encoder_1(src_dist_embedded, mask=src_mask).cuda()
        temporal_input_embedded = self.temporal_encoder_1(src_temporal_embedded, mask=src_mask).cuda()

        # Fuse spatial and temporal features.
        fusion_feat = torch.cat((temporal_input_embedded, spatial_input_embedded), dim=2).cuda()
        fusion_feat = self.fusion_layer(fusion_feat).cuda()  # Apply linear fusion layer.

        # Further processing through secondary encoders.
        spatial_output = self.spatial_encoder_2(fusion_feat).cuda()
        temporal_output = self.temporal_encoder_2(spatial_output).cuda()

        # Reshape output for decoding.
        temporal_output = temporal_output.reshape(
            temporal_output.shape[0] * temporal_output.shape[1], temporal_output.shape[2]
        )

        # Decode spatial-temporal features using decoders.
        output1 = self.decoder1(temporal_output).cuda()
        output2 = self.decoder2(temporal_output).cuda()
        output3 = self.decoder3(temporal_output).cuda()
        output4 = self.decoder4(temporal_output).cuda()

        # Concatenate decoder outputs and reshape.
        output = torch.cat([output1, output2, output3, output4], dim=1)
        output = output.reshape(src.shape[0], self.tgt_len, self.output_size).cuda()

        return output.cuda()


In [None]:
class SAESTAR(torch.nn.Module):
    """
    SAESTAR Model: A spatial-temporal transformer-based model for sequence prediction tasks.

    This model integrates spatial and temporal encoding layers with a multi-layer perceptron (MLP) decoder. 
    It fuses spatial and temporal features to produce predictions over a target sequence length.

    Args:
        src_dims (list[int]): Dimensions of source embeddings.
        dist_dims (list[int]): Dimensions of distance embeddings.
        type_dims (list[int]): Dimensions of type embeddings.
        num_types (int): Number of unique types in the type embedding.
        num_types_dist (int): Number of unique types in the distance embedding.
        hidden (int): Hidden layer size for the decoders (default=256).
        num_layers (int): Number of layers in the transformer encoders (default=16).
        num_heads (int): Number of attention heads in the transformer layers (default=8).
        src_len (int): Length of the source sequence.
        tgt_len (int): Length of the target sequence.
        dropout (float): Dropout probability for regularization (default=0.1).

    Methods:
        forward(src, distance, type, src_mask=None, dist_key_padding_mask=None):
            Executes the forward pass of the model, embedding the input data,
            encoding spatial and temporal features, and decoding the output.
    """
    def __init__(self, src_dims, dist_dims, type_dims, num_types, num_types_dist, hidden=256, num_layers=16, num_heads=8, src_len=10, tgt_len=40, dropout=0.1):
        """
        Initialize the SAESTAR model.

        Args:
            src_dims (list[int]): Dimensions for the source embeddings.
            dist_dims (list[int]): Dimensions for the distance embeddings.
            type_dims (list[int]): Dimensions for the type embeddings.
            num_types (int): Total number of unique types for type embeddings.
            num_types_dist (int): Total number of unique types for distance embeddings.
            hidden (int): Hidden size for the decoders.
            num_layers (int): Number of layers for each transformer encoder.
            num_heads (int): Number of attention heads in each transformer layer.
            src_len (int): Length of the source sequence.
            tgt_len (int): Length of the target sequence.
            dropout (float): Dropout probability for regularization.
        """
        super(SAESTAR, self).__init__()

        # Model architecture parameters.
        self.embedding_size = sum(src_dims)  # Total size of the input embeddings.
        self.output_size = 2  # Output dimensions (e.g., x and y coordinates).
        self.dropout_prob = dropout  # Dropout probability.
        self.src_len = src_len  # Source sequence length.
        self.tgt_len = tgt_len  # Target sequence length.
        self.hidden_size = hidden  # Hidden layer size.

        # Embedding layer for source, distance, and type inputs.
        self.embedding_layer = SAESTAREmbedding(src_dims, dist_dims, type_dims, num_types, num_types_dist, src_len).cuda()

        # Transformer encoder layers for spatial and temporal features.
        self.temporal_encoder_layer = nn.TransformerEncoderLayer(d_model=self.embedding_size, nhead=num_heads, batch_first=True).cuda()
        self.spatial_encoder_layer = nn.TransformerEncoderLayer(d_model=self.embedding_size, nhead=num_heads, batch_first=True).cuda()

        # Stacked transformer encoders for spatial and temporal processing.
        self.spatial_encoder_1 = nn.TransformerEncoder(self.spatial_encoder_layer, num_layers).cuda()
        self.spatial_encoder_2 = nn.TransformerEncoder(self.spatial_encoder_layer, num_layers).cuda()
        self.temporal_encoder_1 = nn.TransformerEncoder(self.temporal_encoder_layer, num_layers).cuda()
        self.temporal_encoder_2 = nn.TransformerEncoder(self.temporal_encoder_layer, num_layers).cuda()

        # Decoders for spatial-temporal feature projection.
        self.decoder1 = MLPDecoder(self.embedding_size, self.hidden_size, self.output_size, dropout=dropout).cuda()
        self.decoder2 = MLPDecoder(self.embedding_size, self.hidden_size, self.output_size, dropout=dropout).cuda()
        self.decoder3 = MLPDecoder(self.embedding_size, self.hidden_size, self.output_size, dropout=dropout).cuda()
        self.decoder4 = MLPDecoder(self.embedding_size, self.hidden_size, self.output_size, dropout=dropout).cuda()

        # Fusion layer for combining spatial and temporal features.
        self.fusion_layer = nn.Linear(self.embedding_size * 2, self.embedding_size).cuda()

        # Regularization and activation layers.
        self.dropout = nn.Dropout(p=dropout)
        self.relu = nn.ReLU()
        self.dropout1 = nn.Dropout(p=dropout)
        self.relu1 = nn.ReLU()

    def forward(self, src, distance, type, src_mask=None, dist_key_padding_mask=None):
        """
        Forward pass through the SAESTAR model.

        Args:
            src (torch.Tensor): Source sequence tensor of shape (batch_size, src_len, src_dims).
            distance (torch.Tensor): Distance tensor of shape (batch_size, src_len, dist_dims).
            type (torch.Tensor): Type tensor of shape (batch_size, src_len, type_dims).
            src_mask (torch.Tensor, optional): Mask for the source sequence (default=None).
            dist_key_padding_mask (torch.Tensor, optional): Mask for the distance tensor (default=None).

        Returns:
            torch.Tensor: Output tensor of shape (batch_size, tgt_len, output_size).
        """
        # Embed the source sequence for temporal encoding.
        src_temporal_embedded = self.embedding_layer(src, is_src=True).cuda()
        src_temporal_embedded = self.dropout(src_temporal_embedded).cuda()  # Apply dropout.
        src_temporal_embedded = self.relu(src_temporal_embedded).cuda()  # Apply ReLU activation.

        # Embed the distance sequence for spatial encoding.
        src_dist_embedding = self.embedding_layer(distance, is_src=False, src_tensor=src, type_tensor=type).cuda()
        src_dist_embedded = self.dropout1(src_dist_embedding).cuda()  # Apply dropout.
        src_dist_embedded = self.relu1(src_dist_embedded).cuda()  # Apply ReLU activation.

        # Process spatial and temporal embeddings through respective encoders.
        spatial_input_embedded = self.spatial_encoder_1(src_dist_embedded, mask=src_mask).cuda()
        temporal_input_embedded = self.temporal_encoder_1(src_temporal_embedded, mask=src_mask).cuda()

        # Fuse spatial and temporal features.
        fusion_feat = torch.cat((temporal_input_embedded, spatial_input_embedded), dim=2).cuda()
        fusion_feat = self.fusion_layer(fusion_feat).cuda()  # Combine features using the fusion layer.

        # Process the fused features through secondary encoders.
        spatial_output = self.spatial_encoder_2(fusion_feat).cuda()
        temporal_output = self.temporal_encoder_2(spatial_output).cuda()

        # Reshape the temporal output for decoding.
        temporal_output = temporal_output.reshape(
            temporal_output.shape[0] * temporal_output.shape[1], temporal_output.shape[2]
        )

        # Decode the temporal output through multiple decoders.
        output1 = self.decoder1(temporal_output).cuda()
        output2 = self.decoder2(temporal_output).cuda()
        output3 = self.decoder3(temporal_output).cuda()
        output4 = self.decoder4(temporal_output).cuda()

        # Concatenate the outputs from all decoders.
        output = torch.cat([output1, output2, output3, output4], dim=1)
        output = output.reshape(src.shape[0], self.tgt_len, self.output_size).cuda()

        return output.cuda()

In [None]:
class Scaler:
    """
    A class to scale and unscale data using Min-Max scaling. It supports scaling of source (src),
    target (tgt), and distance (distance) data, with optional spatial scaling.
    """

    def __init__(self, train_data : dict, model_name : str=None, spatial=False, size : int=None):
        """
        Initializes the Scaler class, fitting separate scalers for different data types (source, target, distance).
        
        Parameters:
        train_data (dict): A dictionary containing training data with keys 'src', 'tgt', and 'distance'.
        spatial (bool): If True, a scaler for distance data is also created (default: False).
        """
        self.spatial = False  # Default spatial scaling is turned off
        # Create and fit the scaler for the source data (first 3 columns of 'src')
        self.src_scaler = self.create_scaler(train_data['src'][:, :, :3])
        # Create and fit the scaler for the target data (first 2 columns of 'tgt')
        self.tgt_scaler = self.create_scaler(train_data['tgt'][:, :, :2])
        # If spatial scaling is enabled, create and fit the scaler for distance data
        if spatial:
            if model_name == "STAR":
                self.dist_scaler = self.create_scaler(train_data['distance'][:, :, :size])
            else:
                self.dist_scaler = self.create_scaler(train_data['distance'])
                
    def create_scaler(self, data):
        """
        Creates and fits a MinMaxScaler to the provided data.
        
        Parameters:
        data (numpy.ndarray): The data to fit the scaler on.
        
        Returns:
        MinMaxScaler: A fitted MinMaxScaler instance.
        """
        # Get the shape of the input data
        shape = data.shape
        # Instantiate a MinMaxScaler
        scaler = MinMaxScaler()
        # Convert data to numpy if it's a tensor
        np_data = np.array(data)
        # Flatten the data to 2D for scaling
        flat_data = np_data.reshape(shape[0] * shape[1], shape[2])
        # Clip the data to avoid extreme values that could distort scaling
        clipped_data = np.clip(np.array(flat_data), -1e6, 1e6)
        # Fit and return the scaler on the clipped data
        return scaler.fit(clipped_data)

    def scale(self, data, scaler_type: str):
        """
        Scales the input data using the specified scaler type (src, tgt, or distance).
        
        Parameters:
        data (torch.Tensor): The data to be scaled (tensor format).
        scaler_type (str): The type of scaler to use ('src', 'tgt', or 'distance').
        
        Returns:
        torch.Tensor: Scaled data in the original shape.
        """
        # Select the appropriate scaler based on the input type
        if scaler_type == "src":
            scaler = self.src_scaler
        elif scaler_type == "tgt":
            scaler = self.tgt_scaler
        elif scaler_type == "dist":
            scaler = self.dist_scaler
        
        # Get the shape of the input data
        shape = data.shape
        # Flatten the data into 2D for scaling
        data_flat = data.view(data.shape[0] * data.shape[1], data.shape[2])
        # Clip the data to avoid extreme values that could distort scaling
        data_clipped = np.clip(np.array(data_flat.cpu()), -1e6, 1e6)
        # Transform the data using the fitted scaler
        data_scaled = torch.tensor(scaler.transform(data_clipped)).cuda()
        # Reshape the data back to its original shape
        reshaped_data = data_scaled.reshape(shape)
        return reshaped_data

    def unscale(self, data, scaler_type: str, original_shape):
        """
        Reverses the scaling transformation applied to the data, returning the data to its original scale.
        
        Parameters:
        data (torch.Tensor): The data to be unscaled (scaled data in tensor format).
        scaler_type (str): The type of scaler to use for unscaling ('src', 'tgt', or 'distance').
        original_shape (tuple): The original shape of the data to reshape after unscaling.
        
        Returns:
        torch.Tensor: Unscaled data in its original shape.
        """
        # Select the appropriate scaler based on the input type
        if scaler_type == "src":
            scaler = self.src_scaler
        elif scaler_type == "tgt":
            scaler = self.tgt_scaler
        else:
            scaler = self.dist_scaler
        # Flatten the data for inverse transformation
        data_flatten = data.reshape(data.shape[0] * data.shape[1], data.shape[2])
        # Inverse transform the data to the original scale
        data_unscaled = torch.tensor(scaler.inverse_transform(np.array(data_flatten.cpu()))).cuda()
        # Reshape the unscaled data back to the original shape
        return data_unscaled.reshape(original_shape)
        

In [None]:
class Experiment:
    """
    A class to handle the training and evaluation of various models for time series prediction.

    Args:
        scaler (Scaler): A Scaler object to normalize input data.
        model_name (str): The model architecture to use. Choices: 'Base', 'Transformer', 'STAR', 'SAESTAR'.
        src_len (int): The length of the source sequence.
        tgt_len (int): The length of the target sequence.
        num_types (int, optional): Number of types (used in SAESTAR model). Defaults to None.
        graph_dims (int, optional): Dimension of the graph (used in STAR and SAESTAR models). Defaults to None.
        layers (int, optional): Number of layers in the model. Defaults to 16.
        heads (int, optional): Number of attention heads. Defaults to 8.
        hidden_size (int, optional): The hidden layer size. Defaults to 256.
        dropout (float, optional): The dropout rate for regularization. Defaults to 0.1.
        lr (float, optional): The learning rate for the optimizer. Defaults to 0.000015.

    Attributes:
        scaler (Scaler): The Scaler object for input normalization.
        model_type (str): The type of model to use ('Base', 'Transformer', 'STAR', 'SAESTAR').
        model (nn.Module): The model object (e.g., Transformer, STAR, etc.).
        optimizer (Adam): The optimizer used for training.
        criterion (nn.Module): The loss function (MSELoss).
    """

    def __init__(
        self,
        scaler: Scaler,
        model_name: str, 
        src_len: int, 
        tgt_len: int, 
        num_types: int = None, 
        graph_dims: int = None, 
        layers: int = 16, 
        heads: int = 8, 
        hidden_size: int = 256, 
        dropout: float = 0.1, 
        lr: float = 0.000015,
        epochs : int = 10,
        location_name : str='',
        earlystopping : int=30
    ):
        """
        Initializes the Experiment class, selects the appropriate model, and prepares the optimizer.

        Args are as described in the class docstring.
        """
        self.scaler = scaler
        self.epochs = epochs
        self.model_type = model_name
        self.heads = heads
        self.criterion = nn.MSELoss()
        self._name = location_name
        self.early_stopping_patience = earlystopping
        self.optimizer = None
        self.model = None
        
        # Select the model architecture based on the model_name
        if model_name == "Base":
            self.model = self.base_model(src_len, tgt_len, hidden_size, layers, heads, dropout)
        if model_name == "Transformer":
            self.model = self.transformer_model(src_len, tgt_len, hidden_size, layers, heads, dropout)
        if model_name == "STAR":
            self.model = self.star_model(src_len, tgt_len, graph_dims, hidden_size, layers, heads, dropout)
        if model_name == "SAESTAR":
            self.model = self.saestar_model(src_len, tgt_len, graph_dims, num_types, hidden_size, layers, heads, dropout)
        
        # Initialize the optimizer
        self.optimizer = Adam(self.model.parameters(), lr=lr)

    def create_mask(self, batch_size, sequence_length):
        """
        Creates a look-ahead mask for the sequence to prevent attending to future tokens.
        
        Args:
            batch_size (int): The size of the batch (number of sequences).
            sequence_length (int): The length of the sequence to create the mask for.
        
        Returns:
            torch.Tensor: The look-ahead mask with shape (batch_size, sequence_length, sequence_length).
        """
        
        # Create a look-ahead mask (upper triangular matrix with 1s above the diagonal)
        look_ahead_mask = torch.triu(torch.ones(sequence_length, sequence_length), diagonal=1).bool()
        
        # Expand the mask to match the batch size
        look_ahead_mask = look_ahead_mask.unsqueeze(0).expand(batch_size * self.heads, -1, -1)  # Shape: (batch_size, sequence_length, sequence_length)
        return look_ahead_mask
    
   
    def generate_embedded_dim(self, total, array_length):
        """
        Generates a list of dimensions that sum up to a total value, ensuring the sum of the list equals the total.
        
        Args:
            total (int): The total value to split into an array.
            array_length (int): The length of the array to generate.
        
        Returns:
            list: A list of dimensions.
        """
        initial_value = total // array_length
        remainder = total % array_length
        array = [initial_value] * array_length
        array[-1] += remainder
        
        return torch.tensor(array, dtype=torch.int).cuda()
    
    def base_model(self, src, tgt, hidden, layers, heads, dropout) -> nn.Module:
        """
        Creates and returns a Base model.
        
        Args:
            src (int): The source sequence length.
            tgt (int): The target sequence length.
            hidden (int): The hidden layer size.
            layers (int): The number of layers in the model.
            heads (int): The number of attention heads.
            dropout (float): The dropout rate.
        
        Returns:
            nn.Module: The instantiated Base model.
        """
        embedding_dims = [128, 128]
        return TransformerBase(
            embedding_dims, 
            src_len=src, 
            tgt_len=tgt, 
            hidden=hidden, 
            num_layers=layers, 
            num_heads=heads, 
            dropout=dropout
        )
    
    def transformer_model(self, src, tgt, hidden, layers, heads, dropout) -> nn.Module:
        """
        Creates and returns a Transformer model.
        
        Args:
            src (int): The source sequence length.
            tgt (int): The target sequence length.
            hidden (int): The hidden layer size.
            layers (int): The number of layers in the model.
            heads (int): The number of attention heads.
            dropout (float): The dropout rate.
        
        Returns:
            nn.Module: The instantiated Transformer model.
        """
        embedding_dims_src = [94, 94, 64, 4]
        embedding_dims_tgt = [128, 128]   
        return Transformer(
            embedding_dims_src, 
            embedding_dims_tgt, 
            num_agents=4, 
            num_layers=layers, 
            num_heads=heads, 
            hidden=hidden, 
            src_len=src, 
            tgt_len=tgt, 
            dropout=dropout
        )

    def star_model(self, src, tgt, graph_dims, hidden, layers, heads, dropout) -> nn.Module:
        """
        Creates and returns a STAR model.
        
        Args:
            src (int): The source sequence length.
            tgt (int): The target sequence length.
            graph_dims (int): The graph dimensions.
            hidden (int): The hidden layer size.
            layers (int): The number of layers in the model.
            heads (int): The number of attention heads.
            dropout (float): The dropout rate.
        
        Returns:
            nn.Module: The instantiated STAR model.
        """
        src_dims = [94, 94, 64, 4]
        total = sum(src_dims)
        dist_dims = self.generate_embedded_dim(total / 2, graph_dims) 
        type_dims = self.generate_embedded_dim(total / 2, graph_dims)
        return STAR(src_dims, dist_dims, type_dims, 4, hidden, layers, heads, src, tgt, dropout)

    def saestar_model(self, src, tgt, graph_dims, num_types, hidden, layers, heads, dropout) -> nn.Module:
        """
        Creates and returns a SAESTAR model.
        
        Args:
            src (int): The source sequence length.
            tgt (int): The target sequence length.
            graph_dims (int): The graph dimensions.
            num_types (int): The number of types in the model.
            hidden (int): The hidden layer size.
            layers (int): The number of layers in the model.
            heads (int): The number of attention heads.
            dropout (float): The dropout rate.
        
        Returns:
            nn.Module: The instantiated SAESTAR model.
        """
        src_dims = [94, 94, 64, 4]
        total = sum(src_dims)
        dist_dims = self.generate_embedded_dim(total / 2, graph_dims) 
        type_dims = self.generate_embedded_dim(total / 2, graph_dims) 
        return SAESTAR(src_dims, dist_dims, type_dims, 4, num_types, hidden, layers, heads, src, tgt, dropout)
    
    def train(self, train_loader: DataLoader, epoch: int) -> float:
        """
        Trains the model for one epoch using the provided DataLoader.

        Args:
            train_loader (DataLoader): The DataLoader for the training dataset.
            epoch (int): The current epoch number.

        Returns:
            float: The total training loss for the epoch.
        """
        # Set the model to training mode
        self.model.train()
        train_loss = 0.0  # Initialize training loss
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{self.epochs}")


        # Iterate over batches in the training data
        for batch_idx, batch in enumerate(progress_bar):
            self.optimizer.zero_grad()  # Zero gradients before backward pass
            
            # Handle different model types
            if self.model_type == 'Base':
                inputs, targets = batch['src'].cuda(), batch['tgt'].cuda()
                src_mask = self.create_mask(inputs.shape[0], inputs.shape[1])
                tgt_mask = self.create_mask(targets.shape[0], targets.shape[1])
                inputs = self.scaler.scale(inputs[:, :, :3], "src")
                targets = self.scaler.scale(targets[:, :, :2], "tgt")
                outputs = self.model(inputs[:, :, :2], targets, src_mask=src_mask.cuda(), tgt_mask=tgt_mask.cuda())
            
            elif self.model_type == 'Transformer':
                inputs, targets = batch['src'].cuda(), batch['tgt'].cuda()
                src_mask = self.create_mask(inputs.shape[0], inputs.shape[1])
                tgt_mask = self.create_mask(targets.shape[0], targets.shape[1])
                scaled_inputs = self.scaler.scale(inputs[:, :, :3], "src")
                targets = self.scaler.scale(targets[:, :, :2], "tgt")
                inputs = torch.cat((scaled_inputs, inputs[:, :, 4:].cuda()), dim=2)
                outputs = self.model(inputs, targets, src_mask=src_mask.cuda(), tgt_mask=tgt_mask.cuda())
            
            elif self.model_type in ['STAR', 'SAESTAR']:
                inputs, targets, distances, distance_types = batch['src'].cuda(), batch['tgt'].cuda(), batch['distance'].cuda(), batch['type'].cuda()
                src_mask = self.create_mask(inputs.shape[0], inputs.shape[1])
                scaled_inputs = self.scaler.scale(inputs[:, :, :3], "src")
                targets = self.scaler.scale(targets[:, :, :2], "tgt")
                distances = self.scaler.scale(distances, "dist")
                inputs = torch.cat((scaled_inputs, inputs[:, :, 4:].cuda()), dim=2)
                outputs = self.model(inputs.type(torch.float32), distances.type(torch.float32), distance_types, src_mask=src_mask.cuda())

            # Compute loss, backpropagate, and optimize
            loss = self.criterion(outputs.type(torch.float32), targets.type(torch.float32))
            loss.backward()  # Backpropagation
            self.optimizer.step()  # Optimization step
            
            # Accumulate loss and update progress bar
            train_loss += loss.item()
            progress_bar.set_postfix({'Training Loss': train_loss / ((batch_idx + 1) * len(inputs))})

        return train_loss

    
    def eval(self, valid_loader: DataLoader) -> float:
        """
        Evaluates the model on a validation set.

        Args:
            valid_loader (DataLoader): The DataLoader for the validation dataset.

        Returns:
            float: The average validation loss for the validation set.
        """
        # Set the model to evaluation mode
        self.model.eval()
        valid_loss = 0.0

        # No gradient computation during validation
        with torch.no_grad():
            for batch_idx, batch in enumerate(valid_loader):
                # Handle different model types and prepare inputs/targets accordingly
                if self.model_type == 'Base':
                    inputs, targets = batch['src'].cuda(), batch['tgt'].cuda()
                    src_mask = self.create_mask(inputs.shape[0], inputs.shape[1])
                    tgt_mask = self.create_mask(targets.shape[0], targets.shape[1])
                    inputs = self.scaler.scale(inputs[:, :, :3], "src")
                    targets = self.scaler.scale(targets[:, :, :2], "tgt")
                    outputs = self.model(inputs[:, :, :2], targets, src_mask=src_mask.cuda(), tgt_mask=tgt_mask.cuda())
                
                elif self.model_type == 'Transformer':
                    inputs, targets = batch['src'].cuda(), batch['tgt'].cuda()
                    src_mask = self.create_mask(inputs.shape[0], inputs.shape[1])
                    tgt_mask = self.create_mask(targets.shape[0], targets.shape[1])
                    scaled_inputs = self.scaler.scale(inputs[:, :, :3], "src")
                    targets = self.scaler.scale(targets[:, :, :2], "tgt")
                    inputs = torch.cat((scaled_inputs, inputs[:, :, 4:].cuda()), dim=2)
                    outputs = self.model(inputs, targets, src_mask=src_mask.cuda(), tgt_mask=tgt_mask.cuda())
                
                elif self.model_type in ['STAR', 'SAESTAR']:
                    inputs, targets, distances, distance_types = batch['src'].cuda(), batch['tgt'].cuda(), batch['distance'].cuda(), batch['type'].cuda()
                    src_mask = self.create_mask(inputs.shape[0], inputs.shape[1])
                    scaled_inputs = self.scaler.scale(inputs[:, :, :3], "src")
                    targets = self.scaler.scale(targets[:, :, :2], "tgt")
                    distances = self.scaler.scale(distances, "dist")
                    inputs = torch.cat((scaled_inputs, inputs[:, :, 4:].cuda()), dim=2)
                    outputs = self.model(inputs.type(torch.float32), distances.type(torch.float32), distance_types, src_mask=src_mask.cuda())

                # Compute loss
                loss = self.criterion(outputs.type(torch.float32), targets.type(torch.float32))
                valid_loss += loss.item()

        # Calculate average validation loss
        valid_loss /= len(valid_loader.dataset)
        return valid_loss

        
    def train_model(self, train_loader, valid_loader) -> None:
        """
        Trains the model over multiple epochs, with early stopping and TensorBoard logging.

        Args:
            train_loader (DataLoader): The DataLoader for the training dataset.
            valid_loader (DataLoader): The DataLoader for the validation dataset.
        """
        self.model.cuda()  # Move model to GPU
        best_valid_loss = float('inf')
        early_stopping_counter = 0

        # Set up TensorBoard writers for logging training and validation losses
        log_dir = f'./data/Results/{self.model_type}/{self._name}/logs'
        weights_dir = f'./data/Weights/{self.model_type}/{self._name}/'
        train_writer = SummaryWriter(log_dir= log_dir + '/Train')
        valid_writer = SummaryWriter(log_dir= log_dir + f'/Valid')

        for epoch in range(self.epochs):
            # Train for one epoch
            train_loss = self.train(train_loader, epoch)
            torch.save(self.model.state_dict(), weights_dir + 'current/checkpoint.pth')

            # Evaluate on the validation set
            valid_loss = self.eval(valid_loader)

            # Log losses
            train_loss /= len(train_loader.dataset)
            valid_loss /= len(valid_loader.dataset)
            print(f"Epoch {epoch + 1}/{self.epochs}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}")

            train_writer.add_scalar('Loss', train_loss, epoch)
            valid_writer.add_scalar('Loss', valid_loss, epoch)

            # Save the model checkpoint if validation loss improves
            if valid_loss < best_valid_loss:
                best_valid_loss = valid_loss
                torch.save(self.model.state_dict(), weights_dir + '/best_chkp/checkpoint.pth')
                print("Checkpoint saved.")
                early_stopping_counter = 0
            else:
                early_stopping_counter += 1

            # Early stopping if validation loss does not improve for a number of epochs
            if early_stopping_counter >= self.early_stopping_patience:
                print(f"Validation loss hasn't decreased for {self.early_stopping_patience} epochs. Stopping training.")
                break

        # Close TensorBoard writers
        train_writer.close()
        valid_writer.close()


    def calculate_ADE(self, predicted_trajectories, true_trajectories):
        """
        Calculates the Average Displacement Error (ADE) between predicted and true trajectories.

        Args:
            predicted_trajectories (np.array): The predicted trajectory points.
            true_trajectories (np.array): The true trajectory points.

        Returns:
            float: The Average Displacement Error (ADE).
        """
        errors = np.linalg.norm(predicted_trajectories - true_trajectories, axis=-1)
        ADE = np.mean(errors, axis=-1)
        return ADE


    def calculate_FDE(self, predicted_trajectories, true_trajectories):
        """
        Calculates the Final Displacement Error (FDE) between predicted and true trajectories.

        Args:
            predicted_trajectories (np.array): The predicted trajectory points.
            true_trajectories (np.array): The true trajectory points.

        Returns:
            np.array: The Final Displacement Error (FDE).
        """
        errors = np.linalg.norm(predicted_trajectories[:, -1] - true_trajectories[:, -1], axis=-1)
        return errors


    def calculate_FDE(self, predicted_trajectories, true_trajectories):
        """
        Calculates the Final Displacement Error (FDE) between predicted and true trajectories.
        The FDE measures the Euclidean distance between the predicted and actual positions 
        at the final time step of a trajectory.

        Args:
            predicted_trajectories (np.array): The predicted trajectory points.
            true_trajectories (np.array): The true trajectory points.

        Returns:
            np.array: The Final Displacement Error (FDE) for each prediction.
        """
        # Calculate the Euclidean distance between predicted and true trajectories at the final timestep
        errors = np.linalg.norm(predicted_trajectories[:, -1] - true_trajectories[:, -1], axis=-1)
        return errors


    def validation_metrics(self, predicted_trajectories, true_trajectories):
        """
        Computes validation metrics, including the Average Displacement Error (ADE) and 
        Final Displacement Error (FDE), and identifies the minimum ADE and FDE.

        Args:
            predicted_trajectories (np.array): The predicted trajectory points.
            true_trajectories (np.array): The true trajectory points.

        Returns:
            tuple: A tuple containing the following metrics:
                - ADE (np.array): The Average Displacement Error for each prediction.
                - FDE (np.array): The Final Displacement Error for each prediction.
                - minADE (float): The minimum value of the ADE.
                - minFDE (float): The minimum value of the FDE.
        """
        # Calculate ADE and FDE using the helper functions
        ADE = self.calculate_ADE(predicted_trajectories, true_trajectories)
        FDE = self.calculate_FDE(predicted_trajectories, true_trajectories)
        
        # Find indices of the minimum values for ADE and FDE
        minADE_idx = np.argmin(ADE)
        minFDE_idx = np.argmin(FDE)
        
        # Get the minimum ADE and FDE values
        minADE = ADE[minADE_idx]
        minFDE = FDE[minFDE_idx]
        
        return ADE, FDE, minADE, minFDE


    def test_model(self, test_loader):
        """
        Evaluates the trained model on the test dataset, calculating prediction accuracy
        using ADE and FDE. Logs results to TensorBoard and saves the predictions and 
        true values to disk for later analysis.

        Args:
            test_loader (DataLoader): The DataLoader for the test dataset.
        """
        # Set model to evaluation mode and move it to GPU
        self.model.cuda()
        self.model.eval()
        
        # Initialize lists to store results and predictions
        all_ADE = np.array([])
        all_FDE = np.array([])
        all_predictions = np.array([])
        all_actuals = np.array([])

        # Set up TensorBoard writer for logging
        writer = SummaryWriter()

        # Evaluate the model on the test dataset without gradient computation
        with torch.no_grad():
            src_tensors = np.array([])
            tgt_tensor = np.array([])
            pred_tensor = np.array([])

            # Iterate over the test data batches
            for batch_idx, batch in enumerate(test_loader):
                # Retrieve inputs and targets for each model type
                if self.model_type == 'Base':
                    src, tgt = batch['src'].cuda(), batch['tgt'].cuda()
                    src_mask = self.create_mask(src.shape[0], src.shape[1])
                    tgt_mask = self.create_mask(tgt.shape[0], tgt.shape[1])
                    inputs = self.scaler.scale(src[:, :, :3], "src")
                    targets = self.scaler.scale(tgt[:, :, :2], "tgt")
                    outputs = self.model(inputs[:, :, :2], targets, src_mask=src_mask.cuda(), tgt_mask=tgt_mask.cuda())

                elif self.model_type == 'Transformer':
                    src, tgt = batch['src'].cuda(), batch['tgt'].cuda()
                    src_mask = self.create_mask(src.shape[0], src.shape[1])
                    tgt_mask = self.create_mask(tgt.shape[0], tgt.shape[1])
                    scaled_inputs = self.scaler.scale(src[:, :, :3], "src")
                    targets = self.scaler.scale(tgt[:, :, :2], "tgt")
                    inputs = torch.cat((scaled_inputs, src[:, :, 4:].cuda()), dim=2)
                    outputs = self.model(inputs, targets, src_mask=src_mask.cuda(), tgt_mask=tgt_mask.cuda())

                elif self.model_type == 'STAR' or self.model_type == 'SAESTAR':
                    src, tgt, distances, distance_types = batch['src'].cuda(), batch['tgt'].cuda(), batch['distance'].cuda(), batch['type'].cuda()
                    src_mask = self.create_mask(src.shape[0], src.shape[1])
                    scaled_inputs = self.scaler.scale(src[:, :, :3], "src")
                    targets = self.scaler.scale(tgt[:, :, :2], "tgt")
                    distances = self.scaler.scale(distances, "dist")
                    inputs = torch.cat((scaled_inputs, src[:, :, 4:].cuda()), dim=2)
                    outputs = self.model(inputs.type(torch.float32), distances.type(torch.float32), distance_types, src_mask=src_mask.cuda())

                # Unscale the model outputs to original values
                new_outputs = self.scaler.unscale(outputs, "tgt", tgt[:, :, :2].shape)
                
                # Calculate validation metrics (ADE, FDE)
                ADE, FDE, minADE, minFDE = self.validation_metrics(new_outputs.cpu().numpy(), tgt[:, :, :2].cpu().numpy())
                
                if pred_tensor.shape[0] == 0:
                    # Collect predictions, targets, and sources for later processing
                    pred_tensor =  new_outputs.cpu().numpy()
                    tgt_tensor = tgt.cpu().numpy()
                    src_tensors = src.cpu().numpy()
                                        
                    all_ADE = ADE
                    all_FDE = FDE

                    # Save predictions and true values
                    all_predictions = new_outputs.cpu().numpy()
                    all_actuals = targets[:, :, :2].cpu().numpy()
                
                # Collect predictions, targets, and sources for later processing
                pred_tensor = np.append(pred_tensor, new_outputs.cpu().numpy(), axis=0)
                tgt_tensor = np.append(tgt_tensor, tgt.cpu().numpy(), axis=0)
                src_tensors = np.append(src_tensors, src.cpu().numpy(), axis=0)
                
                all_ADE = np.append(all_ADE, ADE, axis=0)
                all_FDE = np.append(all_FDE, FDE, axis=0)

                # Save predictions and true values
                all_predictions = np.append(all_predictions, new_outputs.cpu().numpy(), axis=0)
                all_actuals = np.append(all_actuals, targets[:, :, :2].cpu().numpy(), axis=0)
                
                # Log results to TensorBoard
                writer.add_scalar('ADE/mean', ADE.mean())
                writer.add_scalar('FDE/mean', FDE.mean())
                writer.add_scalar('ADE/Min_mean', minADE.mean())
                writer.add_scalar('FDE/Min_mean', minFDE.mean())

            # Close the TensorBoard writer
            writer.close()
            
            # Save the results to CSV
            file_path = f"./data/Results/{self.model_type}/{self._name}/src_pred_tgt_results.csv"
            with open(file_path, mode='w', newline='') as file:
                csv_writer = csv.writer(file)
                csv_writer.writerow(["Source", "Prediction", "Target"])

                for i in range(src_tensors.shape[0]):
                    for j in range(src_tensors.shape[1]):
                        src_values = src_tensors[i, j, :]
                        prediction_values = pred_tensor[i, j, :]
                        target_values = tgt_tensor[i, j, :]
                        csv_writer.writerow([src_values, prediction_values, target_values])

        # Save tensors for later analysis
        src_tensors = torch.tensor(src_tensors)
        tgt_tensor = torch.tensor(tgt_tensor)
        pred_tensor = torch.tensor(pred_tensor)

        torch.save(src_tensors, f"./data/Results/{self.model_type}/{self._name}/src.pt")
        torch.save(tgt_tensor, f"./data/Results/{self.model_type}/{self._name}/tgt.pt")
        torch.save(pred_tensor, f"./data/Results/{self.model_type}/{self._name}/pred.pt")

        # Calculate overall metrics across all batches
        mean_ADE = np.mean(all_ADE)
        mean_FDE = np.mean(all_FDE)

        # Find the minimum ADE and FDE values
        min_ADE_idx = np.argmin(all_ADE)
        min_FDE_idx = np.argmin(all_FDE)
        min_ADE = all_ADE[min_ADE_idx]
        min_FDE = all_FDE[min_FDE_idx]

        # Save the overall metrics to a file
        save_file = f"./data/Results/{self.model_type}/{self._name}/metrics.txt"
        with open(save_file, "w") as file:
            file.write(f"Mean ADE: {mean_ADE}\n")
            file.write(f"Mean FDE: {mean_FDE}\n")
            file.write(f"Min ADE: {min_ADE}\n")
            file.write(f"Min FDE: {min_FDE}\n")

        print("Validation results saved to", save_file)

In [None]:
class ExperimentManager:
    """
    Manages the lifecycle of experiments, including data preparation, model training, and evaluation.
    """

    def __init__(
        self, 
        location_name: str, 
        epochs: int, 
        learning_rate: float, 
        num_layers: int, 
        num_heads: int, 
        dropout: float, 
        src_len: int, 
        tgt_len: int, 
        batch_size: int,
        hidden_size: int,
        earlystopping : int
    ):
        """
        Initialize the ExperimentManager with necessary configurations and perform initial visualizations.
        
        Parameters:
        - location: Location object with dataset details.
        - visualization: Visualization object for data insights.
        - epochs: Number of training epochs.
        - learning_rate: Learning rate for optimization.
        - num_layers: Number of layers in the model.
        - num_heads: Number of attention heads in the model.
        - dropout: Dropout rate for regularization.
        - src_len: Source sequence length for the model.
        - tgt_len: Target sequence length for the model.
        - batch_size: Batch size for training.
        - hidden_size: Hidden layer size in the model.
        """
        self.location_name = location_name
        self.epochs = epochs
        self.lr = learning_rate
        self.layers = num_layers
        self.heads = num_heads
        self.dropout = dropout
        self.batch_size = batch_size
        self.hidden_size = hidden_size
        self.earlystopping = earlystopping
        self.src_len = src_len
        self.tgt_len = tgt_len
        
    
    def load_tensors(self, model_name : str, data_type : str, spatial : bool=False):
        """
        Loads tensors from disk for a given model and data type, optionally including spatial data.

        Args:
            model_name (str): The name of the model (used for logging).
            data_type (str): The type of data to load (e.g., 'train', 'validation', 'test').
            spatial (bool): Whether to include spatial data ('dist' and 'dist_type') in the returned tensors.
        
        Returns:
            dict: A dictionary containing loaded tensors, with keys 'src', 'tgt', and optionally 'distance', 'type'.
        """
        print(f"Loading {data_type} for model {model_name} at Location: {self.location_name}")
        
        # Loading mandatory tensors
        src = torch.load(f"./data/Dataset/{self.location_name}/{data_type}/src.pt", weights_only=True)
        tgt = torch.load(f"./data/Dataset/{self.location_name}/{data_type}/tgt.pt", weights_only=True)
        
        if spatial:
            # Loading additional spatial data if requested
            dist = torch.load(f"./data/Dataset/{self.location_name}/{data_type}/dist.pt", weights_only=True)
            d_type = torch.load(f"./data/Dataset/{self.location_name}/{data_type}/dist_type.pt", weights_only=True)
            return {'src': src, 'tgt': tgt, 'distance': dist, 'type': d_type}
        else:
            return {'src': src, 'tgt': tgt}
    
    def experiment_base(self):
        """
        Runs the base experiment, handling data preparation and experiment execution.
        """
        model_name = 'Base'
        print("Base Experiment Starting")
        print("-" * 30)
        
        # Loading and preparing data
        train = self.load_tensors(model_name=model_name, data_type="Train", spatial=False)
        val = self.load_tensors(model_name=model_name, data_type="Val", spatial=False)
        test = self.load_tensors(model_name=model_name, data_type="Test", spatial=False)
        train_dataset = TransformerDataset(train)
        val_dataset = TransformerDataset(val)
        test_dataset = TransformerDataset(test)
        
        # Scaling data and creating DataLoaders
        scaler = Scaler(train, False)
        train_dataloader, val_dataloader, test_dataloader = self.get_data_loaders(train_dataset, val_dataset, test_dataset)
        
        # Running the experiment
        self.run_experiment(model_name, scaler, train_dataloader, val_dataloader, test_dataloader, num_types=None, graph_dims=None)
        
        # Cleaning up memory
        del scaler, train, val, test, train_dataloader, test_dataloader, val_dataloader
        gc.collect()
    
    def experiment_transformer(self):
        """
        Runs the Transformer experiment, handling data preparation and experiment execution.
        """
        model_name = 'Transformer'
        print("Transformer Experiment Starting")
        print("-" * 30)
        
        # Loading and preparing data
        train = self.load_tensors(model_name=model_name, data_type="Train", spatial=False)
        val = self.load_tensors(model_name=model_name, data_type="Val", spatial=False)
        test = self.load_tensors(model_name=model_name, data_type="Test", spatial=False)
        train_dataset = TransformerDataset(train)
        val_dataset = TransformerDataset(val)
        test_dataset = TransformerDataset(test)
        
        # Scaling data and creating DataLoaders
        scaler = Scaler(train, False)
        train_dataloader, val_dataloader, test_dataloader = self.get_data_loaders(train_dataset, val_dataset, test_dataset)
        
        # Running the experiment
        self.run_experiment(model_name, scaler, train_dataloader, val_dataloader, test_dataloader, num_types=None, graph_dims=None)
        
        # Cleaning up memory
        del scaler, train, val, test, train_dataloader, test_dataloader, val_dataloader
        gc.collect()
    
    def experiment_star(self):
        """
        Runs the STAR experiment, handling data preparation and experiment execution.
        """
        model_name = 'STAR'
        print("STAR Experiment Starting")
        print("-" * 30)
        
        # Loading and preparing data
        train = self.load_tensors(model_name=model_name, data_type="Train", spatial=True)
        val = self.load_tensors(model_name=model_name, data_type="Val", spatial=True)
        test = self.load_tensors(model_name=model_name, data_type="Test", spatial=True)
        graph_dims = len(train['distance'][0, 0, :self.num_agents])
        train_dataset = STARDataset(train, self.num_agents)
        val_dataset = STARDataset(val, self.num_agents)
        test_dataset = STARDataset(test, self.num_agents)
        
        # Scaling data and creating DataLoaders
        scaler = Scaler(train, model_name, True, graph_dims)
        train_dataloader, val_dataloader, test_dataloader = self.get_data_loaders(train_dataset, val_dataset, test_dataset)
        
        # Running the experiment
        self.run_experiment(model_name, scaler, train_dataloader, val_dataloader, test_dataloader, num_types=None, graph_dims=graph_dims)
        
        # Cleaning up memory
        del scaler, train, val, test, train_dataloader, test_dataloader, val_dataloader
        gc.collect()
    
    def experiment_saestar(self):
        """
        Runs the SAESTAR experiment, handling data preparation and experiment execution.
        """
        model_name = 'SAESTAR'
        print("SAESTAR Experiment Starting")
        print("-" * 30)
        
        # Loading and preparing data
        train = self.load_tensors(model_name=model_name, data_type="Train", spatial=True)
        val = self.load_tensors(model_name=model_name, data_type="Val", spatial=True)
        test = self.load_tensors(model_name=model_name, data_type="Test", spatial=True)
        df_env = pd.read_csv(f"./data/CombinedData/{self.location_name}/env_df.csv", sep=',')
        num_types = 4 + len(df_env['ID'].unique())
        graph_dims = len(train['distance'][0, 0, :])
        train_dataset = SAESTARDataset(train)
        val_dataset = SAESTARDataset(val)
        test_dataset = SAESTARDataset(test)
        
        # Scaling data and creating DataLoaders
        scaler = Scaler(train, model_name, True)
        train_dataloader, val_dataloader, test_dataloader = self.get_data_loaders(train_dataset, val_dataset, test_dataset)
        
        # Running the experiment
        self.run_experiment(model_name, scaler, train_dataloader, val_dataloader, test_dataloader, num_types=num_types, graph_dims=graph_dims)
        
        # Cleaning up memory
        del scaler, train, val, test, train_dataloader, test_dataloader, val_dataloader
        gc.collect()
    
    def run_experiment(self, model_name, scaler, train, val, test, num_types, graph_dims):
        """
        Executes the training and testing of the specified model.
        
        Args:
            model_name (str): The name of the model being tested.
            scaler (Scaler): An object to scale and normalize the data.
            train, val, test (DataLoader): DataLoaders for the respective datasets.
            num_types (int): The number of types or classes (if applicable).
            graph_dims (int): The dimensionality of graph data (if applicable).
        """
        print(f"{model_name} Experiment Initiating")
        
        # Initializing the experiment with the provided configuration
        experiment = Experiment(
            scaler=scaler,
            model_name=model_name,
            src_len=self.src_len,
            tgt_len=self.tgt_len,
            num_types=num_types,
            graph_dims=graph_dims,
            layers=self.layers,
            heads=self.heads,
            hidden_size=self.hidden_size,
            dropout=self.dropout,
            lr=self.lr,
            epochs=self.epochs,
            location_name=self.location_name,
            earlystopping=self.earlystopping
        )
        
        # Training and testing the model
        experiment.train_model(train, val)
        experiment.test_model(test)
        
        # Cleaning up memory
        del experiment
        gc.collect()
    
    def get_data_loaders(self, train : Dataset, val : Dataset, test : Dataset):
        """
        Creates DataLoader objects for training, validation, and testing datasets.
        
        Args:
            train, val, test (Dataset): Datasets for training, validation, and testing.
        
        Returns:
            tuple[DataLoader]: DataLoader objects for training, validation, and testing.
        """
        train_dataloader = DataLoader(train, batch_size=self.batch_size, shuffle=True)
        val_dataloader = DataLoader(val, batch_size=self.batch_size, shuffle=True)
        test_dataloader = DataLoader(test, batch_size=self.batch_size, shuffle=False)
        
        return train_dataloader, val_dataloader, test_dataloader

In [None]:
# Check if GPU is available 
gpu = torch.cuda.is_available()
print("GPU Available: ", gpu)
# Apply deterministic on CUDA convoltion operations
torch.backends.cudnn.deterministic = True
# Disable benchmark mode
torch.backends.cudnn.benchmark = False
# Create a manual seed for testing
torch.manual_seed(42)

epochs = 1
num_layers = 16
num_heads = 8
dropout = 0.3
learning_rate = 0.000015
src_len = 10
tgt_len = 40
batch_size = 32
hidden_size = 512
earlystopping = 30


experiment_manager = ExperimentManager(
        location_name="Torpagatan", 
        epochs=epochs, 
        learning_rate=learning_rate, 
        num_layers=num_layers, 
        num_heads=num_heads, 
        dropout=dropout,
        src_len=src_len,
        tgt_len=tgt_len,
        batch_size=batch_size,
        hidden_size=hidden_size,
        earlystopping=earlystopping
        )
    
experiment_manager.experiment_base()
experiment_manager.experiment_transformer()
experiment_manager.experiment_star()
experiment_manager.experiment_saestar()