# Final notebook DL project 29, group 113

Please run this cell first to have all of the correct libraries.

In [None]:
import random
import pandas as pd
import numpy as np
import glob
import os
import torch
from torch import dropout
from torch.utils.data import Dataset
import torch.nn as nn
from torch.nn import GRU, Module, Linear, ReLU, Sequential, Dropout
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
import json
from geopy.distance import geodesic
from math import radians, sin, cos, sqrt, atan2

## 1) Preprocessing

In [37]:
def track_filter(g):
        """
        Filter a DataFrame group representing a vessel track or segment.

        Conditions:
            - Length > MIN_TRACK_LENGTH
            - Max SOG between MIN_SOG_FOR_TRACK and MAX_SOG_FOR_TRACK
            - Timespan >= MIN_TRACK_TIMESPAN_SECONDS

        Args:
            g (pd.DataFrame): Vessel track data.

        Returns:
            bool: True if the track passes all filters.
        """
        len_filt = len(g) > MIN_TRACK_LENGTH  # Min required length of track/segment
        sog_filt = MIN_SOG_FOR_TRACK <= g["SOG"].max() <= MAX_SOG_FOR_TRACK  # Remove stationary or outlier segments
        time_filt = (g["Timestamp"].max() - g["Timestamp"].min()).total_seconds() >= MIN_TRACK_TIMESPAN_SECONDS
        return len_filt and sog_filt and time_filt

In [38]:
def interpolate_linear(df_in, interval=f"10min", group_col="MMSI", method="linear",linear_cols = ["Latitude", "Longitude", "SOG"]):
            """
            Interpolates Latitude, Longitude, and SOG linearly for each MMSI–Segment pair.

            Args:
                df_in (pd.DataFrame): Input data containing Latitude, Longitude, SOG, MMSI, Segment.
                interval (str): Resampling interval (e.g. "10min").
                group_col (str): Grouping column, usually "MMSI".
                method (str): Interpolation method.

            Returns:
                pd.DataFrame: Interpolated data with [MMSI, Segment, Timestamp, Latitude, Longitude, SOG].
            """
            df = df_in.copy()
            df.index = pd.to_datetime(df.index, errors="coerce")
            df = df.dropna(subset=[group_col, "Segment"])
            linear_cols = ["Latitude", "Longitude", "SOG"]
            groups = []
            # Group by both MMSI and Segment
            for (mmsi, segment), g in df.groupby([group_col, "Segment"]):
                g = g.sort_index()
                g_numeric = g[linear_cols]
                g2 = g_numeric.resample(interval).mean().interpolate(method=method)
                g2[group_col] = mmsi
                g2["Segment"] = segment
                groups.append(g2)
            resampled_df = pd.concat(groups)
            # Reset index and rename to Timestamp
            resampled_df = resampled_df.reset_index().rename(columns={"index": "Timestamp"})
            return resampled_df

In [39]:
def interpolate_circular(df_in, interval=f"10min", group_col="MMSI", circular_cols = ["COG"]):
            """
            Interpolates angular features (e.g., COG) using sine/cosine trick to avoid discontinuities.

            Args:
                df_in (pd.DataFrame): Input data containing COG, MMSI, Segment.
                interval (str): Resampling interval (e.g. "10min").
                group_col (str): Grouping column, usually "MMSI".

            Returns:
                pd.DataFrame: Interpolated angular data with [MMSI, Segment, Timestamp, COG].
            """
            df = df_in.copy()
            df.index = pd.to_datetime(df.index, errors="coerce")
            df = df.dropna(subset=[group_col, "Segment"])
            groups = []
            for (mmsi, segment), g in df.groupby([group_col, "Segment"]):
                g = g.sort_index()
                # Resample on Timestamp
                resampled = g[circular_cols].resample(interval).mean()
                for col in circular_cols:
                    # Forward fill to avoid NaNs before trig transform
                    raw = resampled[col].copy().ffill()
                    rad = np.deg2rad(raw)
                    sin_interp = pd.Series(np.sin(rad), index=raw.index).interpolate(method="linear")
                    cos_interp = pd.Series(np.cos(rad), index=raw.index).interpolate(method="linear")
                    angle_rad = np.arctan2(sin_interp, cos_interp)
                    resampled[col] = np.rad2deg(angle_rad) % 360
                resampled[group_col] = mmsi
                resampled["Segment"] = segment
                groups.append(resampled)
            resampled_df = pd.concat(groups)
            resampled_df = resampled_df.reset_index().rename(columns={"index": "Timestamp"})
            return resampled_df

In [40]:
def df_create(file_path: str) -> pd.DataFrame:
    """
    Create a cleaned DataFrame from raw AIS data.
    """

    ### ================================================================
    ### --- FLAGS ---
    ### ================================================================
    REMOVE_OUT_OF_BOUNDS = True
    ENABLE_TYPE_FILTER = True
    MMSI_VALIDATION = True
    MMSI_MID_RANGE_VALIDATION = True
    REMOVE_INVALID_TIMESTAMPS = True
    REMOVE_DUPLICATES = True
    REMOVE_ZERO_COORDS = True
    REMOVE_SOG_COG_NAN_IF_NO_INTERPOLATION = True
    REMOVE_HIGH_SOG = True
    REMOVE_MIN_VALID_YEAR = 2015

    ENABLE_TRACK_FILTER = True
    ENABLE_SEGMENTATION = True

    ENABLE_INTERPOLATION = True
    ENABLE_NORMALIZATION = False

    ENABLE_SORTING = True

    ### ================================================================
    ### --- CONSTANTS ---
    ### ================================================================
    KNOTS_TO_MS = 0.514444  # Conversion factor from knots to meters per second

    BBOX = [60, 0, 50, 20]
    VALID_TYPES = ["Class A", "Class B"]
    MMSI_STANDARD_LENGTH = 9
    VALID_MID_RANGE = (200, 775)
    MAX_SOG_THRESHOLD = 30
    MIN_VALID_YEAR = 2015
    MIN_TRACK_LENGTH = 256
    MIN_SOG_FOR_TRACK = 1
    MAX_SOG_FOR_TRACK = 50

    MIN_TRACK_TIMESPAN_SECONDS = 3600
    SEGMENT_GAP_SECONDS = 900
    MIN_SEGMENT_LENGTH = 10

    ROUND_INTERVAL_MIN = 10

    ### ================================================================
    ### STEP 1: READ CSV
    ### ================================================================
    # Define expected data types to optimize memory usage and ensure consistency
    dtypes = {
        "MMSI": "object",
        "SOG": float,
        "COG": float,
        "Longitude": float,
        "Latitude": float,
        "# Timestamp": "object",
        "Type of mobile": "object",
    }
    usecols = list(dtypes.keys())
    # Load raw AIS data from CSV file with specified columns and types
    df = pd.read_csv(file_path, usecols=usecols, dtype=dtypes)
    # Rename timestamp column immediately after reading
    df = df.rename(columns={"# Timestamp": "Timestamp"})
    initial_rows = len(df)
    print(f"[INFO] Loaded raw CSV: {initial_rows} rows.")

    ### --- CHECK FOR MISSING EXPECTED COLUMNS ---
    # Validate presence of critical columns to avoid downstream errors
    expected_columns = {"MMSI", "Timestamp", "Latitude", "Longitude", "SOG", "COG"}
    missing_cols = expected_columns - set(df.columns)
    if missing_cols:
        print(f"[WARNING] The following expected columns are missing from the dataframe: {missing_cols}")

    ### --- CONVERT SOG FROM KNOTS TO M/S IMMEDIATELY AFTER READING ---
    # Convert SOG from knots to m/s for consistency in physical units
    df["SOG"] = KNOTS_TO_MS * df["SOG"]  # All further processing will be done in meters per second

    ### ================================================================
    ### STEP 2: BASIC CLEANING
    ### ================================================================
    ### --- REMOVE OUT-OF-BOUNDS POSITIONS ---
    # Filter out AIS points outside the defined geographic bounding box to focus on the area of interest
    if REMOVE_OUT_OF_BOUNDS:
        north, west, south, east = BBOX
        initial_rows = len(df)
        df = df[(df["Latitude"] <= north) & (df["Latitude"] >= south) & (df["Longitude"] >= west) & (df["Longitude"] <= east)]
        removed = initial_rows - len(df)
        print(f"[INFO] Removed {removed} rows outside bounding box.")

    ### --- FILTER BY SHIP TYPE AND MMSI VALIDATION ---
    # Keep only relevant ship types to reduce noise and ensure data quality
    if ENABLE_TYPE_FILTER:
        initial_rows = len(df)
        df = df[df["Type of mobile"].isin(VALID_TYPES)].drop(columns=["Type of mobile"])
        removed = initial_rows - len(df)
        print(f"[INFO] Removed {removed} rows not matching valid vessel types.")

    ### --- MMSI FORMAT VALIDATION ---
    # Validate MMSI length to comply with standardized maritime identifiers
    if MMSI_VALIDATION and MMSI_STANDARD_LENGTH:
        initial_rows = len(df)
        df = df[df["MMSI"].str.len() == MMSI_STANDARD_LENGTH]  # Adhere to MMSI format
        removed = initial_rows - len(df)
        print(f"[INFO] Removed {removed} rows failing MMSI length validation.")
    
    ### --- MMSI MID RANGE VALIDATION ---
    # Filter by MID range to ensure vessels are registered in expected regions
    if MMSI_MID_RANGE_VALIDATION and VALID_MID_RANGE is not None:
        initial_rows = len(df)
        df = df[df["MMSI"].str[:3].astype(int).between(VALID_MID_RANGE[0], VALID_MID_RANGE[1])]  # Adhere to MID standard
        removed = initial_rows - len(df)
        print(f"[INFO] Removed {removed} rows failing MMSI MID range validation.")

    ### --- TIMESTAMP PARSING ---
    # Convert to datetime for temporal analysis
    df["Timestamp"] = pd.to_datetime(df["Timestamp"], format="%d/%m/%Y %H:%M:%S", errors="coerce")

    ### --- TIMESTAMP VALIDATION ---
    # Remove rows with invalid timestamps to maintain temporal integrity
    if REMOVE_INVALID_TIMESTAMPS:
        initial_rows = len(df)
        invalid_ts = df["Timestamp"].isna().sum()
        if invalid_ts > 0:
            print(f"[WARNING] {invalid_ts} / {initial_rows} rows have invalid timestamps and will be removed.")
        df = df.dropna(subset=["Timestamp"])
        removed = initial_rows - len(df)
        print(f"[INFO] Removed {removed} rows with invalid timestamps.")

    ### --- REMOVE DUPLICATE ENTRIES ---
    # Remove duplicate AIS messages to avoid bias and redundancy
    if REMOVE_DUPLICATES:
        initial_rows = len(df)
        df = df.drop_duplicates(["Timestamp", "MMSI", ], keep="first")
        removed = initial_rows - len(df)
        print(f"[INFO] Removed {removed} duplicate rows.")

    ### --- REMOVE INVALID COORDS ---
    # Exclude points with zero coordinates which are likely erroneous
    if REMOVE_ZERO_COORDS:
        initial_rows = len(df)
        df = df[(df["Latitude"] != 0) & (df["Longitude"] != 0)]
        removed = initial_rows - len(df)
        print(f"[INFO] Removed {removed} rows with zero coordinates.")

    ### --- REMOVE NaN SOG/COG IF NO INTERPOLATION ---
    # Drop rows with missing speed or course if interpolation is not enabled to keep data consistent
    if REMOVE_SOG_COG_NAN_IF_NO_INTERPOLATION and not ENABLE_INTERPOLATION:
        initial_rows = len(df)
        df = df.dropna(subset=["SOG", "COG"])
        removed = initial_rows - len(df)
        print(f"[INFO] Removed {removed} rows with NaN SOG/COG (no interpolation).")

    ### --- REMOVE HIGH SOG VALUES ---
    # Filter out unrealistic high speeds to remove noise and outliers
    if REMOVE_HIGH_SOG and MAX_SOG_THRESHOLD is not None:
        initial_rows = len(df)
        df = df[df["SOG"] < MAX_SOG_THRESHOLD]
        removed = initial_rows - len(df)
        print(f"[INFO] Removed {removed} rows with SOG >= {MAX_SOG_THRESHOLD} m/s.")

    ### --- REMOVE DATA BEFORE MIN VALID YEAR ---
    # Focus analysis on recent data by excluding older records
    if REMOVE_MIN_VALID_YEAR and MIN_VALID_YEAR is not None:
        initial_rows = len(df)
        df = df[df["Timestamp"].dt.year >= MIN_VALID_YEAR]
        removed = initial_rows - len(df)
        print(f"[INFO] Removed {removed} rows before year {MIN_VALID_YEAR}.")

    print(f"[INFO] Dataframe after basic cleaning has {len(df)} rows and {df['MMSI'].nunique()} unique MMSIs.")

    ### ================================================================
    ### STEP 3: TRACK FILTERING AND SEGMENTATION
    ### ================================================================
    # Filter out tracks that do not meet minimum criteria to improve model training quality
    if ENABLE_TRACK_FILTER:
        initial_rows = len(df)
        df = df.groupby("MMSI").filter(track_filter)
        removed = initial_rows - len(df)
        print(f"[INFO] Removed {removed} rows by track filtering.")

    if ENABLE_SEGMENTATION:
        # Segment tracks based on time gaps to isolate continuous vessel movements
        df['Segment'] = df.groupby('MMSI')['Timestamp'].transform(
            lambda x: (x.diff().dt.total_seconds().fillna(0) >= SEGMENT_GAP_SECONDS).cumsum()
        )
        initial_rows = len(df)
        df = df.groupby(["MMSI", "Segment"]).filter(track_filter)
        removed = initial_rows - len(df)
        print(f"[INFO] Removed {removed} rows by segment track filtering.")

        initial_rows = len(df)
        # Remove short segments after segmentation to keep only substantial trajectories
        df = df.groupby(["MMSI", "Segment"]).filter(lambda g: len(g) >= MIN_SEGMENT_LENGTH)
        removed = initial_rows - len(df)
        print(f"[INFO] Removed {removed} rows from segments shorter than {MIN_SEGMENT_LENGTH} rows.")

        df = df.reset_index(drop=True)
        # Store MMSI–Timestamp–Segment mapping before interpolation (so we can merge it back later)
        segment_df = df[["MMSI", "Timestamp", "Segment"]].copy()
    else:
        df['Segment'] = 0  # Default segment assignment when segmentation is disabled

    print(f"[INFO] Dataframe after track filtering and segmentation has {len(df)} rows and {df['MMSI'].nunique()} unique MMSIs.")

    ### ================================================================
    ### STEP 4: INTERPOLATION AND NORMALIZATION
    ### ================================================================
    if ENABLE_INTERPOLATION:
        print(f"[INFO] Interpolating dataset at {ROUND_INTERVAL_MIN}-minute intervals using pandas.resample...")
        # Set timestamp as index for time-based resampling
        df = df.set_index("Timestamp")

        # ================================================================
        # --- LINEAR INTERPOLATION FOR LINEAR VARIABLES ---
        # ================================================================
        # Columns that can be interpolated linearly
        linear_cols = ["Latitude", "Longitude", "SOG"]
        
        # Perform linear interpolation for position and speed, including Segment
        df_linear = interpolate_linear(
            df[["MMSI", "Segment"] + linear_cols],
            interval=f"{ROUND_INTERVAL_MIN}min",
            group_col="MMSI",
            method="linear"
        )
        print(f"[INFO] Linear interpolated rows: {len(df_linear)}")

        # ================================================================
        # --- CIRCULAR INTERPOLATION FOR ANGULAR VARIABLES ---
        # ================================================================
        # Columns that represent angles and require special interpolation to handle wrap-around
        circular_cols = ["COG"]  # Columns measured in degrees on a circular domain

        # Apply circular interpolation per MMSI–Segment to handle angular data correctly
        df_cog = interpolate_circular(
            df[["MMSI", "Segment"] + circular_cols],
            interval=f"{ROUND_INTERVAL_MIN}min",
            group_col="MMSI"
        )
        print(f"[INFO] Circular interpolated rows: {len(df_cog)}")

        # ================================================================
        # --- COMBINE INTERPOLATIONS ---
        # ================================================================
        # Merge linear and circular interpolated data on MMSI + Timestamp
        # Ensures both interpolation types align in time after resampling.
        df_interp = df_linear.merge(
            df_cog,
            on=["MMSI", "Segment", "Timestamp"],
            how="left"
        )
        # Segment column is already present in both; no need to merge segment_df again
        df = df_interp
        # Drop any remaining rows with missing critical values
        initial_rows = len(df)
        df = df.dropna(subset=["Latitude", "Longitude", "SOG", "COG"])
        removed = initial_rows - len(df)
        print(f"[INFO] Interpolated dataset now has {len(df)} rows (removed {removed} rows with missing critical values).")

    if ENABLE_NORMALIZATION:
        # Normalize Latitude and Longitude to [0,1] to facilitate machine learning model convergence
        print("[INFO] Normalizing Latitude and Longitude columns to [0,1] range.")
        df.loc[:, "Latitude"] = (df["Latitude"] - df["Latitude"].min()) / (df["Latitude"].max() - df["Latitude"].min())
        df.loc[:, "Longitude"] = (df["Longitude"] - df["Longitude"].min()) / (df["Longitude"].max() - df["Longitude"].min())

    # Sort data by MMSI and Timestamp to maintain chronological order for each vessel
    if ENABLE_SORTING:
        df = df.sort_values(["MMSI", "Segment", "Timestamp"]).reset_index(drop=True)

    ### ================================================================
    ### STEP 4: SUMMARY AND FINALIZATION
    ### ================================================================
    # Print summary statistics about the cleaned dataset
    print(f"[INFO] Final dataset stats:")
    print(f"  - Total rows: {len(df)}")
    print(f"  - Unique MMSIs: {df['MMSI'].nunique()}")
    print(f"  - Time span: {df['Timestamp'].min()} to {df['Timestamp'].max()}")

    return df

In [41]:
def split_dataset(df: pd.DataFrame, train_frac: float, val_frac: float) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    """
    Splits a cleaned AIS DataFrame into train/val/test splits based on MMSI-level vessel grouping.

    Args:
        df (pd.DataFrame): Cleaned AIS dataset.
        train_frac (float): Fraction of vessels to include in training set.
        val_frac (float): Fraction of vessels to include in validation set.

    Returns:
        tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: DataFrames for train, validation, and test sets.
    """

    # Determine number of unique vessels to split by vessel rather than by row
    ship_number = df["MMSI"].nunique()

    # Randomly sample vessels for training set
    train_ships = random.sample(list(df["MMSI"].unique()), int(ship_number * train_frac))
    # Sample vessels for validation set excluding training vessels
    val_ships = random.sample(list(set(df["MMSI"].unique()) - set(train_ships)), int(ship_number * val_frac))
    # Remaining vessels assigned to test set
    test_ships = list(set(df["MMSI"].unique()) - set(train_ships) - set(val_ships))

    # Create DataFrames for each split based on vessel membership
    df_train = df[df["MMSI"].isin(train_ships)].copy().reset_index(drop=True)
    df_val = df[df["MMSI"].isin(val_ships)].copy().reset_index(drop=True)
    df_test = df[df["MMSI"].isin(test_ships)].copy().reset_index(drop=True)

    return df_train, df_val, df_test

We have already cleaned and split the raw files, so the next cell will not work, but this is what gave us the final csv we worked with.

In [42]:
print("Scanning raw data folder for AIS CSV files...")
raw_files = glob.glob("raw data/*.csv")

if not raw_files:
    raise FileNotFoundError("No CSV files found in raw_data/")

for file_path in raw_files:
    base = os.path.basename(file_path)
    name, _ = os.path.splitext(base)
    print(f"\n=== Processing file: {base} ===")

    # --- Preprocess file ---
    df = df_create(file_path)

    # --- Split dataset ---
    print("Splitting into train/val/test...")
    train_df, val_df, test_df = split_dataset(df, train_frac=0.7, val_frac=0.15)

    # --- Export CSVs ---
    os.makedirs("datasplits", exist_ok=True)  # Ensure datasplits directory exists
    train_df.to_csv(f"datasplits/train/train_{name}.csv", index=False)
    val_df.to_csv(f"datasplits/val/val_{name}.csv", index=False)
    test_df.to_csv(f"datasplits/test/test_{name}.csv", index=False)

    print(f"Saved: train_{name}.csv, val_{name}.csv, test_{name}.csv")

Scanning raw data folder for AIS CSV files...


FileNotFoundError: No CSV files found in raw_data/

In [None]:
class AISDataset(Dataset):
    """
    Custom PyTorch Dataset for AIS (Automatic Identification System) trajectory data.

    Args:
        dataset_path (str or pd.DataFrame): Path to CSV or DataFrame containing AIS data.
        seq_input_length (int): Number of time steps in the input sequence.
        seq_output_length (int): Number of time steps in the output sequence.
        stats (tuple, optional): Tuple of (lat_min, lat_max, lon_min, lon_max, sog_max) to apply consistent normalization.
    """
    def __init__(self, dataset_path: str, seq_input_length: int = 5, seq_output_length: int = 5, stats=None):
        # We only support seq_output_length of equal to seq_input_length (same window size input-output) 
        if (seq_input_length != seq_output_length):
            raise ValueError("Seq_input_length must be equal to seq_output_length")
        
        self.dataframe = dataset_path if isinstance(dataset_path, pd.DataFrame) else pd.read_csv(dataset_path)

        if 'Timestamp' in self.dataframe.columns:
            self.dataframe.loc[:, 'Timestamp'] = pd.to_datetime(self.dataframe.loc[:, 'Timestamp'])
            self.dataframe = self.dataframe.sort_values(by=['MMSI', 'Timestamp']).reset_index(drop=True)
        
        self.seq_input_length = seq_input_length
        self.seq_output_length = seq_output_length
        self.valid_idxs = []
        

        # --- 1. INDEXING LOGIC ---
        # Precompute valid indices
        unique_mmsis = self.dataframe.loc[:, 'MMSI'].unique()
        for mmsi in unique_mmsis:
            mmsi_mask = self.dataframe.loc[:, 'MMSI'] == mmsi
            mmsi_df = self.dataframe[mmsi_mask]
            
            # Iterate through segments
            for segment_id, segment_data in mmsi_df.groupby('Segment'):
                if len(segment_data) < (self.seq_input_length + self.seq_output_length):
                   continue
                indices = segment_data.index
                num_sequences = len(segment_data) - (self.seq_input_length + self.seq_output_length) + 1
                for i in range(num_sequences):
                    self.valid_idxs.append(indices[i])

        # --- 2. NORMALIZATION LOGIC ---
        # If no stats are provided, compute from current dataset
        if stats is None:
            self.lat_min = self.dataframe.loc[:, 'Latitude'].min()
            self.lat_max = self.dataframe.loc[:, 'Latitude'].max()
            self.lon_min = self.dataframe.loc[:, 'Longitude'].min()
            self.lon_max = self.dataframe.loc[:, 'Longitude'].max()
            self.sog_max = self.dataframe.loc[:, 'SOG'].max() # Add speed max
        else:
            self.lat_min, self.lat_max, self.lon_min, self.lon_max, self.sog_max = stats

        # Save stats so we can retrieve them later
        self.stats = (self.lat_min, self.lat_max, self.lon_min, self.lon_max, self.sog_max)

        # Apply Normalization
        # epsilon 1e-6 to avoid division by zero
        self.dataframe.loc[:, 'Latitude'] = (self.dataframe.loc[:, 'Latitude'] - self.lat_min) / (self.lat_max - self.lat_min + 1e-6)
        self.dataframe.loc[:, 'Longitude'] = (self.dataframe.loc[:, 'Longitude'] - self.lon_min) / (self.lon_max - self.lon_min + 1e-6)
        
        # Normalize Speed (SOG) roughly 0-1
        self.dataframe.loc[:, 'SOG'] = self.dataframe.loc[:, 'SOG'] / (self.sog_max + 1e-6)           
                


    def __len__(self):
        """
        Returns:
            int: Number of valid input-output sequences.
        """
        return len(self.valid_idxs)

    def __getitem__(self, idx):
        """
        Constructs one input-output pair from the AIS data.

        Args:
            idx (int): Index of the sequence.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: Tuple of input tensor x and target tensor y.
        """
        real_idx = self.valid_idxs[idx]

        # --- 3. COG (Degrees to Radians) ---
        cog_deg = self.dataframe.iloc[real_idx: real_idx + self.seq_input_length]['COG'].to_numpy(dtype=float)
        cog_rad = np.deg2rad(cog_deg) # Convert to radians
        cog_tensor = torch.tensor(cog_rad, dtype=torch.float32)
        
        cog_sin = torch.sin(cog_tensor)
        cog_cos = torch.cos(cog_tensor)

        # x features
        x_data = self.dataframe.iloc[real_idx: real_idx + self.seq_input_length][['Latitude', 'Longitude', 'SOG']].to_numpy(dtype=float)
        x_tensor = torch.tensor(x_data, dtype=torch.float32)
        
        # Concatenate: Lat, Lon, SOG, Sin, Cos
        x = torch.cat((x_tensor, cog_sin.unsqueeze(-1), cog_cos.unsqueeze(-1)), dim=-1)

        # y labels
        y_data = self.dataframe.iloc[real_idx + self.seq_input_length: 
                                     real_idx + self.seq_input_length + self.seq_output_length][['Latitude', 'Longitude']].to_numpy(dtype=float)
        y = torch.tensor(y_data, dtype=torch.float32)

        return x, y

## 2) Models

In [None]:
class GRUModel(nn.Module):
    '''
    A GRU-based sequence-to-sequence model for predicting future positions based on input sequences.
    Args:
        input_size (int): The number of expected features in the input (default is 5).
        embed_size (int): The size of the embedding layer (default is 64).
        hidden_size (int): The number of features in the hidden state of the GRU (default is 256).
        output_size (int): The number of expected features in the output (default is 2).
        num_layers (int): Number of hidden layers in the GRU (default is 1).
        dropout (float): Dropout probability for the GRU layers (default is 0.1).
        first_linear (bool): Whether to include a linear embedding layer before the GRU (default is True).
    Returns:
        torch.Tensor: The output tensor containing predicted positions with shape (batch_size, seq_length, output_size).
    '''
    def __init__(self, input_size = 5, embed_size = 64, hidden_size = 256, output_size = 2, num_layers=1, dropout=0.1, first_linear=True):
        super().__init__()
        self.first_linear = first_linear

        if first_linear:
            # Map input from size 5 (input_size) to size 64 (embed_size)
            self.embedding = nn.Linear(input_size, embed_size)
            # GRU works in 64-dim space 
            self.gru = nn.GRU(
                embed_size,
                hidden_size,
                num_layers,
                bias=True,
                batch_first=True,
                dropout = dropout ,
                bidirectional=False
            )
        else:
            # GRU works in 4-dim space 
            self.gru = nn.GRU(
                input_size,
                hidden_size,
                num_layers,
                bias=True,
                batch_first=True,
                dropout = dropout ,
                bidirectional=False
            )

        # Map GRU outputs from size 64 (embed_size) to size 2 (input/output_size) again 
        self.fc_out = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        if self.first_linear:
            x = self.embedding(x)        # (batch, seq, 64)
        out, hidden = self.gru(x)    # (batch, seq, 64)
        out = self.fc_out(out)       # (batch, seq, 2)

        return out

In [None]:
class LSTMModel(nn.Module):
    '''
    A LSTM-based sequence-to-sequence model for predicting future positions based on input sequences.
    Args:
        input_size (int): The number of expected features in the input (default is 5).
        embed_size (int): The size of the embedding layer (default is 64).
        hidden_size (int): The number of features in the hidden state of the LSTM (default is 256).
        output_size (int): The number of expected features in the output (default is 2).
        num_layers (int): Number of hidden layers in the LSTM (default is 1).
        dropout (float): Dropout probability for the LSTM layers (default is 0.0).
        first_linear (bool): Whether to include a linear embedding layer before the LSTM (default is True).
    Returns:
        torch.Tensor: The output tensor containing predicted positions with shape (batch_size, seq_length, output_size).

    '''
    def __init__(self, input_size = 5, embed_size = 64, hidden_size = 256, output_size = 2, num_layers=1, dropout=0.0, first_linear=True):
        super().__init__()
        self.first_linear = first_linear

        if self.first_linear:
            # Map input from size 5 (input_size) to size 64 (embed_size)
            self.embedding = nn.Linear(input_size, embed_size)

            # LSTM works in 64-dim space 
            self.lstm = nn.LSTM(
                embed_size,
                hidden_size,
                num_layers,
                bias=True,
                batch_first=True,
                dropout = dropout ,
                bidirectional=False
            )
        
        else:
            # LSTM works in 4-dim space 
            self.lstm = nn.LSTM(
                input_size,
                hidden_size,
                num_layers,
                bias=True,
                batch_first=True,
                dropout = dropout ,
                bidirectional=False
            )

        # Map LSTM outputs from size 64 (embed_size) to size 2 (input/output_size) again 
        self.fc_out = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        if self.first_linear:
            x = self.embedding(x)        # (batch, seq, 64)
        out, (hidden, cell) = self.lstm(x)    # (batch, seq, 64)
        out = self.fc_out(out)       # (batch, seq, 2)

        return out

## 3) Loss functions

In [None]:
class HaversineLoss(torch.nn.Module):
    def __init__(self, 
                 lat_min, lat_max,
                 lon_min, lon_max,
                 radius_earth_km=6371.0):
        super().__init__()
        self.lat_min = lat_min
        self.lat_max = lat_max
        self.lon_min = lon_min
        self.lon_max = lon_max
        self.radius_earth_km = radius_earth_km

    def forward(self, y_pred, y_true):
        """
        Args:
            y_pred: (batch, seq, 2) lat/lon normalizados [0,1]
            y_true: (batch, seq, 2) lat/lon normalizados [0,1]
        """

        # --- DESNORMALIZAR ---
        lat_true = y_true[..., 0] * (self.lat_max - self.lat_min) + self.lat_min
        lon_true = y_true[..., 1] * (self.lon_max - self.lon_min) + self.lon_min

        lat_pred = y_pred[..., 0] * (self.lat_max - self.lat_min) + self.lat_min
        lon_pred = y_pred[..., 1] * (self.lon_max - self.lon_min) + self.lon_min

        # --- PASAR A RADIANES ---
        lat1 = torch.deg2rad(lat_true)
        lon1 = torch.deg2rad(lon_true)
        lat2 = torch.deg2rad(lat_pred)
        lon2 = torch.deg2rad(lon_pred)

        # --- HAVERSINE ---
        dlat = lat2 - lat1
        dlon = lon2 - lon1

        a = (torch.sin(dlat / 2) ** 2 +
             torch.cos(lat1) * torch.cos(lat2) *
             torch.sin(dlon / 2) ** 2)

        a = torch.clamp(a, 0.0, 1.0)
        eps = 1e-9

        c = 2 * torch.atan2(torch.sqrt(a + eps), torch.sqrt(1 - a + eps))
        distance = self.radius_earth_km * c  # => km

        return torch.mean(distance)



In [None]:
class HybridTrajectoryLoss(torch.nn.Module):
    def __init__(self, w_pos=1.0, w_ang=0.2, w_spd=0.1, radius_km=6371.0, max_pos_error_km=10.0):
        super().__init__()
        self.w_pos = w_pos
        self.w_ang = w_ang
        self.w_spd = w_spd
        self.radius_km = radius_km
        self.max_pos_error_km = max_pos_error_km

    def haversine_dist(self, p1, p2):
        """
        Calculates Great Circle distance (km) between points.
        Input: (Batch, ..., 2) where 0=Lat, 1=Lon in RADIANS.
        """
        lat1, lon1 = p1[..., 0], p1[..., 1]
        lat2, lon2 = p2[..., 0], p2[..., 1]

        dlat = lat2 - lat1
        dlon = lon2 - lon1

        a = (torch.sin(dlat / 2) ** 2 +
             torch.cos(lat1) * torch.cos(lat2) * torch.sin(dlon / 2) ** 2)
        
        # Stability clamp
        a = torch.clamp(a, 0.0, 1.0)
        
        c = 2 * torch.atan2(torch.sqrt(a + 1e-9), torch.sqrt(1 - a + 1e-9))
        return self.radius_km * c

    def forward(self, y_pred, y_true):
        """
        Args:
            y_pred: (Batch, Seq_Len, 2) -> Lat/Lon in RADIANS
            y_true: (Batch, Seq_Len, 2) -> Lat/Lon in RADIANS
        """

        # --- 1. Position Loss (Haversine -> Tanh) ---
        dist_km = self.haversine_dist(y_pred, y_true)
        
        # Normalize: 0 to 1
        # If the error is > max_pos_error_km, the loss saturates to 1.0
        loss_pos = torch.tanh(dist_km.mean() / self.max_pos_error_km)


        # --- 2. Angle Loss (Euclidean Vector Approximation) ---
        vec_pred = y_pred[:, -1, :] - y_pred[:, 0, :]
        vec_true = y_true[:, -1, :] - y_true[:, 0, :]
        
        # Calculate Cosine Similarity
        cos_sim = F.cosine_similarity(vec_pred, vec_true, dim=1, eps=1e-8)
        
        # Normalize to [0, 1]
        # Range of cosine is [-1, 1], so (1 - cos) is [0, 2]. Divide by 2.
        loss_ang = (1.0 - cos_sim).mean() / 2.0


        # --- 3. Speed Loss (Euclidean Diff -> Tanh) ---
        vel_pred = torch.diff(y_pred, dim=1)
        vel_true = torch.diff(y_true, dim=1)
        
        # Magnitude of the difference vector
        speed_pred = torch.linalg.norm(vel_pred, dim=2)
        speed_true = torch.linalg.norm(vel_true, dim=2)
        
        # Squared Error
        speed_sq_err = (speed_pred - speed_true) ** 2
        
        # Normalize to [0, 1]
        loss_spd = torch.tanh(speed_sq_err.mean())


        # --- Combine ---
        total_loss = (self.w_pos * loss_pos) + \
                     (self.w_ang * loss_ang) + \
                     (self.w_spd * loss_spd)

        return total_loss, {
            "loss_pos": loss_pos.item(),
            "loss_ang": loss_ang.item(),
            "loss_spd": loss_spd.item()
        }

## 4) Training

We trained the model in 4 different steps :
<br> <br>
1) Training the parameters for GRU & LSTM and Haversine & MAE loss: **number of hidden units per layer** and **number of hidden layers** <br>
2) Finetuning for the best above mentioned model : **Batch size**, **Dropout rate** and **Window size** <br>
3) Training the above mentioned best model <br>
4) Training this same model but with more data (4 days) <br>
<br>
In the cell below, we have kept step 3): training with 1 day of data the best model. This cell should run, but could take a lot of time, so we recommend to run this on a server with GPUs. 
    

In [None]:
sequence_input_length = 3
sequence_output_length = 3
batch_size = 32 
dropout_num = 0.1 #FOR THE DROPOUT LAYER IN THE MODEL
lr = 0.00001 #LEARNING RATE FOR ADAM OPTIMIZER
num_epochs = 1000 #NUMBER OF EPOCHS TO TRAIN
patience = 5 #EARLY STOPPING PATIENCE

type_of_loss = 'MAE'
trainset = AISDataset('datasplits/train/train_aisdk-2025-02-27.csv', seq_input_length=sequence_input_length, seq_output_length=sequence_output_length)

# 2. Extract stats from Train Set
train_stats = trainset.stats
# 3. Pass stats to Validation Set
valset = AISDataset('datasplits/val/val_aisdk-2025-02-27.csv', seq_input_length=sequence_input_length, seq_output_length=sequence_output_length, stats=train_stats)
    
# Create data loaders
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(valset, batch_size=batch_size, shuffle=True, num_workers=0)
    
lat_min = min(trainset.lat_min, valset.lat_min)
lat_max = max(trainset.lat_max, valset.lat_max)
lon_min = min(trainset.lon_min, valset.lon_min)
lon_max = max(trainset.lon_max, valset.lon_max)

print(f"Number of training batches: {len(train_loader)}")
print(f"Number of validation batches: {len(val_loader)}")

# Initialize the model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

training_losses = {}
validation_losses = {}
early_stopping_epochs = {}
val_losses_i = {}
train_losses_i = {}    

model = GRUModel(input_size=5, embed_size=64, hidden_size=256, output_size=2, num_layers=2, dropout=0.1).to(device)

# Define optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
if type_of_loss == 'HAVERSINE': 
    loss_fn = HaversineLoss(lat_min, lat_max, lon_min, lon_max)
elif type_of_loss == 'MAE':
    loss_fn = torch.nn.L1Loss()

# Training loop
best_val_loss = float('inf')
patience_counter = 0
train_losses_list = []
val_losses_list = []
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    model.train()
    train_loss = 0

    # Training step with tqdm over batches
    batch_bar = tqdm(train_loader, desc="Training", leave=False)
    for batch_idx, (data, target) in enumerate(batch_bar):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        l = loss_fn(output, target)
        l.backward()
        optimizer.step()
        train_loss += l.item()
        batch_bar.set_postfix(loss=l.item())

    train_loss /= len(train_loader)
    train_losses_list.append(train_loss)
    print(f"Training Loss: {train_loss:.6f}")
        
    # Validation step with tqdm
    model.eval()
    val_loss = 0
    val_bar = tqdm(val_loader, desc="Validation", leave=False)
    with torch.no_grad():
        for data, target in val_bar:
            data, target = data.to(device), target.to(device)
            output = model(data)
            batch_loss = loss_fn(output, target).item()
            val_loss += batch_loss
            val_bar.set_postfix(loss=batch_loss)

    val_loss /= len(val_loader)
    val_losses_list.append(val_loss)
    print(f"Validation Loss: {val_loss:.6f}")
        
    # Early stopping check
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        train_loss_model = train_loss
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"Early stopping triggered at epoch {epoch+1}.")
            break
    
# Save the trained model
torch.save(model.state_dict(), 'results/models/mae_final_model.pth')
    
# Save training and validation losses
training_losses['final_model'] = train_losses_list
validation_losses['final_model'] = val_losses_list
early_stopping_epochs['final_model'] = epoch + 1
    
with open('results/results_final_model_mae.json', 'w') as f:
    json.dump({
        'training_losses': training_losses,
        'validation_losses': validation_losses,
        'early_stopping_epochs': early_stopping_epochs
    }, f)
    
print("Training complete. Model and results saved.")

Number of training batches: 2769
Number of validation batches: 554
Using device: cpu

Epoch 1/1000


                                                                          

Training Loss: 0.106579


                                                                         

KeyboardInterrupt: 

## 5) Evaluation

### 5.1) Metrics

In [43]:
def decode_predictions(predictions, lat_min, lat_max, lon_min, lon_max):
    """
    Decode normalized predictions back to original latitude and longitude.
    Returns denormalized coordinates as a tensor.

    Args:
        predictions (Tensor): Normalized tensor of shape (..., 2).
        lat_min (float): Minimum latitude used for normalization.
        lat_max (float): Maximum latitude used for normalization.
        lon_min (float): Minimum longitude used for normalization.
        lon_max (float): Maximum longitude used for normalization.

    Returns:
        Tensor: Decoded predictions in original lat/lon coordinates.
    """
    decoded = predictions.clone()
    decoded[..., 0] = decoded[..., 0] * (lat_max - lat_min) + lat_min
    decoded[..., 1] = decoded[..., 1] * (lon_max - lon_min) + lon_min
    return decoded

In [44]:
def haversine_dist(coords1, coords2):
    """
    Compute average Haversine distance (in km) between batches of coordinate pairs.
    Inputs are iterables of (lat, lon) pairs.

    Args:
        coords1 (Iterable of tuples): First set of (lat, lon) coordinates.
        coords2 (Iterable of tuples): Second set of (lat, lon) coordinates.

    Returns:
        float: Average Haversine distance across all pairs (in km).
    """
    R = 6371.0  # Earth radius in kilometers
    distance = 0.0

    for coord1, coord2 in zip(coords1, coords2):
        lat1, lon1 = coord1
        lat2, lon2 = coord2

        dlat = radians(lat2 - lat1)
        dlon = radians(lon2 - lon1)

        a = sin(dlat / 2)**2 + cos(radians(lat1)) * cos(radians(lat2)) * sin(dlon / 2)**2
        c = 2 * atan2(sqrt(a), sqrt(1 - a))

        distance += R * c
    return distance / len(coords1)

In [45]:
def ADE(coords1, coords2):
    """
    Compute Average Displacement Error (ADE) between two sequences of coordinates.
    Inputs must be torch tensors in degrees.
    Output is average error per time step in kilometers.

    Args:
        coords1 (Tensor): Predicted coordinates, shape (batch_size, seq_len, 2).
        coords2 (Tensor): Ground truth coordinates, shape (batch_size, seq_len, 2).

    Returns:
        float: ADE in km.
    """
    distance = 0.0
    batch_size, seq_len, _ = coords1.shape
    for i in range(seq_len):
        # Compute average Haversine distance for each time step across the batch
        distance += haversine_dist(coords1[:, i, :], coords2[:, i, :])
    return distance / seq_len


In [46]:
def FDE(coords1, coords2):
    """
    Compute Final Displacement Error (FDE) between last predicted and true point.
    Inputs must be torch tensors in degrees.
    Output is error at last time step in kilometers.

    Args:
        coords1 (Tensor): Predicted coordinates, shape (batch_size, seq_len, 2).
        coords2 (Tensor): Ground truth coordinates, shape (batch_size, seq_len, 2).

    Returns:
        float: FDE in km.
    """
    final_pred = coords1[:, -1]
    final_gt   = coords2[:, -1]

    distance = 0.0
    # Calculate Haversine distance for each pair in the batch
    for pred_point, gt_point in zip(final_pred, final_gt):
        distance += haversine_dist(pred_point.unsqueeze(0), gt_point.unsqueeze(0))
    return distance / final_pred.shape[0]

In [47]:
def RMSE(coords1, coords2):
    """
    Compute Root Mean Square Error (RMSE) between predicted and true coordinates.
    Inputs must be torch tensors in degrees.
    Output is root of average squared Haversine error in kilometers.

    Args:
        coords1 (Tensor): Predicted coordinates, shape (batch_size, seq_len, 2).
        coords2 (Tensor): Ground truth coordinates, shape (batch_size, seq_len, 2).

    Returns:
        float: RMSE in km.
    """
    batch_size, seq_len, _ = coords1.shape
    mse = 0.0
    for i in range(seq_len):
        # Sum squared Haversine distances for each time step
        mse += haversine_dist(coords1[:, i, :], coords2[:, i, :]) ** 2
    mse /= seq_len
    return sqrt(mse)

### 5.2) Evaluate 
<br> In this notebook we trained the 1-day dataset, so we will evaluate this model as well. In the full project, we also evaluated the 4-days dataset. 

In [None]:
# Load normalization stats from training/validation to ensure consistent scaling.
data_dir = "datasplits"
train_df = pd.read_csv(os.path.join(data_dir, 'train', 'train_aisdk-2025-02-27.csv'))
val_df = pd.read_csv(os.path.join(data_dir, 'val', 'val_aisdk-2025-02-27.csv'))
lat_min = min(train_df['Latitude'].min(), val_df['Latitude'].min())
lat_max = max(train_df['Latitude'].max(), val_df['Latitude'].max())
lon_min = min(train_df['Longitude'].min(), val_df['Longitude'].min())
lon_max = max(train_df['Longitude'].max(), val_df['Longitude'].max())
sog_max = max(train_df['SOG'].max(), val_df['SOG'].max())
    
# Initialize LSTM and GRU models with same architecture as during training.
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

model_gru = GRUModel(
    input_size=5,
    embed_size=64,
    hidden_size=256,
    output_size=2,
    num_layers=2,
    dropout=0.1,
).to(device)

model_lstm = LSTMModel(
    input_size=5,
    embed_size=64,
    hidden_size=256,
    output_size=2,
    num_layers=2,
    dropout=0.1,
).to(device)

model_gru.load_state_dict(torch.load('results/models/mae_final_model.pth', map_location=device))
model_lstm.load_state_dict(torch.load('results/models/hav_final_model.pth', map_location=device))

# Load test dataset using same normalization stats as training.
testset = AISDataset(
    os.path.join(data_dir, 'test', 'test_aisdk-2025-02-27.csv'),
    seq_input_length=3,
    seq_output_length=3,
    stats=(lat_min, lat_max, lon_min, lon_max, sog_max)
)
test_loader = torch.utils.data.DataLoader(
    testset, batch_size=32, shuffle=False, num_workers=1
)

# Evaluate both models using standard metrics (ADE, FDE, RMSE) on unnormalized outputs.
model_lstm.eval()
model_gru.eval()
total_ade_lstm = 0
total_fde_lstm = 0
total_rmse_lstm = 0
total_ade_gru = 0
total_fde_gru = 0
total_rmse_gru = 0

print("Evaluating model...")
with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output_lstm = model_lstm(data)
        output_gru = model_gru(data)
        
        # Unnormalize predicted and target coordinates for metric calculation.
        output_lstm = decode_predictions(output_lstm, lat_min, lat_max, lon_min, lon_max)
        output_gru = decode_predictions(output_gru, lat_min, lat_max, lon_min, lon_max)
        target = decode_predictions(target, lat_min, lat_max, lon_min, lon_max)
        # Sanity check
        assert output_lstm.shape == target.shape, \
            f"Shape mismatch: {output_lstm.shape} vs {target.shape}"
            
        total_ade_lstm += ADE(output_lstm, target)
        total_ade_gru += ADE(output_gru, target)
        total_fde_lstm += FDE(output_lstm, target)
        total_fde_gru += FDE(output_gru, target)
        total_rmse_lstm += RMSE(output_lstm, target)
        total_rmse_gru += RMSE(output_gru, target)

average_ade_lstm = total_ade_lstm / len(test_loader)
average_fde_lstm = total_fde_lstm / len(test_loader)
average_rmse_lstm = total_rmse_lstm / len(test_loader)
print(f"Average ADE LSTM: {average_ade_lstm:.4f}")
print(f"Average FDE LSTM: {average_fde_lstm:.4f}")
print(f"Average RMSE LSTM: {average_rmse_lstm:.4f}")

average_ade_gru = total_ade_gru / len(test_loader)
average_fde_gru = total_fde_gru / len(test_loader)
average_rmse_gru = total_rmse_gru / len(test_loader)

print(f"Average ADE GRU: {average_ade_gru:.4f}")
print(f"Average FDE GRU: {average_fde_gru:.4f}")
print(f"Average RMSE GRU: {average_rmse_gru:.4f}")