In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import multiprocessing
multiprocessing.set_start_method("spawn", force=True)

##> import libraries
import sys
from pathlib import Path
import random
import time
from itertools import product

root_dir = Path.cwd().resolve().parent
if root_dir.exists():
    sys.path.append(str(root_dir))
else:
    raise FileNotFoundError('Root directory not found')

#> import flower
import flwr as fl

#> import custom libraries
from src.load import load_df_to_dataset
from src.EAE import EvidentialTransformerDenoiseAutoEncoder, evidential_regression
from src.client import train_and_evaluate_local, evaluate_saved_model, evaluate_local
from src.datasets import TrajectoryDataset, generate_ood_data,clean_outliers_by_quantile
from src.plot import plot_loss, plot_tsne_with_uncertainty, plot_uncertainty

#> torch libraries
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn
import pandas as pd
import numpy as np
import statsmodels.api as sm

from sklearn.model_selection import train_test_split

#> Plot
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
# import scienceplots  # https://github.com/garrettj403/SciencePlots?tab=readme-ov-file
#plt.style.use(['science', 'grid', 'notebook'])  # , 'ieee'
import torch.nn.functional as F
import math
# %matplotlib inline
#%matplotlib widget


In [None]:
  # Define the dataset catalog
assets_dir = root_dir.parents[3] / 'aistraj' / 'bin'/ 'tvt_assets'
assets_dir = assets_dir.resolve()
print(f"Assets Directory: {assets_dir}")
if not assets_dir.exists():
    raise FileNotFoundError('Assets directory not found')
    
saved_model_dir = root_dir / 'models'
saved_model_dir = saved_model_dir.resolve()
print(f"Assets Directory: {saved_model_dir}")
if not saved_model_dir.exists():
    raise FileNotFoundError('Model directory not found')

In [4]:
def load_datasets(assets_dir, seq_len=960, batch_size=32):

    # train dataset
    train_pickle_path_extend = assets_dir / 'extended' / 'cleaned_extended_train_df.parquet'
    train_df_extend = load_df_to_dataset(train_pickle_path_extend).data

    # validation dataset
    validate_pickle_path_extend = assets_dir / 'extended' / 'cleaned_extended_validate_df.parquet'
    validate_df_extend = load_df_to_dataset(validate_pickle_path_extend).data
    ood_df = generate_ood_data(validate_df_extend, ood_mean=10, ood_std=3)
    #print(validate_df_extend)
    # Define the list of features to discard
    drop_features_list = ['epoch', 'datetime', 'obj_id', 'traj_id', 'stopped', 'curv', 'abs_ccs']
    
    # Specify columns to clean
    columns_to_clean = ['speed_c', 'lon', 'lat']  # Specify columns to clean
    cleaned_train_data = clean_outliers_by_quantile(train_df_extend, columns_to_clean, remove_na=False)
    cleaned_val_data = clean_outliers_by_quantile(validate_df_extend, columns_to_clean, remove_na=False)
   
    # Create training and validation datasets
    #df_extend = pd.concat([cleaned_train_data, cleaned_val_data])
    df_extend = pd.concat([train_df_extend, validate_df_extend])
    df_extend = df_extend.sort_index()
    # Create training and validation datasets
    train_dataset_traj = TrajectoryDataset(
        cleaned_train_data,
        seq_len=seq_len,
        mode='ae',
        drop_features_list=drop_features_list,
        scaler_method='QuantileTransformer',
        filter_less_seq_len = seq_len
    )
    val_dataset_traj = TrajectoryDataset(
        cleaned_val_data,
        seq_len=seq_len,
        mode='ae',
        drop_features_list=drop_features_list,
        scaler_method='QuantileTransformer',
        filter_less_seq_len = seq_len
    )
    val_ood_dataset_traj = TrajectoryDataset(
        ood_df,
        seq_len=seq_len,
        mode='ae',
        drop_features_list=drop_features_list,
        filter_percent = None,
        scaler = None,
        filter_less_seq_len = None,
        scaler_method = 'No_Scaler'
    )    

    train_dataloader_traj = DataLoader(
        train_dataset_traj,
        batch_size=batch_size,
        num_workers=2,
        shuffle=True,
        pin_memory=False
    )
    val_dataloader_traj = DataLoader(
        val_dataset_traj,
        batch_size=batch_size,
        num_workers=2,
        shuffle=False,
        pin_memory=False
    )
    val_ood_dataloader_traj = DataLoader(
        val_ood_dataset_traj,
        batch_size=batch_size,
        num_workers=2,
        shuffle=False,
        pin_memory=False
    )
    return train_dataloader_traj, val_dataloader_traj, val_ood_dataloader_traj, train_dataset_traj.n_features, train_dataset_traj, val_dataset_traj, val_ood_dataset_traj

In [None]:
# Load Dataset
train_dataloader_traj, val_dataloader_traj, val_ood_dataloader_traj, input_dim, train_dataset_traj, val_dataset_traj, val_ood_dataset_traj = load_datasets(assets_dir)

In [None]:
train_dataset_traj.input_masks.min()

In [None]:
val_dataset_traj.inputs.shape

In [None]:
val_ood_dataset_traj.inputs.shape

In [None]:
train_dataset_traj.inputs.shape

In [None]:
val_dataset_traj.labels[['traj_id', 'lon', 'lat', 'obj_id', 'datetime']]

In [None]:
val_ood_dataset_traj.dataframe[['traj_id', 'lon', 'lat', 'obj_id', 'datetime']]

In [None]:
val_dataset_traj.dataframe[['traj_id', 'lon', 'lat', 'obj_id', 'datetime']]

In [None]:
for key in ['epoch', 'stopped', 'cog_c', 'aad', 'rot_c', 'speed_c', 'distance_c',
       'acc_c', 'cdd', 'dir_ccs', 'dist_ww', 'dist_ra',
       'dist_cl', 'dist_ma', 'traj_id', 'lon', 'lat', 'obj_id',
       'season', 'part_of_day', 'month_sin', 'month_cos', 'hour_sin',
       'hour_cos']:
    print(key)
    plt.figure()
    train_dataset_traj.dataframe[key].hist(bins=30, grid=False, edgecolor='black')
    #val_dataset_traj.labels[key].hist(bins=30, grid=False, edgecolor='black')
    
    max_index = train_dataset_traj.dataframe[key].idxmax()
    max_value = train_dataset_traj.dataframe[key].max()
    print(f"max index and value: {max_index}", max_value, (val_dataset_traj.dataframe[key] == max_value).sum())
    # Get the maximum value and the values above and below
    # Get the index list
    index_list = train_dataset_traj.dataframe[key].index.to_list()
    
    # Find the position of the maximum index in the index list
    max_pos = index_list.index(max_index)
    
    # Get the upper and lower indexes (to prevent out-of-bounds)
    start_pos = max(0, max_pos - 3)
    end_pos = min(len(index_list) - 1, max_pos + 3)
    
    # Extract the corresponding row
    result = train_dataset_traj.dataframe[[key, 'obj_id', 'traj_id', 'lon', 'lat','datetime']].loc[index_list[start_pos:end_pos + 1]]

    print(result)
    plt.show()

In [6]:
def initialize_weights(m):
    if isinstance(m, (nn.Linear, nn.Conv1d)):
        nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, (nn.LayerNorm, nn.BatchNorm1d)):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)


class TransformerAutoEncoder(nn.Module):
    def __init__(self, input_dim, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, max_seq_length, dropout_rate, output_dim=None):
        super(TransformerAutoEncoder, self).__init__()
        self.d_model = d_model
        self.max_seq_length = max_seq_length
        if output_dim:
            self.output_dim = output_dim

        # Positional Encoding
        self.positional_encoding = nn.Parameter(self._generate_positional_encoding(self.max_seq_length, d_model), requires_grad=False)

        # Encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout_rate, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)

        # Decoder
        decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout_rate, batch_first=True)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

        # Fully connected layers for encoding inputs and decoding outputs
        self.input_fc = nn.Linear(input_dim, d_model)
        self.output_fc = nn.Linear(d_model, output_dim or input_dim)   # Output 4x dimensions in order to separate mu, v, alpha, beta

        self.bottleneck = nn.Linear(self.max_seq_length*d_model, 10)  # Optional, can skip if not needed
        self.reconstruct = nn.Linear(10, self.max_seq_length*d_model)  # Project back to the original space
        self.flatten = nn.Flatten(start_dim=1,end_dim=2)#
        self.unflatten = nn.Unflatten(dim=1, unflattened_size=(self.max_seq_length, d_model))
        # Apply weight initialization
        self.apply(initialize_weights)

    def forward(self, src, padding_mask=None, return_latent=False, noise_factor=0.05):
        # add noise
        noise = torch.randn_like(src) * noise_factor
        noisy_src = src + noise
        noisy_src = self.input_fc(noisy_src)  # Shape: (batch_size, seq_length, d_model)

        # add positional embedding
        noisy_src += self.positional_encoding[:, :noisy_src.size(1), :]

        if padding_mask is not None:
            # padding_mask:(batch_size, seq_len, input_dim)
            padding_mask_timestep = padding_mask.any(dim=-1)  # [batch_size, seq_len]
            mask_expanded = padding_mask_timestep.unsqueeze(-1).expand_as(noisy_src).bool()
            # mask
            noisy_src = torch.where(mask_expanded, noisy_src, torch.tensor(0.0, device=noisy_src.device))

        # encode
        encoded_memory = self.transformer_encoder(noisy_src)  
        
        # Bottleneck layer (optional)
        #hidden_representation = self.bottleneck(self.flatten(encoded_memory))  # Shape: (batch_size, seq_length, d_model)
        # Decoder: Use hidden representation to reconstruct
        #hidden_to_decoder = self.reconstruct(hidden_representation)  # Project back to (batch_size, seq_length, d_model)
        
        # decode
        #decoded_output = self.transformer_decoder(self.unflatten(hidden_to_decoder), encoded_memory)
        decoded_output = self.transformer_decoder(encoded_memory, encoded_memory)
        decoded_output = self.output_fc(decoded_output)  # Shape: (batch_size, seq_length, output_dim * 4)

        if padding_mask is not None:
            mask_expanded = padding_mask_timestep.unsqueeze(-1).expand(-1, -1, decoded_output.size(-1)).float()
            #print(f"decoded_output shape: {decoded_output.shape}")
            #print(f"mask_expanded shape: {mask_expanded.shape}")
            decoded_output = decoded_output * mask_expanded

        # Output mu, v, alpha, beta via Evidential Learning
        # mu, logv, logalpha, logbeta = torch.chunk(decoded_output, 4, dim=2)
        # v = F.softplus(logv) + 1e-6
        # alpha = F.softplus(logalpha) + 1.0
        # beta = F.softplus(logbeta) + 1e-6

        # Return the encoded representation and decoded uncertainty outputs
        # if return_latent:
        #     return mu, v, alpha, beta, encoded_memory
        # else:
        #     return mu, v, alpha, beta
        return decoded_output, encoded_memory


    def _generate_positional_encoding(self, length, d_model):
        position = torch.arange(length).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pos_encoding = torch.zeros(length, d_model)
        pos_encoding[:, 0::2] = torch.sin(position * div_term)
        pos_encoding[:, 1::2] = torch.cos(position * div_term)
        return pos_encoding.unsqueeze(0)  # add batch dimension

In [7]:
def train(model, dataloader, criterion, optimizer, device, num_epochs=20):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for batch_idx, batch in enumerate(dataloader):
            batch_data = batch['inputs'].to(device)
            masks = batch.get('input_masks', None).to(device)
            if masks is not None:
                lengths = masks.sum(dim=1).to('cpu')
                masks = masks.to(device)
                if masks.dim() == 2:
                    masks_expend = masks.unsqueeze(-1).expand_as(batch_data)
            else:
                lengths = torch.tensor([960]*len(batch_data)).to('cpu')
            optimizer.zero_grad()
            #output, _ = model(batch_data, lengths)
            #output, _ = model(batch_data)
            output, _ = model(batch_data, masks_expend)
            #print(output.shape, batch_data.shape)
            loss = criterion(output, batch_data, masks)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

In [12]:
def save_model(model, path):
    torch.save(model.state_dict(), path)
    print(f"Model saved to {path}")
    
def load_model(model, path, device='cpu'):
    model.load_state_dict(torch.load(path, map_location=device))
    model.to(device)
    print(f"Model loaded from {path}")    

def masked_reconstruction_loss(original, reconstructed, mask, offset=2.5, batch_mean=True):
    """
    original: 原始时间序列 (batch_size, seq_len, feature_dim)
    reconstructed: 重建时间序列 (batch_size, seq_len, feature_dim)
    mask: 时间序列的有效时间步 mask (batch_size, seq_len)
    """
    # Time-step-by-time-step calculation error (MSE)
    error = F.mse_loss(reconstructed, original, reduction="none")  # (batch_size, seq_len, feature_dim)
    error = error.mean(dim=-1)  # Averaging by feature dimension to get error per time step (batch_size, seq_len)

    # Apply mask
    masked_error = error * mask  # Keep only the error of valid time steps

    # Average over valid time steps
    loss = masked_error.sum(dim=1) / mask.sum(dim=1)  # Calculate weighted average per sample
    if batch_mean == True:
        loss = loss.mean()  # Average the samples in the batch
    else:
        pass
    
    loss = loss + offset
    
    return loss


def reconstruction_error_per_sample(x, x_reconstructed, reduction='mean'):
    # Check for shape consistency
    assert x.shape == x_reconstructed.shape, "Input and reconstructed tensors must have the same shape"
    
    # Calculate error element by element
    errors = (x - x_reconstructed) ** 2  # (batch_size, seq_len, channels)
    
    # Summarise errors in seq_len and channels dimensions
    if reduction == 'mean':
        errors_per_sample = errors.mean(dim=(1, 2))  # Averaging by time and channel
    elif reduction == 'sum':
        errors_per_sample = errors.sum(dim=(1, 2))  # Summing over time and channels
    else:
        raise ValueError("Reduction must be 'mean' or 'sum'")
    
    return errors_per_sample
    
def get_latent_space(model, dataloader, device):
    model.eval()
    latents = []
    recon_error = []
    #labels = []  # If there are labels, they can be used for visual distinction
    with torch.no_grad():
        for batch_idx, batch in enumerate(dataloader):
            batch_data = batch['inputs'].to(device)
            masks = batch.get('input_masks', None)
            if masks is not None:
                lengths = masks.sum(dim=1).to('cpu')
                masks = masks.to(device)
                if masks.dim() == 2:
                    masks_expend = masks.unsqueeze(-1).expand_as(batch_data)
            else:
                lengths = torch.tensor([960]*len(batch_data)).to('cpu')
            batch_data = batch_data.to(device)
            #recon_data, latent = model(batch_data, lengths)
            #recon_data, latent = model(batch_data)
            recon_data, latent = model(batch_data,masks)
            latents.append(latent.cpu().numpy())
            loss = masked_reconstruction_loss(batch_data.to(device), recon_data.to(device), masks.to(device), batch_mean=False)
            #print(loss.shape)
            #recon_error+=loss.cpu().numpy())
            #recon_error.extend(loss.cpu().numpy())
            recon_error.extend(loss.cpu().numpy())
            # If you have tags, you can add them here
            # labels.append(batch_labels)

    latents = np.vstack(latents)
    #recon_error = np.vstack(recon_error)
    print(f"Get Latent Space Done")
    return latents, recon_error

In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
input_dim = 20
hidden_dim = 50
latent_dim = 8
num_layers = 2
batch_size = 32
num_epochs = 20

# Create model
model = TransformerAutoEncoder(
    input_dim=input_dim,
    d_model=8,
    nhead=4,
    num_encoder_layers=2,
    num_decoder_layers=2,
    dim_feedforward=32,
    max_seq_length=960,
    dropout_rate=0.1
).to(device)
criterion = masked_reconstruction_loss# nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)

save_model_path = saved_model_dir + 'dae_model_qt_960_50e.pth'
# Train the model


In [None]:
train(model, train_dataloader_traj, criterion, optimizer, num_epochs=num_epochs, device=device)
save_model(model, save_model_path)

In [None]:
latent_representations, recon_error = get_latent_space(model, val_dataloader_traj,criterion)

In [None]:
latent_representations.shape

In [None]:
len(recon_error)

In [None]:
recon_error[0]

In [None]:
tsne = TSNE(n_components=1, random_state=42)
label_pos = tsne.fit_transform(val_dataset_traj.labels[['lon', 'lat']])
tsne = TSNE(n_components=1, random_state=42)
label_obj = tsne.fit_transform(val_dataset_traj.labels[['traj_id', 'obj_id']])

In [None]:
plt.hist(recon_error, bins=10, edgecolor='black', alpha=0.7)

In [None]:
percentile_99 = np.percentile(recon_error, 99)
print(percentile_99)

In [None]:
max(recon_error)

In [None]:
plot_tsne_with_uncertainty(latent_representations, recon_error, uncertainty_type='recon_error', threshold = percentile_99)

In [None]:
plot_tsne_with_uncertainty(latent_representations, label_pos, uncertainty_type='pos')

In [None]:
plot_tsne_with_uncertainty(latent_representations, label_obj, uncertainty_type='obj')

In [None]:
latent_representations_ood, recon_error_ood = get_latent_space(model, val_ood_dataloader_traj,criterion)
combined_latent_representations = np.concatenate([latent_representations, latent_representations_ood], axis = 0)
combined_recon_error = recon_error + recon_error_ood


In [None]:
percentile_99 = np.percentile(combined_recon_error, 99)
print(percentile_99)
plot_tsne_with_uncertainty(combined_latent_representations, combined_recon_error, uncertainty_type='recon_error', threshold = percentile_99)

In [None]:
ood_labels = [0] * len(latent_representations) + [1] * len(latent_representations_ood)
plot_tsne_with_uncertainty(combined_latent_representations, ood_labels, uncertainty_type='ood label')

In [None]:
load_model(model, save_model_path, device='cuda')

In [None]:
#for key in val_dataset_traj.labels.keys():
for key in ['epoch', 'stopped', 'cog_c', 'aad', 'rot_c', 'speed_c', 'distance_c',
       'acc_c', 'cdd', 'dir_ccs', 'dist_ww', 'dist_ra',
       'dist_cl', 'dist_ma', 'traj_id', 'lon', 'lat', 'obj_id', 'datetime',
       'season', 'part_of_day', 'month_sin', 'month_cos', 'hour_sin',
       'hour_cos']:
    plot_tsne_with_uncertainty(latent_representations, val_dataset_traj.labels[key], uncertainty_type=key)

In [None]:
dist = np.sqrt(val_dataset_traj.labels["lon"]**2 + val_dataset_traj.labels["lat"]**2)
plot_tsne_with_uncertainty(latent_representations, dist, uncertainty_type=key)