In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import multiprocessing
multiprocessing.set_start_method("spawn", force=True)

##> import libraries
import sys
from pathlib import Path
import random
import time
from itertools import product
from typing import OrderedDict


root_dir = Path.cwd().resolve().parent
if root_dir.exists():
    sys.path.append(str(root_dir))
else:
    raise FileNotFoundError('Root directory not found')

#> import flower
import flwr as fl
from flwr.common import Context
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner, DirichletPartitioner
from torch.utils.data import DataLoader
from datasets import Dataset


#> import custom libraries
from src.load import load_df_to_dataset
from src.EAE import EvidentialTransformerDenoiseAutoEncoder, evidential_regression
from src.client import train_and_evaluate_local, evaluate_saved_model
from src.datasets import TrajectoryDataset, clean_outliers_by_quantile, generate_ood_data
from src.plot import plot_tsne_with_uncertainty, visualize_mean_features

#> torch libraries
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn
import pandas as pd
import numpy as np
import statsmodels.api as sm
import torch.nn.functional as F
from sklearn.metrics import f1_score, roc_auc_score, average_precision_score, roc_curve

#> Plot
import matplotlib.pyplot as plt
import seaborn as sns
import scienceplots  # https://github.com/garrettj403/SciencePlots?tab=readme-ov-file
# plt.style.use(['science', 'grid', 'notebook', 'ieee'])  # , 'ieee'


# %matplotlib inline
# %matplotlib widget


In [None]:
 # Define the dataset catalog
assets_dir = root_dir.parents[3] / 'aistraj' / 'bin'/ 'tvt_assets'
assets_dir = assets_dir.resolve()
print(f"Assets Directory: {assets_dir}")
if not assets_dir.exists():
    raise FileNotFoundError('Assets directory not found')
    
saved_model_dir = root_dir / 'models'
saved_model_dir = saved_model_dir.resolve()
print(f"Assets Directory: {saved_model_dir}")
if not saved_model_dir.exists():
    raise FileNotFoundError('Model directory not found')

In [None]:
 # setup_environment()
if multiprocessing.get_start_method(allow_none=True) != "spawn":
    try:
        multiprocessing.set_start_method("spawn", force=True)
    except RuntimeError as e:
        print(f"Warning: {e}")

# Define the dataset catalog
assets_dir = Path("/data1/aistraj/bin/tvt_assets").resolve()
print(f"Assets Directory: {assets_dir}")
if not assets_dir.exists():
    raise FileNotFoundError('Assets directory not found')

# Set the working directory to the 'src' directory, which contains only the code.
code_dir = root_dir / 'src'
code_dir = code_dir.resolve()
print(f"Code Directory: {code_dir}")
if not code_dir.exists():
    raise FileNotFoundError('Code directory not found')

excludes = ["data", "*.pyc", "__pycache__"
]

ray_init_args = {
    "runtime_env": {
        #"working_dir": str(code_dir),
        "py_modules": [str(code_dir)],
        "excludes": [str(code_dir / file) for file in excludes]
    },
    "include_dashboard": False,
    #"num_cpus": 4,
    # "local_mode": True
}

num_clients = 4

# config = {
#     "lambda_reg": 1,     
#     "num_epochs": 1,        
#     "offset": 2.5,       
# }



In [None]:
dae_model_save_path = '/data1/sgao/repos/CogSigma/oNSA/models/dae_model_qt_960.pth'
vae_model_save_path = '/data1/sgao/repos/CogSigma/oNSA/models/vae_model_qt_960_20.pth'
eae_model_save_path = '/data1/sgao/repos/CogSigma/oNSA/models/eae_model_qt_lambda05_960.pth'

In [None]:
def load_datasets_eval(assets_dir, seq_len=960, batch_size=32):
    
     # train dataset
    train_pickle_path_extend = assets_dir / 'extended' / 'cleaned_extended_train_df.parquet'
    train_df_extend = load_df_to_dataset(train_pickle_path_extend).data

    # validation dataset
    validate_pickle_path_extend = assets_dir / 'extended' / 'cleaned_extended_validate_df.parquet'
    validate_df_extend = load_df_to_dataset(validate_pickle_path_extend).data
    ood_df = generate_ood_data(validate_df_extend, ood_mean=1, ood_std=1)
    #print (ood_df.shape)

    # Define the list of features to discard
    drop_features_list = ['epoch', 'datetime', 'obj_id', 'traj_id', 'stopped', 'curv', 'abs_ccs']
    columns_to_clean = ['speed_c', 'lon', 'lat']  # Specify columns to clean
    
    cleaned_train_data = clean_outliers_by_quantile(train_df_extend, columns_to_clean, remove_na=False)
    cleaned_val_data = clean_outliers_by_quantile(validate_df_extend, columns_to_clean, remove_na=False)
    
    df_extend = pd.concat([cleaned_train_data, cleaned_val_data])
    df_extend = df_extend.sort_index()
    
    val_dataset_traj = TrajectoryDataset(
        cleaned_val_data,
        seq_len=seq_len,
        mode='ae',
        drop_features_list=drop_features_list,
        scaler_method='QuantileTransformer',
        filter_less_seq_len=6
        #categorical_features=['season']
    )
    
    val_ood_dataset_traj = TrajectoryDataset(
        ood_df,
        seq_len=seq_len,
        mode='ae',
        drop_features_list=drop_features_list,
        filter_percent = None,
        scaler = None,
        filter_less_seq_len = None,
        scaler_method = 'No_Scaler'
    )    
    #print (val_ood_dataset_traj.inputs)

    val_dataloader_traj = DataLoader(
        val_dataset_traj,
        batch_size=batch_size,
        num_workers=2,
        shuffle=False,
        pin_memory=False
    )
    
    val_ood_dataloader_traj = DataLoader(
        val_ood_dataset_traj,
        batch_size=batch_size,
        num_workers=2,
        shuffle=False,
        pin_memory=False
    )

    return val_dataloader_traj, val_ood_dataloader_traj, val_dataset_traj.n_features, val_dataset_traj

In [None]:
# Load Dataset
val_dataloader_traj, val_ood_dataloader_traj, input_dim, dataset_traj = load_datasets_eval(assets_dir)

In [None]:
val_ood_dataloader_traj.dataset.inputs.shape

In [None]:
val_dataloader_traj.dataset.inputs.shape

## Models and utlis

In [None]:
import math

def initialize_weights(m):
    if isinstance(m, (nn.Linear, nn.Conv1d)):
        nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d)):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)


class TransformerAutoEncoder(nn.Module):
    def __init__(self, input_dim, d_model, latent_dim, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, max_seq_length, dropout_rate, output_dim=None):
        super(TransformerAutoEncoder, self).__init__()
        self.d_model = d_model
        self.max_seq_length = max_seq_length
        if output_dim:
            self.output_dim = output_dim

        # Positional Encoding
        self.positional_encoding = nn.Parameter(self._generate_positional_encoding(self.max_seq_length, d_model), requires_grad=False)

        # Encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout_rate, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)

        # Decoder
        decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout_rate, batch_first=True)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

        # Fully connected layers for encoding inputs and decoding outputs
        self.input_fc = nn.Linear(input_dim, d_model)
        self.output_fc = nn.Linear(d_model, output_dim or input_dim)   # Output 4x dimensions in order to separate mu, v, alpha, beta

        self.bottleneck = nn.Linear(self.max_seq_length*d_model, 10)  # Optional, can skip if not needed
        self.reconstruct = nn.Linear(10, self.max_seq_length*d_model)  # Project back to the original space
        self.flatten = nn.Flatten(start_dim=1,end_dim=2)#
        self.unflatten = nn.Unflatten(dim=1, unflattened_size=(self.max_seq_length, d_model))
        # Apply weight initialization
        self.apply(initialize_weights)

    def forward(self, src, padding_mask=None, return_latent=False, noise_factor=0.05):
        # add noise
        noise = torch.randn_like(src) * noise_factor
        noisy_src = src + noise
        noisy_src = self.input_fc(noisy_src)  # Shape: (batch_size, seq_length, d_model)

        # add positional embedding
        noisy_src += self.positional_encoding[:, :noisy_src.size(1), :]

        if padding_mask is not None:
            # padding_mask:(batch_size, seq_len, input_dim)
            padding_mask_timestep = padding_mask.any(dim=-1)  # [batch_size, seq_len]
            mask_expanded = padding_mask_timestep.unsqueeze(-1).expand_as(noisy_src).bool()
            # mask
            noisy_src = torch.where(mask_expanded, noisy_src, torch.tensor(0.0, device=noisy_src.device))

        # encode
        encoded_memory = self.transformer_encoder(noisy_src)  
        
        # decode
        decoded_output = self.transformer_decoder(encoded_memory, encoded_memory)
        decoded_output = self.output_fc(decoded_output)  # Shape: (batch_size, seq_length, output_dim * 4)

        if padding_mask is not None:
            mask_expanded = padding_mask_timestep.unsqueeze(-1).expand(-1, -1, decoded_output.size(-1)).float()
            decoded_output = decoded_output * mask_expanded

        return decoded_output, encoded_memory


    def _generate_positional_encoding(self, length, d_model):
        position = torch.arange(length).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pos_encoding = torch.zeros(length, d_model)
        pos_encoding[:, 0::2] = torch.sin(position * div_term)
        pos_encoding[:, 1::2] = torch.cos(position * div_term)
        return pos_encoding.unsqueeze(0)  # add batch dimension
    
    
    
    
class TransformerVariationalAutoEncoder(nn.Module):
    def __init__(self, input_dim, d_model, latent_dim, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, max_seq_length, dropout_rate, output_dim=None):
        super(TransformerVariationalAutoEncoder, self).__init__()
        self.d_model = d_model
        self.max_seq_length = max_seq_length
        if output_dim:
            self.output_dim = output_dim

        # Positional Encoding
        self.positional_encoding = nn.Parameter(self._generate_positional_encoding(self.max_seq_length, d_model), requires_grad=False)

        # Encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout_rate, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        
        self.encode_mean = nn.Linear(self.max_seq_length*d_model, latent_dim)
        self.encode_logvar = nn.Linear(self.max_seq_length*d_model, latent_dim)
        self.decode_adapter = nn.Linear(latent_dim, self.max_seq_length*d_model)
        
        # Decoder
        decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout_rate, batch_first=True)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

        # Fully connected layers for encoding inputs and decoding outputs
        self.input_fc = nn.Linear(input_dim, d_model)
        self.output_fc = nn.Linear(d_model, output_dim or input_dim)   # Output 4x dimensions in order to separate mu, v, alpha, beta

        self.bottleneck = nn.Linear(self.max_seq_length*d_model, 10)  # Optional, can skip if not needed
        self.reconstruct = nn.Linear(10, self.max_seq_length*d_model)  # Project back to the original space
        self.flatten = nn.Flatten(start_dim=1,end_dim=2)#
        self.unflatten = nn.Unflatten(dim=1, unflattened_size=(self.max_seq_length, d_model))
        # Apply weight initialization
        self.apply(initialize_weights)

    def forward(self, src, padding_mask=None, return_latent=False, noise_factor=0.05):
        # add noise
        noise = torch.randn_like(src) * noise_factor
        noisy_src = src + noise
        noisy_src = self.input_fc(noisy_src)  # Shape: (batch_size, seq_length, d_model)

        # add positional embedding
        noisy_src += self.positional_encoding[:, :noisy_src.size(1), :]

        if padding_mask is not None:
            # padding_mask:(batch_size, seq_len, input_dim)
            padding_mask_timestep = padding_mask.any(dim=-1)  # [batch_size, seq_len]
            mask_expanded = padding_mask_timestep.unsqueeze(-1).expand_as(noisy_src).bool()
            # mask
            noisy_src = torch.where(mask_expanded, noisy_src, torch.tensor(0.0, device=noisy_src.device))

        # encode
        encoded_memory = self.transformer_encoder(noisy_src)  
        
        encoded_output =  encoded_memory.view(encoded_memory.size(0), -1)
        mu = self.encode_mean(encoded_output)
        log_var = self.encode_logvar(encoded_output)
        # reparameter
        z, _ = self.reparameterize(mu, log_var)
        z = self.decode_adapter(mu).view_as(encoded_memory)
        
        # decode
        decoded_output = self.transformer_decoder(z, z)
        decoded_output = self.output_fc(decoded_output)  # Shape: (batch_size, seq_length, output_dim * 4)

        if padding_mask is not None:
            mask_expanded = padding_mask_timestep.unsqueeze(-1).expand(-1, -1, decoded_output.size(-1)).float()
            decoded_output = decoded_output * mask_expanded

        return decoded_output, mu, log_var


    def _generate_positional_encoding(self, length, d_model):
        position = torch.arange(length).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pos_encoding = torch.zeros(length, d_model)
        pos_encoding[:, 0::2] = torch.sin(position * div_term)
        pos_encoding[:, 1::2] = torch.cos(position * div_term)
        return pos_encoding.unsqueeze(0)  # add batch dimension
        
    def reparameterize(self, mu, log_var):

        std = torch.exp(0.5 * log_var)  # standard deviation
        epsilon = torch.randn_like(std)  # random noise
        z = mu + epsilon * std  # reparameterize
        return z, std

In [None]:
def masked_reconstruction_loss(original, reconstructed, mask, offset=0, batch_mean=True):

    # Time-step-by-time-step calculation error (MSE)
    error = F.mse_loss(reconstructed, original, reduction="none")  # (batch_size, seq_len, feature_dim)
    error = error.mean(dim=-1)  # (batch_size, seq_len)

    # mask
    masked_error = error * mask  # v

    # Averaging over effective time steps
    loss = masked_error.sum(dim=1) / mask.sum(dim=1)  
    if batch_mean == True:
        loss = loss.mean()  
    else:
        pass
    
    loss = loss + offset
    
    return loss

In [None]:
def load_model(model, path, device='cpu'):
    model.load_state_dict(torch.load(path, map_location=device))
    model.to(device)
    print(f"Model loaded from {path}")  

#### Evaluation function for DAE and VAE

In [None]:
def evaluate_dae_vae(
    model_class,
    model_path,
    device,
    test_dataloader,
    input_dim=20,
    d_model=8,
    latent_dim=8,
    nhead=4,
    num_encoder_layers=2,
    num_decoder_layers=2,
    dim_feedforward=32,
    max_seq_length=960,
    dropout_rate=0.1,
    criterion=masked_reconstruction_loss
):

    # Initialization Model
    model = model_class(
        input_dim=input_dim,
        d_model=d_model,
        latent_dim=latent_dim,
        nhead=nhead,
        num_encoder_layers=num_encoder_layers,
        num_decoder_layers=num_decoder_layers,
        dim_feedforward=dim_feedforward,
        max_seq_length=max_seq_length,
        dropout_rate=dropout_rate
    ).to(device)

    # Load trained model parameters
    load_model(model, model_path, device=device)
    model.eval()

    latents = []
    latents_std = []
    recon_errors = []
    all_inputs = []
    all_outputs = []

    with torch.no_grad():
        for batch in test_dataloader:
            batch_data = batch['inputs'].to(device)
            masks = batch.get('input_masks', None)
            if masks is not None:
                masks = masks.to(device)
                if masks.dim() == 2:
                    masks_expend = masks.unsqueeze(-1).expand_as(batch_data)
            else:
                masks_expend = None

            # forward propagation
            output = model(batch_data, masks_expend)
            
            # Determine whether it is AE or VAE based on the number of outputs
            if len(output) == 2:
                # AE: output = (recon_data, latent)
                recon_data, latent = output
                mu = latent
                std = None
            elif len(output) == 3:
                # VAE: output = (recon_data, mu, log_var)
                recon_data, mu, log_var = output
                if hasattr(model, 'reparameterize'):
                    _, std = model.reparameterize(mu, log_var)
                else:
                    raise AttributeError("The model does not have a reparameterize method, but returned mu and log_var.")
            else:
                raise ValueError("Unexpected number of outputs from the model forward pass. Expected 2 for AE, or 3 for VAE.")

            # Compute the reconstruction error
            loss = criterion(batch_data, recon_data, masks, batch_mean=False)
            recon_errors.extend(loss.cpu().numpy())

            # Collection of latent expressions and standard deviations
            latents.append(mu.cpu().numpy())
            if std is not None:
                latents_std.append(std.cpu().numpy())

            # Collecting inputs and outputs
            all_inputs.append(batch_data.cpu().numpy())
            all_outputs.append(recon_data.cpu().numpy())

    latents = np.vstack(latents)
    if len(latents_std) > 0:
        latents_std = np.vstack(latents_std)
    else:
        latents_std = None

    inputs = np.vstack(all_inputs)
    outputs = np.vstack(all_outputs)

    return latents, latents_std, recon_errors, inputs, outputs



#### Evaluation function for EAE

In [None]:
def evaluate_eae(
    model_class,
    model_path,
    criterion,
    val_dataloader,
    lambda_reg,
    offset,
    device,
    return_latent=False,
    input_dim=20,
    d_model=8,
    nhead=4,
    num_encoder_layers=2,
    num_decoder_layers=2,
    dim_feedforward=32,
    max_seq_length=960,
    dropout_rate=0.1
):

    # Initialization Model
    model = model_class(
        input_dim=input_dim,
        d_model=d_model,
        nhead=nhead,
        num_encoder_layers=num_encoder_layers,
        num_decoder_layers=num_decoder_layers,
        dim_feedforward=dim_feedforward,
        max_seq_length=max_seq_length,
        dropout_rate=dropout_rate
    ).to(device)

    # Load trained model parameters
    load_model(model, model_path, device=device)
    model.eval()

    total_val_loss = 0.0
    total_samples = 0
    recon_error = []
    latent_representations = []
    all_inputs = []
    all_outputs = []

    with torch.no_grad():
        for batch_idx, batch in enumerate(val_dataloader):
            features = batch['inputs'].to(device)
            masks = batch.get('input_masks', None)
            if masks is not None:
                masks = masks.to(device)
                if masks.dim() == 2:
                    masks = masks.unsqueeze(-1).expand_as(features)

            if return_latent:
                mu, v, alpha, beta, latent = model(features, return_latent=True)
                latent_representations.append(latent.cpu())
            else:
                mu, v, alpha, beta = model(features)

            per_sample_reconstruction_error = masked_reconstruction_loss(features, mu, masks[:,:,0] if masks is not None else None, batch_mean=False)
            recon_error.extend(per_sample_reconstruction_error.cpu().numpy())

            # Calculate the total validation loss (mu can be used within criterion)
            val_loss = criterion((mu, v, alpha, beta), features, lambda_reg, offset, mask=masks, recon_error=per_sample_reconstruction_error)
            batch_size = features.size(0)
            total_val_loss += val_loss.item() * batch_size
            total_samples += batch_size

            # Collection of input and output data for post-mortem analysis and visualization
            all_inputs.append(features.cpu().numpy())
            all_outputs.append(mu.cpu().numpy())

    avg_val_loss = total_val_loss / total_samples

    # Concatenating cumulative input and output data
    inputs = np.vstack(all_inputs)
    outputs = np.vstack(all_outputs)

    if return_latent:
        latent_representations = torch.cat(latent_representations, dim=0)
        return avg_val_loss, recon_error, inputs, outputs, latent_representations
    else:
        return avg_val_loss, recon_error, inputs, outputs


In [None]:
def calculate_ood_metrics(id_scores, ood_scores, threshold_method='percentile', percentile=95, k=1.0, reduce_method='mean'):
    # 1. Ensure input data is a NumPy array
    id_scores = np.array(id_scores)
    ood_scores = np.array(ood_scores)
    
    # 2. Dimensionality reduction if the input is 2D (batch_size, latent_dim)
    if id_scores.ndim > 1:
        if reduce_method == 'mean':
            id_scores = np.mean(id_scores, axis=1)
            ood_scores = np.mean(ood_scores, axis=1)
        elif reduce_method == 'max':
            id_scores = np.max(id_scores, axis=1)
            ood_scores = np.max(ood_scores, axis=1)
        elif reduce_method == 'l2':
            id_scores = np.linalg.norm(id_scores, axis=1)
            ood_scores = np.linalg.norm(ood_scores, axis=1)
        else:
            raise ValueError("Invalid reduce_method. Available options are 'mean', 'max', 'l2'.")

    # 3. Calculate the threshold
    if threshold_method == 'percentile':
        threshold = np.percentile(id_scores, percentile)
    elif threshold_method == 'mean_std':
        threshold = np.mean(id_scores) + k * np.std(id_scores)
    else:
        raise ValueError("threshold_method must be 'percentile' or 'mean_std'")
    
    # 4. Concatenate ID and OOD scores
    all_scores = np.concatenate([id_scores, ood_scores], axis=0)
    
    # 5. Create labels (ID is 0, OOD is 1)
    labels_id = np.zeros(len(id_scores))  # ID labels
    labels_ood = np.ones(len(ood_scores)) # OOD labels
    all_labels = np.concatenate([labels_id, labels_ood], axis=0)
    
    # 6. Generate predictions based on threshold
    predictions = (all_scores > threshold).astype(int)
    
    # 7. Calculate metrics
    # F1 Score
    f1 = f1_score(all_labels, predictions)
    # AUROC
    auroc = roc_auc_score(all_labels, all_scores)
    # AUPR
    aupr = average_precision_score(all_labels, all_scores)
    
    # Detection Error
    fpr, tpr, roc_thresholds = roc_curve(all_labels, all_scores)
    detection_errors = 0.5 * (fpr + (1 - tpr))
    detection_error = np.min(detection_errors)

    return f1, auroc, aupr, detection_error, threshold


## DAE

In [None]:
input_dim = 20
d_model = 8
nhead = 4
num_encoder_layers = 2
num_decoder_layers = 2
dim_feedforward = 32
max_seq_length = 960
dropout_rate = 0.1

latents_d, _, recon_errors_d, inputs_d, outputs_d = evaluate_dae_vae(
    model_class=TransformerAutoEncoder,
    model_path=dae_model_save_path,
    device='cuda',
    test_dataloader=val_dataloader_traj,
    input_dim=input_dim,
    d_model=d_model,
    nhead=nhead,
    num_encoder_layers=num_encoder_layers,
    num_decoder_layers=num_decoder_layers,
    dim_feedforward=dim_feedforward,
    max_seq_length=max_seq_length,
    dropout_rate=dropout_rate,
    criterion=masked_reconstruction_loss 
)

In [None]:
print("Mean reconstruction error:", sum(recon_errors_d) / len(recon_errors_d))


In [None]:
latents_d_ood, _, recon_errors_d_ood, _, _ = evaluate_dae_vae(
    model_class=TransformerAutoEncoder,
    model_path=dae_model_save_path,
    device='cuda',
    test_dataloader=val_ood_dataloader_traj
)

print("Mean reconstruction error:", sum(recon_errors_d_ood) / len(recon_errors_d_ood))


In [None]:
combined_latent_representations_d = np.concatenate([latents_d, latents_d_ood], axis = 0)
combined_recon_error_d = recon_errors_d + recon_errors_d_ood

In [None]:
ood_labels_d = [0] * len(latents_d) + [1] * len(latents_d_ood)
plot_tsne_with_uncertainty(combined_latent_representations_d, ood_labels_d, uncertainty_type='OOD')

In [None]:
percentile_98_d = np.percentile(combined_recon_error_d, 95)
print(percentile_98_d)
plot_tsne_with_uncertainty(combined_latent_representations_d, combined_recon_error_d, uncertainty_type='Reconstruction Error', threshold = percentile_98_d)


In [None]:
visualize_mean_features(inputs_d, outputs_d, num_features=19, num_samples=200)

In [None]:
f1, auroc, aupr, detection_error, threshold = calculate_ood_metrics(recon_errors_d, recon_errors_d_ood, threshold_method='percentile', percentile=95)
print(f"Reconstruction Error F1 score (DAE): {f1:.4f}, AUROC: {auroc:.4f}, AUPR: {aupr}, Detection Error: {detection_error}, Threshold: {threshold:.4f}")

## VAE

In [None]:
latents_v, latents_std, recon_errors_v, inputs_v, outputs_v = evaluate_dae_vae(
    model_class=TransformerVariationalAutoEncoder,
    model_path=vae_model_save_path,
    device='cuda',
    test_dataloader=val_dataloader_traj,
    latent_dim=10,
    input_dim=input_dim,
    d_model=d_model,
    nhead=nhead,
    num_encoder_layers=num_encoder_layers,
    num_decoder_layers=num_decoder_layers,
    dim_feedforward=dim_feedforward,
    max_seq_length=max_seq_length,
    dropout_rate=dropout_rate,
    criterion=masked_reconstruction_loss 
)

print("Mean reconstruction error:", sum(recon_errors_v) / len(recon_errors_v))

In [None]:
visualize_mean_features(inputs_v, outputs_v, num_features=19, num_samples=100)

In [None]:
latents_v_ood, latents_std_ood, recon_errors_v_ood, _, _ = evaluate_dae_vae(
    model_class=TransformerVariationalAutoEncoder,
    model_path=vae_model_save_path,
    device='cuda',
    test_dataloader=val_ood_dataloader_traj,
    latent_dim=10,
    input_dim=input_dim,
    d_model=d_model,
    nhead=nhead,
    num_encoder_layers=num_encoder_layers,
    num_decoder_layers=num_decoder_layers,
    dim_feedforward=dim_feedforward,
    max_seq_length=max_seq_length,
    dropout_rate=dropout_rate,
    criterion=masked_reconstruction_loss
)

print("Mean reconstruction error:", sum(recon_errors_v_ood) / len(recon_errors_v_ood))

In [None]:
combined_latent_representations_v = np.concatenate([latents_v, latents_v_ood], axis = 0)
combined_recon_error_v = recon_errors_v + recon_errors_v_ood
combined_latents_std = np.concatenate([latents_std, latents_std_ood], axis = 0) 

In [None]:
ood_labels_v = [0] * len(latents_v) + [1] * len(latents_v_ood)
plot_tsne_with_uncertainty(combined_latent_representations_v, ood_labels_v, uncertainty_type='ood label')

In [None]:
percentile_98_v = np.percentile(combined_recon_error_v, 95)
print(percentile_98_v)
plot_tsne_with_uncertainty(combined_latent_representations_v, combined_recon_error_v, uncertainty_type='recon_error', threshold = percentile_98_v)

In [None]:
plot_tsne_with_uncertainty(combined_latent_representations_v, combined_latents_std.mean(axis=1), uncertainty_type='latents_std', threshold = np.percentile(combined_latents_std, 95))

In [None]:
f1_v, auroc_v, aupr_v, detection_error_v, threshold_v = calculate_ood_metrics(recon_errors_v, recon_errors_v_ood, threshold_method='percentile', percentile=95)
print(f"Reconstruction Error F1 score (VAE): {f1_v:.4f}, AUROC: {auroc_v:.4f}, AUPR: {aupr_v}, Detection Error: {detection_error_v}, Threshold: {threshold_v:.4f}")

In [None]:
f1_std, auroc_std, aupr_std, detection_error_std, threshold_std = calculate_ood_metrics(latents_std, latents_std_ood, threshold_method='percentile', percentile=75, reduce_method='mean')
print(f"STD F1 score (VAE): {f1_std:.4f}, AUROC: {auroc_std:.4f}, AUPR: {aupr_std}, Detection Error: {detection_error_std}, Threshold: {threshold_std:.4f}")

### EAE

In [None]:
avg_val_loss_e, recon_errors_e, inputs_e, outputs_e = evaluate_eae(
    model_class=EvidentialTransformerDenoiseAutoEncoder, 
    model_path=eae_model_save_path,
    criterion=evidential_regression, 
    val_dataloader=val_dataloader_traj, 
    lambda_reg=0.5, 
    offset=2.5, 
    device='cpu', 
    return_latent=False,
    input_dim=20, 
    d_model=8,
    nhead=4,
    num_encoder_layers=2,
    num_decoder_layers=2,
    dim_feedforward=32,
    max_seq_length=960,
    dropout_rate=0.1
)

print("Mean reconstruction error:", sum(recon_errors_e) / len(recon_errors_e))


In [None]:
eae_model = EvidentialTransformerDenoiseAutoEncoder(
    input_dim=input_dim,
    d_model=8,
    nhead=4,
    num_encoder_layers=2,
    num_decoder_layers=2,
    dim_feedforward=32,
    max_seq_length=960,
    dropout_rate=0.1
)

In [None]:
val_loss_e, val_aleatoric_uncertainties_e, val_epistemic_uncertainties_e, avg_aleatoric_uncertainty_e, avg_epistemic_uncertainty_e, latent_representations_eval_e, recon_error_e = evaluate_saved_model(
    model_class=eae_model, 
    model_path=eae_model_save_path, 
    criterion=evidential_regression, 
    val_dataloader=val_dataloader_traj, 
    lambda_reg=0.5, 
    offset=2.5, 
    device='cuda', 
    return_latent=True
)

In [None]:
print("Average:", sum(recon_error_e) / len(recon_error_e))

In [None]:
val_ood_loss_e, val_ood_aleatoric_uncertainties_e, val_ood_epistemic_uncertainties_e, avg_ood_aleatoric_uncertainty_e, avg_ood_epistemic_uncertainty_e, latent_ood_representations_eval_e, recon_ood_error_le = evaluate_saved_model(
    model_class=eae_model, 
    model_path=eae_model_save_path, 
    criterion=evidential_regression, 
    val_dataloader=val_ood_dataloader_traj, 
    lambda_reg=0.5, 
    offset=2.5, 
    device='cuda', 
    return_latent=True
)

In [None]:
combined_latent_representations_e = np.concatenate([latent_representations_eval_e, latent_ood_representations_eval_e], axis = 0)
combined_recon_error_e = recon_error_e + recon_ood_error_le

In [None]:
combined_val_epistemic_uncertainties_e = np.concatenate([val_epistemic_uncertainties_e, val_ood_epistemic_uncertainties_e], axis = 0)
combined_val_aleatoric_uncertainties_e = np.concatenate([val_aleatoric_uncertainties_e, val_ood_aleatoric_uncertainties_e], axis = 0)

In [None]:
ood_labels_e = [0] * len(latent_representations_eval_e) + [1] * len(latent_ood_representations_eval_e)
plot_tsne_with_uncertainty(combined_latent_representations_e, ood_labels_e, uncertainty_type='ood label')

In [None]:
percentile_98_e = np.percentile(combined_recon_error_e, 95)
print(percentile_98_e)
plot_tsne_with_uncertainty(combined_latent_representations_e, combined_recon_error_e, uncertainty_type='recon_error', threshold = percentile_98_e)

In [None]:
visualize_mean_features(inputs_e, outputs_e, num_features=19, num_samples=100)

In [None]:
percentile_98_e_uncertainty = np.percentile(combined_val_epistemic_uncertainties_e, 94.5)
print(percentile_98_e_uncertainty)
plot_tsne_with_uncertainty(combined_latent_representations_e, combined_val_epistemic_uncertainties_e, uncertainty_type='val_epistemic_uncertainties', threshold = percentile_98_e_uncertainty)

In [None]:
plot_tsne_with_uncertainty(latent_representations_eval_e, val_epistemic_uncertainties_e, uncertainty_type='val_epistemic_uncertainties without ood', threshold = percentile_98_e_uncertainty)

In [None]:
percentile_98_e_uncertainty_a = np.percentile(combined_val_aleatoric_uncertainties_e, 95)
print(percentile_98_e_uncertainty_a)
plot_tsne_with_uncertainty(combined_latent_representations_e, combined_val_aleatoric_uncertainties_e, uncertainty_type='val_aleatoric_uncertainties', threshold = percentile_98_e_uncertainty_a)

In [None]:
plot_tsne_with_uncertainty(latent_representations_eval_e, val_aleatoric_uncertainties_e, uncertainty_type='val_aleatoric_uncertainties without ood', threshold = percentile_98_e_uncertainty)

In [None]:
f1_e, auroc_e, aupr_e, detection_error_e, threshold_e = calculate_ood_metrics(recon_errors_e, recon_ood_error_le, threshold_method='percentile', percentile=95)
print(f"Reconstruction Error F1 score (EAE): {f1_e:.4f}, AUROC: {auroc_e:.4f}, AUPR: {aupr_e}, Detection Error: {detection_error_e}, Threshold: {threshold_e:.4f}")

In [None]:
f1_eu, auroc_eu, aupr_eu, detection_error_eu, threshold_eu = calculate_ood_metrics(val_epistemic_uncertainties_e, val_ood_epistemic_uncertainties_e, threshold_method='percentile', percentile=95)
print(f"Epistemic Uncertainty F1 score (EAE): {f1_eu:.4f}, AUROC: {auroc_eu:.4f}, AUPR: {aupr_eu}, Detection Error: {detection_error_eu}, Threshold: {threshold_eu:.4f}")

In [None]:
f1_au, auroc_au, aupr_au, detection_error_au, threshold_au = calculate_ood_metrics(val_aleatoric_uncertainties_e, val_ood_aleatoric_uncertainties_e, threshold_method='percentile', percentile=75)
print(f"Aleatoric Uncertainty F1 score (EAE): {f1_au:.4f}, AUROC: {auroc_au:.4f}, AUPR: {aupr_au}, Detection Error: {detection_error_au}, Threshold: {threshold_au:.4f}")

### Analysis

In [None]:
#for key in val_dataset_traj.labels.keys():
# for key in ['epoch', 'stopped', 'cog_c', 'aad', 'rot_c', 'speed_c', 'distance_c',
#        'acc_c', 'cdd', 'dir_ccs', 'dist_ww', 'dist_ra',
#        'dist_cl', 'dist_ma', 'traj_id', 'lon', 'lat', 'obj_id', 'datetime',
#        'season', 'part_of_day', 'month_sin', 'month_cos', 'hour_sin',
#        'hour_cos']:
# for key in ['cog_c', 'aad', 'rot_c', 'speed_c', 'distance_c',
#        'acc_c', 'cdd', 'dir_ccs', 'dist_ww', 'dist_ra',
#        'dist_cl', 'dist_ma', 'lon', 'lat',
#        'season', 'part_of_day', 'month_sin', 'month_cos', 'hour_sin',
#        'hour_cos']:
#     plot_tsne_with_uncertainty(latent_representations_eval_e, dataset_traj.labels[key], uncertainty_type=key)