In [None]:
%load_ext autoreload
%autoreload 2

In [12]:
import os
import multiprocessing
multiprocessing.set_start_method("spawn", force=True)

##> import libraries
import sys
from pathlib import Path
import random
import time
from itertools import product

root_dir = Path.cwd().resolve().parent
if root_dir.exists():
    sys.path.append(str(root_dir))
else:
    raise FileNotFoundError('Root directory not found')

#> import flower
import flwr as fl

#> import custom libraries
from src.load import load_df_to_dataset
from src.EAE import EvidentialTransformerDenoiseAutoEncoder, evidential_regression
from src.client import train_and_evaluate_local, evaluate_saved_model, evaluate_local
from src.datasets import TrajectoryDataset, generate_ood_data,clean_outliers_by_quantile
from src.plot import plot_loss, plot_tsne_with_uncertainty, plot_uncertainty

#> torch libraries
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn
import pandas as pd
import numpy as np
import statsmodels.api as sm
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from sklearn.model_selection import train_test_split

#> Plot
import matplotlib.pyplot as plt
import seaborn as sns
# import scienceplots  # https://github.com/garrettj403/SciencePlots?tab=readme-ov-file
#plt.style.use(['science', 'grid', 'notebook'])  # , 'ieee'


# %matplotlib inline
#%matplotlib widget
import torch.nn.functional as F
import math

In [None]:
  # Define the dataset catalog
assets_dir = root_dir.parents[3] / 'aistraj' / 'bin'/ 'tvt_assets'
assets_dir = assets_dir.resolve()
print(f"Assets Directory: {assets_dir}")
if not assets_dir.exists():
    raise FileNotFoundError('Assets directory not found')
    
saved_model_dir = root_dir / 'models'
saved_model_dir = saved_model_dir.resolve()
print(f"Assets Directory: {saved_model_dir}")
if not saved_model_dir.exists():
    raise FileNotFoundError('Model directory not found')

In [14]:
def load_datasets(assets_dir, seq_len=960, batch_size=32):

    # train dataset
    train_pickle_path_extend = assets_dir / 'extended' / 'cleaned_extended_train_df.parquet'
    train_df_extend = load_df_to_dataset(train_pickle_path_extend).data

    # validation dataset
    validate_pickle_path_extend = assets_dir / 'extended' / 'cleaned_extended_validate_df.parquet'
    validate_df_extend = load_df_to_dataset(validate_pickle_path_extend).data
    ood_df = generate_ood_data(validate_df_extend, ood_mean=10, ood_std=3)
    #print(validate_df_extend)
    # Define the list of features to discard
    drop_features_list = ['epoch', 'datetime', 'obj_id', 'traj_id', 'stopped', 'curv', 'abs_ccs']
    
    # Specify columns to clean
    columns_to_clean = ['speed_c', 'lon', 'lat']  # Specify columns to clean
    cleaned_train_data = clean_outliers_by_quantile(train_df_extend, columns_to_clean, remove_na=False)
    cleaned_val_data = clean_outliers_by_quantile(validate_df_extend, columns_to_clean, remove_na=False)
   
    # Create training and validation datasets
    #df_extend = pd.concat([cleaned_train_data, cleaned_val_data])
    df_extend = pd.concat([train_df_extend, validate_df_extend])
    df_extend = df_extend.sort_index()
    # Create training and validation datasets
    train_dataset_traj = TrajectoryDataset(
        cleaned_train_data,
        seq_len=seq_len,
        mode='ae',
        drop_features_list=drop_features_list,
        scaler_method='QuantileTransformer',
        filter_less_seq_len = seq_len
    )
    val_dataset_traj = TrajectoryDataset(
        cleaned_val_data,
        seq_len=seq_len,
        mode='ae',
        drop_features_list=drop_features_list,
        scaler_method='QuantileTransformer',
        filter_less_seq_len = seq_len
    )
    val_ood_dataset_traj = TrajectoryDataset(
        ood_df,
        seq_len=seq_len,
        mode='ae',
        drop_features_list=drop_features_list,
        filter_percent = None,
        scaler = None,
        filter_less_seq_len = None,
        scaler_method = 'No_Scaler'
    )    
    #val_dataset_traj = train_dataset_traj
    # dataset_traj = TrajectoryDataset(
    #     df_extend,
    #     seq_len=seq_len,
    #     mode='ae',
    #     drop_features_list=drop_features_list
    # )
    # train_dataset_traj = dataset_traj
    # val_dataset_traj = dataset_traj
    # Creating Data Loader
    train_dataloader_traj = DataLoader(
        train_dataset_traj,
        batch_size=batch_size,
        num_workers=2,
        shuffle=True,
        pin_memory=False
    )
    val_dataloader_traj = DataLoader(
        val_dataset_traj,
        batch_size=batch_size,
        num_workers=2,
        shuffle=False,
        pin_memory=False
    )
    val_ood_dataloader_traj = DataLoader(
        val_ood_dataset_traj,
        batch_size=batch_size,
        num_workers=2,
        shuffle=False,
        pin_memory=False
    )
    return train_dataloader_traj, val_dataloader_traj, val_ood_dataloader_traj, train_dataset_traj.n_features, train_dataset_traj, val_dataset_traj, val_ood_dataset_traj

In [None]:
# Load Dataset
train_dataloader_traj, val_dataloader_traj, val_ood_dataloader_traj, input_dim, train_dataset_traj, val_dataset_traj, val_ood_dataset_traj = load_datasets(assets_dir)

In [None]:
val_dataset_traj.inputs.shape

In [None]:
train_dataset_traj.inputs.shape

In [None]:
val_dataset_traj.labels[['traj_id', 'lon', 'lat', 'obj_id', 'datetime']]

In [None]:
val_dataset_traj.dataframe[['traj_id', 'lon', 'lat', 'obj_id', 'datetime']]

In [None]:
for key in ['epoch', 'stopped', 'cog_c', 'aad', 'rot_c', 'speed_c', 'distance_c',
       'acc_c', 'cdd', 'dir_ccs', 'dist_ww', 'dist_ra',
       'dist_cl', 'dist_ma', 'traj_id', 'lon', 'lat', 'obj_id',
       'season', 'part_of_day', 'month_sin', 'month_cos', 'hour_sin',
       'hour_cos']:
    print(key)
    plt.figure()
    train_dataset_traj.dataframe[key].hist(bins=30, grid=False, edgecolor='black')
    #val_dataset_traj.labels[key].hist(bins=30, grid=False, edgecolor='black')
    
    max_index = train_dataset_traj.dataframe[key].idxmax()
    max_value = train_dataset_traj.dataframe[key].max()
    print(f"max index and value: {max_index}", max_value, (val_dataset_traj.dataframe[key] == max_value).sum())
    # 获取最大值以及上下的值
    # 获取索引列表
    index_list = train_dataset_traj.dataframe[key].index.to_list()
    
    # 找到最大值索引在索引列表中的位置
    max_pos = index_list.index(max_index)
    
    # 获取上下索引（防止越界）
    start_pos = max(0, max_pos - 3)
    end_pos = min(len(index_list) - 1, max_pos + 3)
    
    # 提取对应的行
    result = train_dataset_traj.dataframe[[key, 'obj_id', 'traj_id', 'lon', 'lat','datetime']].loc[index_list[start_pos:end_pos + 1]]

    print(result)
    plt.show()

In [None]:
class StochasticSampler(nn.Module):
    """We basically want to parametrize the sampling from the latent space"""

    def __init__(self, deterministic=False):
        super().__init__()
        self.sampler = torch.distributions.Normal(loc=0, scale=1)
        self.deterministic = deterministic

    def forward(self, z_mean, z_log_var):
        """Return a normal sample value Z from the latent space given a mean and variance"""
        # z_mean and z_log_var are mean and log-var estimates of the latent space
        # under the assumption that the latent space is a gaussian normal
        device = z_mean.device
        # Scales and shifts the sampled values using the reparameterization trick
        eps = self.sampler.sample(z_mean.shape).squeeze().to(device)
        # print(eps.shape, z_log_var.shape, z_mean.shape)
        return (
            z_mean
            if self.deterministic
            else (z_mean + torch.exp(0.5 * z_log_var) * eps), torch.exp(0.5 * z_log_var)
        )
        
class LSTMVAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, num_layers=1):
        super(LSTMVAE, self).__init__()
        self.latent_sampler = StochasticSampler(deterministic=False)
        # Encoder
        self.encoder_lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)       # Mean of latent space
        self.fc_log_var = nn.Linear(hidden_dim, latent_dim)  # Log variance of latent space
        #torch.nn.init.constant_(self.fc_log_var.weight, -0.5)
        torch.nn.init.constant_(self.fc_log_var.weight, -1.0)
        torch.nn.init.constant_(self.fc_log_var.bias, -1.0)
        torch.nn.init.xavier_uniform_(self.fc_mu.weight)
        self.bn_mu = nn.BatchNorm1d(latent_dim)
        # Decoder
        self.decoder_fc = nn.Linear(latent_dim, hidden_dim)
        self.decoder_lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True)
        self.output_fc = nn.Linear(hidden_dim, input_dim)  # Map hidden_dim back to input_dim

    def encode(self, x, lengths):
        # Encode the input sequence
        packed_input = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
        packed_output, (h_n, _) = self.encoder_lstm(packed_input)
        h_n_last = h_n[-1]  # Use the last hidden state

        # Compute latent space parameters
        mu = self.fc_mu(h_n_last)
        log_var = self.fc_log_var(h_n_last)
        mu = self.bn_mu(mu)
        return mu, log_var

    def reparameterize(self, mu, log_var):
        # Reparameterization trick: z = mu + epsilon * std
        std = torch.exp(0.5 * log_var)
        epsilon = torch.randn_like(std)
        z = mu + epsilon * std
        return z, std

    def decode(self, z, max_len, lengths):
        # Map latent variable z to initial hidden state
        hidden = self.decoder_fc(z).unsqueeze(0).repeat(self.decoder_lstm.num_layers, 1, 1)
        batch_size = z.size(0)

        # Create an all-zero input sequence for the decoder
        decoder_input = torch.zeros(batch_size, max_len, hidden.size(2)).to(z.device)

        # Decode sequence
        packed_decoder_input = pack_padded_sequence(decoder_input, lengths, batch_first=True, enforce_sorted=False)
        packed_output, _ = self.decoder_lstm(packed_decoder_input, (hidden, torch.zeros_like(hidden)))
        output, _ = pad_packed_sequence(packed_output, batch_first=True, total_length=max_len)

        # Map hidden_dim back to input_dim
        output = self.output_fc(output)
        return output

    def forward(self, x, lengths):
        max_len = x.size(1)
        # Encoding
        mu, log_var = self.encode(x, lengths)
        z = self.reparameterize(mu, log_var)
        #z, _ = self.latent_sampler(mu, log_var)
        # Decoding
        output = self.decode(mu, max_len, lengths)

        return output, mu, log_var

In [None]:
class CNNVAE(nn.Module):
    def __init__(self, input_dim, latent_dim, seq_len):
        """
        基于 CNN 的 VAE
        参数:
            input_dim: 时间序列的特征维度 (feature_dim)
            latent_dim: 潜在空间的维度
        """
        super(CNNVAE, self).__init__()
        
        # 编码器 (Encoder)
        self.encoder = nn.Sequential(
            nn.Conv1d(input_dim, 64, kernel_size=5, stride=2, padding=2),  # (batch_size, 64, seq_len/2)
            nn.ReLU(),
            nn.Conv1d(64, 128, kernel_size=5, stride=2, padding=2),        # (batch_size, 128, seq_len/4)
            nn.ReLU(),
            nn.Conv1d(128, 256, kernel_size=5, stride=2, padding=2),       # (batch_size, 256, seq_len/8)
            nn.ReLU()
        )
        self.seq_len = seq_len
        self.compressed_seq_len = seq_len/8
        
        # 潜在空间的均值和对数方差
        self.encoder_final = nn.Conv1d(256, latent_dim, kernel_size=1)       # (batch_size, latent_dim, compressed_seq_len)
        self.encode_mean = nn.Linear(int(latent_dim*self.compressed_seq_len), latent_dim)
        self.encode_logvar = nn.Linear(int(latent_dim*self.compressed_seq_len), latent_dim)
        self.decode_adapter = nn.Linear(latent_dim, int(latent_dim*self.compressed_seq_len))

        # 解码器 (Decoder)
        self.decoder = nn.Sequential(
            nn.ConvTranspose1d(latent_dim, 128, kernel_size=5, stride=2, padding=2, output_padding=1),  # (batch_size, 128, seq_len/4)
            nn.ReLU(),
            nn.ConvTranspose1d(128, 64, kernel_size=5, stride=2, padding=2, output_padding=1),          # (batch_size, 64, seq_len/2)
            nn.ReLU(),
            nn.ConvTranspose1d(64, input_dim, kernel_size=5, stride=2, padding=2, output_padding=1),    # (batch_size, input_dim, seq_len)
            #nn.Sigmoid()  # 输出值在 0 和 1 之间
        )
    
    def reparameterize(self, mu, log_var):
        """
        重参数化技巧
        参数:
            mu: 潜在空间均值 (batch_size, latent_dim, compressed_seq_len)
            log_var: 潜在空间对数方差 (batch_size, latent_dim, compressed_seq_len)
        返回:
            z: 采样的潜在变量 (batch_size, latent_dim, compressed_seq_len)
        """
        std = torch.exp(0.5 * log_var)  # 计算标准差
        epsilon = torch.randn_like(std)  # 随机噪声
        z = mu + epsilon * std  # 重参数化
        return z, std

    def forward(self, x):
        """
        前向传播
        参数:
            x: 输入时间序列 (batch_size, seq_len, feature_dim)
        返回:
            reconstructed: 重构的时间序列 (batch_size, seq_len, feature_dim)
            mu: 潜在空间的均值 (batch_size, latent_dim, compressed_seq_len)
            log_var: 潜在空间的对数方差 (batch_size, latent_dim, compressed_seq_len)
        """
        # 转换输入形状以适配 Conv1d (batch_size, feature_dim, seq_len)
        x = x.permute(0, 2, 1)
        
        # 编码
        encoded = self.encoder(x)  # (batch_size, 256, compressed_seq_len)
        encoded = self.encoder_final(encoded)  # (batch_size, latent_dim, compressed_seq_len)
        encoded_output =  encoded.view(encoded.size(0), -1)
        mu = self.encode_mean(encoded_output)
        log_var = self.encode_logvar(encoded_output)
        #print(mu.shape, log_var.shape)
        # 重参数化
        z, _ = self.reparameterize(mu, log_var)
        z = self.decode_adapter(mu).view_as(encoded)
        # 解码
        reconstructed = self.decoder(z)  # (batch_size, input_dim, seq_len)
        
        # 恢复输出形状为 (batch_size, seq_len, feature_dim)
        reconstructed = reconstructed.permute(0, 2, 1)
        return reconstructed, mu, log_var

In [20]:
def initialize_weights(m):
    if isinstance(m, (nn.Linear, nn.Conv1d)):
        nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d)):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)


class TransformerAutoEncoder(nn.Module):
    def __init__(self, input_dim, d_model, latent_dim, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, max_seq_length, dropout_rate, output_dim=None):
        super(TransformerAutoEncoder, self).__init__()
        self.d_model = d_model
        self.max_seq_length = max_seq_length
        if output_dim:
            self.output_dim = output_dim

        # Positional Encoding
        self.positional_encoding = nn.Parameter(self._generate_positional_encoding(self.max_seq_length, d_model), requires_grad=False)

        # Encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout_rate, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        
        self.encode_mean = nn.Linear(self.max_seq_length*d_model, latent_dim)
        self.encode_logvar = nn.Linear(self.max_seq_length*d_model, latent_dim)
        self.decode_adapter = nn.Linear(latent_dim, self.max_seq_length*d_model)
        
        # Decoder
        decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout_rate, batch_first=True)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

        # Fully connected layers for encoding inputs and decoding outputs
        self.input_fc = nn.Linear(input_dim, d_model)
        self.output_fc = nn.Linear(d_model, output_dim or input_dim)   # Output 4x dimensions in order to separate mu, v, alpha, beta

        self.bottleneck = nn.Linear(self.max_seq_length*d_model, 10)  # Optional, can skip if not needed
        self.reconstruct = nn.Linear(10, self.max_seq_length*d_model)  # Project back to the original space
        self.flatten = nn.Flatten(start_dim=1,end_dim=2)#
        self.unflatten = nn.Unflatten(dim=1, unflattened_size=(self.max_seq_length, d_model))
        # Apply weight initialization
        self.apply(initialize_weights)

    def forward(self, src, padding_mask=None, return_latent=False, noise_factor=0.05):
        # add noise
        noise = torch.randn_like(src) * noise_factor
        noisy_src = src + noise
        noisy_src = self.input_fc(noisy_src)  # Shape: (batch_size, seq_length, d_model)

        # add positional embedding
        noisy_src += self.positional_encoding[:, :noisy_src.size(1), :]

        if padding_mask is not None:
            # padding_mask:(batch_size, seq_len, input_dim)
            padding_mask_timestep = padding_mask.any(dim=-1)  # [batch_size, seq_len]
            mask_expanded = padding_mask_timestep.unsqueeze(-1).expand_as(noisy_src).bool()
            # mask
            noisy_src = torch.where(mask_expanded, noisy_src, torch.tensor(0.0, device=noisy_src.device))

        # encode
        encoded_memory = self.transformer_encoder(noisy_src)  
        
        encoded_output =  encoded_memory.view(encoded_memory.size(0), -1)
        mu = self.encode_mean(encoded_output)
        log_var = self.encode_logvar(encoded_output)
        # Reparameterisation
        z, _ = self.reparameterize(mu, log_var)
        z = self.decode_adapter(mu).view_as(encoded_memory)
        
        # decode
        #decoded_output = self.transformer_decoder(self.unflatten(hidden_to_decoder), encoded_memory)
        decoded_output = self.transformer_decoder(z, z)
        decoded_output = self.output_fc(decoded_output)  # Shape: (batch_size, seq_length, output_dim * 4)

        if padding_mask is not None:
            mask_expanded = padding_mask_timestep.unsqueeze(-1).expand(-1, -1, decoded_output.size(-1)).float()
            #print(f"decoded_output shape: {decoded_output.shape}")
            #print(f"mask_expanded shape: {mask_expanded.shape}")
            decoded_output = decoded_output * mask_expanded

        # Output mu, v, alpha, beta via Evidential Learning
        # mu, logv, logalpha, logbeta = torch.chunk(decoded_output, 4, dim=2)
        # v = F.softplus(logv) + 1e-6
        # alpha = F.softplus(logalpha) + 1.0
        # beta = F.softplus(logbeta) + 1e-6

        # Return the encoded representation and decoded uncertainty outputs
        # if return_latent:
        #     return mu, v, alpha, beta, encoded_memory
        # else:
        #     return mu, v, alpha, beta
        return decoded_output, mu, log_var


    def _generate_positional_encoding(self, length, d_model):
        position = torch.arange(length).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pos_encoding = torch.zeros(length, d_model)
        pos_encoding[:, 0::2] = torch.sin(position * div_term)
        pos_encoding[:, 1::2] = torch.cos(position * div_term)
        return pos_encoding.unsqueeze(0)  # add batch dimension
        
    def reparameterize(self, mu, log_var):

        std = torch.exp(0.5 * log_var)  # Calculate standard deviation
        epsilon = torch.randn_like(std)  # Random noise
        z = mu + epsilon * std  # Reparameterisation
        return z, std

In [21]:
def train(model, dataloader, criterion, optimizer, device, num_epochs=20):
    model.train()
    kl_weight = 0.1 
    for epoch in range(num_epochs):
        recon_losses = []
        kl_losses = []
        latent_reg = []
        total_loss = 0
        for batch_idx, batch in enumerate(dataloader):
            batch_data = batch['inputs'].to(device)
            masks = batch.get('input_masks', None).to(device)
            if masks is not None:
                lengths = masks.sum(dim=1).to('cpu')
                masks = masks.to(device)
                if masks.dim() == 2:
                    masks_expend = masks.unsqueeze(-1).expand_as(batch_data)
            else:
                lengths = torch.tensor([960]*len(batch_data)).to('cpu')
            optimizer.zero_grad()
            #output, mu, log_var = model(batch_data, lengths)
            #output, mu, log_var = model(batch_data)
            output, mu, log_var = model(batch_data, masks_expend)
            #print(output.shape, batch_data.shape )
            recon_loss = criterion(output, batch_data,masks)
            #recon_loss = ((batch_data - output) ** 2).mean(dim=1)
            #kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) / batch_data.size(0)
            kl_loss = -0.5 * torch.mean(1 + log_var - mu.pow(2) - log_var.exp())
            #latent_regularization = torch.mean(torch.abs(mu)) + torch.mean(torch.abs(log_var))
            #print(mu.mean(), log_var.mean())
            loss = recon_loss + kl_loss#*kl_weight #+ latent_regularization
            recon_losses.append(recon_loss.item())
            kl_losses.append(kl_loss.item())
            #latent_reg.append(latent_regularization.item())
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
        kl_weight = min(1.0, kl_weight + 0.1)
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Recon Loss: {np.mean(recon_losses):.4f}, KL Loss: {np.mean(kl_losses):.4f}")#, LR Loss: {np.mean(latent_reg):.4f}")

def reconstruction_error_per_sample(x, x_reconstructed, reduction='mean'):
    # Ensure input and reconstructed tensors have the same shape
    assert x.shape == x_reconstructed.shape, "Input and reconstructed tensors must have the same shape"

    # Compute element-wise squared error
    errors = (x - x_reconstructed) ** 2  # (batch_size, seq_len, channels)

    # Aggregate errors across the time and channel dimensions
    if reduction == 'mean':
        errors_per_sample = errors.mean(dim=(1, 2))  # Compute the mean error per sample
    elif reduction == 'sum':
        errors_per_sample = errors.sum(dim=(1, 2))  # Compute the total error per sample
    else:
        raise ValueError("Reduction must be 'mean' or 'sum'")

    return errors_per_sample
    
def get_latent_space(model, dataloader, device):
    model.eval()
    latents = []
    latents_std = []
    recon_error = []
    #labels = []  # If there are labels, they can be used for visual distinction
    with torch.no_grad():
        for batch_idx, batch in enumerate(dataloader):
            batch_data = batch['inputs'].to(device)
            masks = batch.get('input_masks', None)
            if masks is not None:
                lengths = masks.sum(dim=1).to('cpu')
            else:
                lengths = torch.tensor([960]*len(batch_data)).to('cpu')
            batch_data = batch_data.to(device)
            #recon_data, mu, log_var = model(batch_data, lengths)
            recon_data, mu, log_var = model(batch_data)
            latent, std = model.reparameterize(mu, log_var)
            #print(log_var)
            #print(latent.shape, std.shape)
            latents.append(mu.cpu().numpy())
            latents_std.append(std.cpu().numpy())
            loss = reconstruction_error_per_sample(batch_data, recon_data)
            recon_error.extend(loss.cpu().numpy())
            # If you have tags, you can add them here
            # labels.append(batch_labels)

    latents = np.vstack(latents)
    latents_std = np.vstack(latents_std)
    #recon_error = np.vstack(recon_error)
    print(f"Get Latent Space Done")
    return latents, latents_std, recon_error

In [22]:
def save_model(model, path):
    torch.save(model.state_dict(), path)
    print(f"Model saved to {path}")
    
def load_model(model, path, device='cpu'):
    model.load_state_dict(torch.load(path, map_location=device))
    model.to(device)
    print(f"Model loaded from {path}") 
    
def masked_reconstruction_loss(original, reconstructed, mask, offset=2.5, batch_mean=True):

    # Compute element-wise Mean Squared Error (MSE)
    error = F.mse_loss(reconstructed, original, reduction="none")  # (batch_size, seq_len, feature_dim)
    error = error.mean(dim=-1)  # Average across feature dimension (batch_size, seq_len)

    # Apply the mask to keep only valid time steps
    masked_error = error * mask  # (batch_size, seq_len)

    # Compute the mean loss over valid time steps for each sample
    loss = masked_error.sum(dim=1) / mask.sum(dim=1)  # Weighted average per sample

    # Compute batch mean if specified
    if batch_mean:
        loss = loss.mean()  # Average over batch

    # Add offset to the final loss
    loss = loss + offset
    
    return loss

In [23]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
input_dim = 20
hidden_dim = 50
latent_dim = 10
num_layers = 2
batch_size = 32
num_epochs = 20

# Create model
#model = LSTMVAE(input_dim, hidden_dim, latent_dim, num_layers).to(device)
#model = CNNVAE(input_dim, latent_dim, seq_len=960).to(device)
model = TransformerAutoEncoder(
    input_dim=input_dim,
    d_model=8,
    latent_dim = latent_dim,
    nhead=4,
    num_encoder_layers=2,
    num_decoder_layers=2,
    dim_feedforward=32,
    max_seq_length=960,
    dropout_rate=0.1
).to(device)

criterion = masked_reconstruction_loss #nn.MSELoss() #masked_reconstruction_loss #nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

save_model_path =  saved_model_dir + 'vae_model_qt_960_20.pth'
# Train the model


In [None]:
train(model, train_dataloader_traj, criterion, optimizer, num_epochs=num_epochs, device=device)
save_model(model, save_model_path)

In [None]:
dataloader_traj = val_dataloader_traj
dataset_traj = val_dataset_traj

In [None]:
latent_space, latents_std, recon_error = get_latent_space(model, dataloader_traj,criterion)

In [None]:
latent_representations = latent_space

In [None]:
latent_representations.shape

In [None]:
latents_std.shape

In [None]:
#recon_error

In [None]:
from sklearn.manifold import TSNE

In [None]:
plt.hist(recon_error, bins=10, edgecolor='black', alpha=0.7)

In [None]:
percentile_99 = np.percentile(recon_error, 99.5)
print(percentile_99)

In [None]:
max(recon_error)

In [None]:
plot_tsne_with_uncertainty(latent_representations, recon_error, uncertainty_type='recon_error', threshold = percentile_99)

In [None]:
plot_tsne_with_uncertainty(latent_representations, latents_std.mean(axis=1), uncertainty_type='latents_std', threshold = np.percentile(latents_std, 99.5))

In [None]:
tsne = TSNE(n_components=1, random_state=42)
label_pos = tsne.fit_transform(dataset_traj.labels[['lon', 'lat']])
tsne = TSNE(n_components=1, random_state=42)
label_obj = tsne.fit_transform(dataset_traj.labels[['traj_id', 'obj_id']])

In [None]:
plot_tsne_with_uncertainty(latent_representations, label_pos, uncertainty_type='pos')

In [None]:
plot_tsne_with_uncertainty(latent_representations, label_obj, uncertainty_type='obj')

In [None]:
latent_representations_ood, latents_std_ood, recon_error_ood = get_latent_space(model, val_ood_dataloader_traj,criterion)
combined_latent_representations = np.concatenate([latent_representations, latent_representations_ood], axis = 0)
combined_recon_error = recon_error + recon_error_ood
combined_latents_std = np.concatenate([latents_std, latents_std_ood], axis = 0) 

In [None]:
percentile_98 = np.percentile(combined_recon_error, 98)
print(percentile_98)
plot_tsne_with_uncertainty(combined_latent_representations, combined_recon_error, uncertainty_type='recon_error', threshold = percentile_98)

In [None]:
plot_tsne_with_uncertainty(combined_latent_representations, combined_latents_std.mean(axis=1), uncertainty_type='latents_std', threshold = np.percentile(combined_latents_std.mean(axis=1), 99.5))

In [None]:
ood_labels = [0] * len(latent_representations) + [1] * len(latent_representations_ood)
plot_tsne_with_uncertainty(combined_latent_representations, ood_labels, uncertainty_type='ood label')

In [None]:
load_model(model, save_model_path, device='cuda')

In [None]:
#for key in val_dataset_traj.labels.keys():
for key in ['epoch', 'stopped', 'cog_c', 'aad', 'rot_c', 'speed_c', 'distance_c',
       'acc_c', 'cdd', 'dir_ccs', 'dist_ww', 'dist_ra',
       'dist_cl', 'dist_ma', 'traj_id', 'lon', 'lat', 'obj_id', 'datetime',
       'season', 'part_of_day', 'month_sin', 'month_cos', 'hour_sin',
       'hour_cos']:
    plot_tsne_with_uncertainty(latent_representations, dataset_traj.labels[key], uncertainty_type=key)

In [None]:
dist = np.sqrt(val_dataset_traj.labels["lon"]**2 + val_dataset_traj.labels["lat"]**2)
plot_tsne_with_uncertainty(latent_representations, dist, uncertainty_type=key)