# <a id='toc1_'></a>[Training](#toc0_)

**Table of contents**<a id='toc0_'></a>    
1. [Import dependencies](#toc1_)    
2. [Load scaled parquet into Pandas DataFrame](#toc2_)
3. [Dataloader](#toc3_)
4. [Training](#toc4_)
5. [Plotting](#toc5_)

## 1. <a id='toc1'></a>[Import dependencies](#toc1_)

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
##> import libraries
import sys
from pathlib import Path
root_dir = Path.cwd().resolve().parent
if root_dir.exists():
    sys.path.append(str(root_dir))
else:
    raise FileNotFoundError('Root directory not found')

#> import custom libraries
from src.load import load_df_to_dataset
from src.models import TransformerDenoiseAutoEncoder
from src.train import train_and_evaluate
from src.traj_dataloader import (TrajectoryDataset, 
                                 DenoiseAutoencoderSequencedDataset
                                 )
from src.plot import  plot_losses


#> torch libraries
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn


#> Plot
import matplotlib.pyplot as plt
import seaborn as sns
import scienceplots  # https://github.com/garrettj403/SciencePlots?tab=readme-ov-file
plt.style.use(['science', 'grid', 'notebook'])  # , 'ieee'


# %matplotlib inline
%matplotlib widget

In [3]:
assets_dir = root_dir.parent / 'data' / 'local' / 'aistraj' / 'tvt_assets'
assets_dir = assets_dir.resolve()
if not assets_dir.exists():
    raise FileNotFoundError('Assets directory not found')

models_dir = root_dir / 'models' / 'sga'
models_dir = models_dir.resolve()
if not models_dir.exists():
    raise FileNotFoundError('Models directory not found')

scaled_tvt_data_import_assets_dir = assets_dir / 'scaled' 
scaled_tvt_data_import_assets_dir = scaled_tvt_data_import_assets_dir.resolve()
if not scaled_tvt_data_import_assets_dir.exists():
    raise FileNotFoundError('Train-Validate-Test Pickled Data directory not found')

extend_tvt_data_import_assets_dir = assets_dir / 'extended' 
extend_tvt_data_import_assets_dir = extend_tvt_data_import_assets_dir.resolve()
if not extend_tvt_data_import_assets_dir.exists():
    raise FileNotFoundError('Train-Validate-Test Pickled Data directory not found')

tvt_data_import_assets_dir = assets_dir / 'original' 
tvt_data_import_assets_dir = tvt_data_import_assets_dir.resolve()
if not tvt_data_import_assets_dir.exists():
    raise FileNotFoundError('Train-Validate-Test Pickled Data directory not found')

## 2. <a id='toc2_'></a>[Load scaled parquet into Pandas DataFrame](#toc2_)

In [4]:
# Define the paths to the pickle files
train_pickle_path = scaled_tvt_data_import_assets_dir / 'scaled_cleaned_extended_train_df.parquet'
validate_pickle_path = scaled_tvt_data_import_assets_dir / 'scaled_cleaned_extended_validate_df.parquet'
test_pickle_path = scaled_tvt_data_import_assets_dir / 'scaled_cleaned_extended_test_df.parquet'

train_df = load_df_to_dataset(train_pickle_path).data
validate_df = load_df_to_dataset(validate_pickle_path).data
test_df = load_df_to_dataset(test_pickle_path).data

In [6]:
validate_df.columns

Index(['epoch', 'datetime', 'obj_id', 'traj_id', 'month_sin', 'month_cos',
       'hour_sin', 'hour_cos', 'season', 'part_of_day', 'aad', 'cdd',
       'dir_ccs', 'cog_c', 'rot_c', 'distance_c', 'dist_ww', 'dist_ra',
       'dist_cl', 'dist_ma', 'speed_c', 'acc_c', 'lon', 'lat'],
      dtype='object')

## 3. <a id='toc3_'></a>[Dataloader](#toc3_)

In [10]:
drop_features_list = ['epoch', 'datetime', 'obj_id', 'traj_id']

train_dataset_seq = DenoiseAutoencoderSequencedDataset(train_df, drop_features_list, seq_len=256)
val_dataset_seq = DenoiseAutoencoderSequencedDataset(validate_df, drop_features_list, seq_len=256)

batch_size = 32 

train_dataloader_seq = DataLoader(train_dataset_seq, batch_size=batch_size, num_workers=16, shuffle=False, pin_memory=True)
val_dataloader_seq = DataLoader(val_dataset_seq, batch_size=batch_size, num_workers=16, shuffle=False, pin_memory=True)

print (f'feature_columns: {train_dataset_seq.feature_columns}')
print (f'n_features: {train_dataset_seq.n_features}')
print (f'row of train_dataset: {train_dataset_seq.total_sequences}')
print (f'row of train_df: {train_dataset_seq.l_dataset}')
print (f'padding need for train_df: {train_dataset_seq.padding_needed}')
print (f'-'*40)
print (f'feature_columns: {val_dataset_seq.feature_columns}')
print (f'n_features: {val_dataset_seq.n_features}')
print (f'row of val_dataset: {val_dataset_seq.total_sequences}')
print (f'row of val_df: {val_dataset_seq.l_dataset}')
print (f'padding need for val_df: {val_dataset_seq.padding_needed}')


feature_columns: Index(['month_sin', 'month_cos', 'hour_sin', 'hour_cos', 'season',
       'part_of_day', 'aad', 'cdd', 'dir_ccs', 'cog_c', 'rot_c', 'distance_c',
       'dist_ww', 'dist_ra', 'dist_cl', 'dist_ma', 'speed_c', 'acc_c', 'lon',
       'lat'],
      dtype='object')
n_features: 20
row of train_dataset: 57444
row of train_df: 14705500
padding need for train_df: 164
----------------------------------------
feature_columns: Index(['month_sin', 'month_cos', 'hour_sin', 'hour_cos', 'season',
       'part_of_day', 'aad', 'cdd', 'dir_ccs', 'cog_c', 'rot_c', 'distance_c',
       'dist_ww', 'dist_ra', 'dist_cl', 'dist_ma', 'speed_c', 'acc_c', 'lon',
       'lat'],
      dtype='object')
n_features: 20
row of val_dataset: 14700
row of val_df: 3763005
padding need for val_df: 195


## 4. <a id='toc4_'></a>[Training](#toc4_)

In [1]:
input_dim = val_dataset_seq.n_features  
d_model = 4  # Transformer model dimensions
nhead = 4  # Number of heads of multi-attention mechanisms
num_encoder_layers = 1  # Number of encoder layers
num_decoder_layers = 1  # Number of decoder layers
dim_feedforward = 256  # feedforward network dimension
max_seq_length = 256  
dropout_rate = 0.1

model_transformerDAE = TransformerDenoiseAutoEncoder(input_dim, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, max_seq_length, dropout_rate)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

model_transformerDAE = model_transformerDAE.to(device)

# Loss function
criterion = nn.MSELoss()

# Optimizer
optimizer = optim.Adam(model_transformerDAE.parameters(), lr=0.001, weight_decay=1e-5)

# Start training
train_losses_transformerDAE, eval_losses_transformerDAE, encoded_features_transformerDAE, all_inputs_transformerDAE, all_reconstructions_transformerDAE = train_and_evaluate(model=model_transformerDAE, train_dataloader=train_dataloader_seq, eval_dataloader=val_dataloader_seq, test_dataloader=None, optimizer=optimizer, criterion=criterion, model_save_path=models_dir, pickle_save_path=models_dir, model_name = 'transformer_denoiseautoencoder_model_parquet_', epochs=100, mode='train', patience=10)

## 5. <a id='toc5_'></a>[Plotting](#toc5_)

In [2]:
plot_losses (train_losses_transformerDAE, eval_losses_transformerDAE)
