In [None]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from statistics import mean

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import seaborn as sns

from himodule.custom_classes import NasaDataset, SimpleAE, LossAndMetric
from himodule.ae_metrics import MAPE
from himodule.normalisation import StandardScaler, MinMaxScaler
from himodule.secondary_funcs import save_object, check_path, split_dataset, seed_everything

import os

sns.set_theme(style='whitegrid', font_scale=1.2)

In [None]:
# Check for GPU availability

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f'{device=}')

### Training

In [None]:
seed = 37

# Whole dataset loading
train_dataset = NasaDataset('../datasets/clean_train_data.csv')
seed_everything(seed)
# Get datasets for training and validation
train_dataset, val_dataset = split_dataset(train_dataset, test_size=0.3)

# Test dataset loading
test_dataset = NasaDataset('../datasets/clean_test_data.csv')

scaler = None
try:
    norm_name = repr(scaler).split(' ', maxsplit=2)[0].split('.')[-1]
except IndexError:
    norm_name = 'no_scaling'
for idx, dtset in enumerate((train_dataset, val_dataset, test_dataset)):
    dtset.to(device)
    if scaler:
        if idx == 0:
            scaler.fit(dtset.dataset)
        dtset.dataset = scaler.transform(dtset.dataset)

# Save trained scaler to use it in another files    
scaler_path = os.path.join('../scalers/', f'{norm_name}.pkl')
if scaler and not os.path.exists(scaler_path):
    save_object(scaler, scaler_path)

seed_everything(seed)
g = torch.Generator()
g.manual_seed(seed)
train_loader = DataLoader(train_dataset, batch_size=20, shuffle=True, worker_init_fn=seed, generator=g)
seed_everything(seed)
val_loader = DataLoader(train_dataset, batch_size=20, shuffle=True, worker_init_fn=seed, generator=g)

seed_everything(seed)
test_loader = DataLoader(val_dataset, batch_size=20, shuffle=True, worker_init_fn=seed, generator=g)

print(f'Train: {len(train_dataset)}\nValidation: {len(val_dataset)}\nTest: {len(test_dataset)}')

input_shape = train_dataset.get_input_shape()
layers_sizes = (8, 4, 2)

# Model creating
seed_everything(seed)
model_ae = SimpleAE(input_shape, layers_sizes).to(device)
loss_func = nn.MSELoss()
metric_func = MAPE()
get_loss_and_metric = LossAndMetric(loss_func, metric_func, scaler)
optimiser = optim.AdamW(model_ae.parameters(),
                       lr=1e-3)
optimiser_name = repr(optimiser).split(' ', maxsplit=1)[0]

In [None]:
def evaluate_model(model: nn.Module, loader: DataLoader, ):

    '''Returns losses and metrics of trained model.'''

    with torch.no_grad():
        losses = list()
        metrics = list()
        for dta in loader:
            sample = dta['sensors']
            sample = sample.to(device)
            _, reconstruction = model(sample)

            loss, metric = get_loss_and_metric(reconstruction, sample)

            losses.append(loss.item())
            metrics.append(metric.item())
    return losses, metrics


def plot_history(loss_history_df: pd.DataFrame, metric_history_df: pd.DataFrame,
                 test_losses: float, test_metrics: float,
                 ylabel_loss: str = None, ylabel_metric: str = None, save_path: str = None):
    
    '''Makes plot with 4 axes: training losses and metrics, test losses, validation losses and metrics, test metrics.'''

    loss_history_df = loss_history_df.melt(ignore_index=False).iloc[1:]
    metric_history_df = metric_history_df.melt(ignore_index=False).iloc[1:]
    fig = plt.figure(layout="constrained")
    fig.set_size_inches(10, 10)

    gs = GridSpec(2, 5, figure=fig)
    ax1 = fig.add_subplot(gs[0, 0:4])
    ax2 = fig.add_subplot(gs[0, 4])
    
    ax3 = fig.add_subplot(gs[1, 0:4])
    ax4 = fig.add_subplot(gs[1, 4])

    sns.lineplot(data=loss_history_df,
                x=loss_history_df.index,
                y='value',
                hue='variable',
                ax=ax1)
    ax1.set_ylabel(ylabel_loss)

    sns.boxplot(x=['test_loss']*len(test_losses),
                y=test_losses,
                color='g',
                ax=ax2)
    ax2.set_ylabel(None)

    sns.lineplot(data=metric_history_df,
                x=metric_history_df.index,
                y='value',
                hue='variable',
                ax=ax3)
    ax3.set_ylabel(ylabel_metric)

    ax3.set_ylim(0, 20)

    sns.boxplot(x=['test_metric']*len(test_metrics),
                y=test_metrics,
                color='g',
                ax=ax4)
    ax4.set_ylabel(None)

    ax4.set_ylim(0, 100)

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path)
    plt.show()

In [None]:
epochs = 100
history = list()

# Model training on normal and anomaly data

for epoch in range(epochs):
    train_losses = list()
    train_metrics = list()
    for dta in train_loader:
        sample = dta['sensors']
        sample = sample.to(device)
        _, reconstruction = model_ae(sample)

        loss, metric = get_loss_and_metric(reconstruction, sample)

        optimiser.zero_grad()
        loss.backward()
        optimiser.step()

        train_losses.append(loss.item())
        train_metrics.append(metric.item())
    
    val_losses, val_metrics = evaluate_model(model_ae, val_loader)
    
    train_loss, val_loss = mean(train_losses), mean(val_losses)
    train_metrics, val_metrics = mean(train_metrics), mean(val_metrics)
    history.append((epoch, train_loss, val_loss, train_metrics, val_metrics))
    if (epoch + 1) % 10 == 0 or epoch == epochs - 1:
        print(f'{epoch+1:>3}/{epochs:>3}: {train_loss=:.4f}, {val_loss=:.4f}, {train_metrics=:.4f}%, {val_metrics=:.4f}%')

test_losses, test_metrics = evaluate_model(model_ae, test_loader)
test_loss = mean(test_losses)
test_metric = mean(test_metrics)
print(f'\n{test_loss=:.4f}, {test_metric=:.4f}%')


# -------------------------------------------------- #
name = str(layers_sizes)
models_path = '../Models/'
plot_path = '../Plots/history/'


if True:
    # Save model
    pth = os.path.join(models_path, optimiser_name, norm_name)
    check_path(pth)
    torch.save(model_ae.state_dict(), os.path.join(pth, f'{name}.pth'))

if True:
    # Make plots to evaluate model performance

    columns = ('epoch', 'train_loss', 'val_loss', 'train_metric', 'val_metric')
    total_history_df = pd.DataFrame(history, columns=columns).set_index('epoch')

    pth = os.path.join(plot_path, optimiser_name, norm_name)
    check_path(pth)

    loss_history_df = total_history_df.loc[:,('train_loss', 'val_loss')]
    metric_history_df = total_history_df.loc[:,('train_metric', 'val_metric')]
    plot_history(loss_history_df, metric_history_df, test_losses, test_metrics,
                ylabel_loss='MSE', ylabel_metric='MAPE', save_path=os.path.join(pth, f'{name}.png'))