# VAE-GNA Jupyter Notebook
---------------------------------

**Autor:** Matheus Becali Rocha \
**Email:** matheusbecali@gmail.com


#### Notebook Summary

This notebook explores Variational Autoencoders (VAEs) and includes various functions, models, and optimization techniques.

Setup and Imports:
- Imports necessary libraries like `os`, `optuna`, `torch`, and `pyspectra` to handle data, models, and optimization.

Auxiliary Functions:
- **save_checkpoint**: Saves the current state of the model to a file.
- **SNV**: Standard Normal Variate function to normalize data.
- **idx2onehot**: Converts index values to one-hot encoded vectors for categorical data.

Loss Functions:
- **focal_loss**: A loss function used to handle class imbalance by focusing on hard-to-classify examples.
- **adaptative_focal_loss**: An adaptive version of the focal loss function.

Data Preparation:
- Sets up data handling by defining constants and loading the necessary datasets.
- Configures settings for data visualization to plot and analyze data distributions.

Model Definitions:
- **prepare_data_loader**: Prepares the data loader for feeding data into the model.
- **ClassifyingNetwork**: Defines a Multi-Layer Perceptron (MLP) network for classifying data.
- **AttentionLayer**: Implements an attention mechanism to improve model focus on important features.

Optimization with Optuna:
- **optuna_run**: Runs an optimization process using Optuna to find the best hyperparameters for the model.

Experiment Runs:
- Defines hyperparameters and timings for running multiple experiments.
- Includes scripts to execute and visualize results for various models and configurations.

Result Visualization:
- Plots reconstructed data to visualize how well the model has learned to replicate input data.
- Plots loss curves to show the model's training progress and performance over time.


In [None]:
import os
import optuna
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from collections import OrderedDict
import matplotlib
import csv
import pickle

# matplotlib.style.use('dark_background')

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score, balanced_accuracy_score, recall_score, precision_score, f1_score, roc_auc_score, confusion_matrix
from sklearn.model_selection import KFold, StratifiedKFold

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, ConcatDataset

# Set the device to GPU if available, otherwise use CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

try:
    # Print the name of the CUDA device if available
    print(torch.cuda.get_device_name(device))
except Exception as e:
    # Handle the exception if CUDA device is not found
    print('CUDA device not found, using CPU instead.')

# Set a random seed for reproducibility
_seed = 78645


# Loading some functions and dataset

### Auxiliar Functions

In [None]:
import os
import csv

def save_results_to_csv(filename, results_data, header):
    """
    Save results data to a CSV file. If the file does not exist, create it and write the header.

    Parameters:
    filename (str): The name of the file to which the results will be saved.
    results_data (dict): The results data to be written to the CSV file.
    header (list): A list of strings representing the header of the CSV file.
    """
    # Check if the file already exists
    file_exists = os.path.isfile(filename)

    # Open the file in append mode
    with open(filename, 'a', newline='') as csvfile:
        # Create a CSV DictWriter object
        writer = csv.DictWriter(csvfile, fieldnames=header)

        # If the file does not exist, write the header
        if not file_exists:
            writer.writeheader()

        # Write the results data to the CSV file
        writer.writerow(results_data)


In [None]:
def SNV(input_data):
    """
    Standard Normal Variate (SNV) transformation: subtracts the row mean from each row and scales to unit variance.

    Parameters:
    input_data (pd.DataFrame): Input data with rows as samples and columns as features.

    Returns:
    np.ndarray: Transformed data with SNV applied.
    """
    input_data = input_data.to_numpy()
    # Define a new array and populate it with the corrected data  
    output_data = np.zeros_like(input_data)
    for i in range(input_data.shape[1]):
        # Apply SNV correction
        output_data[:, i] = (input_data[:, i] - np.mean(input_data[:, i])) / np.std(input_data[:, i], ddof=1)
    
    return output_data

def LOaO(X):
    """
    Linear Offset and Amplitude Scaling (LOaO): scales data to the range [-1, 1].

    Parameters:
    X (pd.DataFrame or np.ndarray): Input data.

    Returns:
    np.ndarray: Transformed data with values scaled to the range [-1, 1].
    """
    return 2 * ((X - X.min()) / (X.max() - X.min())) - 1


# Mofificado do Pyspectra

class GeneralTransformer:
    """
    A general transformer class for data preprocessing.
    """

    def __init__(self):
        pass

    def fit(self, spc):
        """
        Fit the transformer to the data.

        Parameters:
        spc (pd.DataFrame): Input data with rows as samples and columns as features.
        """
        pass

    def transform(self, spc):
        """
        Transform the data using the fitted parameters.

        Parameters:
        spc (pd.DataFrame): Input data with rows as samples and columns as features.

        Returns:
        pd.DataFrame: Transformed data.
        """
        pass

    def fit_transform(self, spc):
        """
        Fit the transformer to the data and then transform it.

        Parameters:
        spc (pd.DataFrame): Input data with rows as samples and columns as features.

        Returns:
        pd.DataFrame: Transformed data.
        """
        self.fit(spc)
        return self.transform(spc)

class SNVTransformer(GeneralTransformer):
    """
    SNVTransformer: A transformer for performing Standard Normal Variate (SNV) transformation.
    """

    def __init__(self):
        """
        Initializes the SNVTransformer.
        """
        self.MeanSpectra = None
        self.StdSpectra = None

    def fit(self, spc):
        """
        Calculate the mean and standard deviation for SNV transformation.

        Parameters:
        spc (pd.DataFrame): Input data with rows as samples and columns as features.
        """
        self.MeanSpectra = spc.mean(axis=0)
        self.StdSpectra = spc.std(axis=0)

    def transform(self, spc):
        """
        Apply SNV transformation to the data.

        Parameters:
        spc (pd.DataFrame): Input data with rows as samples and columns as features.

        Returns:
        pd.DataFrame: Transformed data with SNV applied.
        """
        return (spc - self.MeanSpectra) / self.StdSpectra

In [None]:
def idx2onehot(idx, n):
    """
    Convert indices to one-hot encoded vectors.

    Parameters:
    idx (torch.Tensor): Tensor containing indices.
    n (int): Number of classes.

    Returns:
    torch.Tensor: One-hot encoded tensor.
    """
    assert torch.max(idx).item() < n

    if idx.dim() == 1:
        idx = idx.unsqueeze(1)
    
    onehot = torch.zeros(idx.size(0), n).to(idx.device)
    onehot.scatter_(1, idx, 1)
    
    return onehot

def plot_gallery(spectre, epoch, fold, model_name, n_row=3, n_col=6, all_plot=False):
    """
    Plot a gallery of spectre data.

    Parameters:
    spectre (list of torch.Tensor): List of tensors containing spectre data.
    epoch (int): Current epoch number.
    fold (int): Current fold number.
    model_name (str): Name of the model.
    n_row (int, optional): Number of rows in the plot. Defaults to 3.
    n_col (int, optional): Number of columns in the plot. Defaults to 6.
    all_plot (bool, optional): Flag indicating whether to plot all data or not. Defaults to False.
    """
    
    eixo_x = np.arange(125)
    # eixo_x = np.arange(1401) # BCA cancer
    # eixo_x = np.arange(22) # PCA-BCA cancer
    # eixo_x = np.arange(1557) # Urine 1557 w/o derivate
    # eixo_x = np.arange(1555) # Urine 1555 w derivate

    if all_plot:
        plt.figure(figsize=(8 * n_col, 8 * n_row))
        for i in range(spectre[0].size()[0]):
            for j in range(n_row * n_col):
                ax = plt.subplot(n_row, n_col, j + 1)
                if j == 0: 
                    ax.set_title("Input Data")
                else:
                    ax.set_title("Recon Data")
                ax.axis("off")
                ax.plot(eixo_x, spectre[j][i].squeeze(0))
        plt.savefig(f'imgs/{model_name}/image_at_epoch_{epoch:04d}_fold_{str(fold)}.png')
        plt.show()
        plt.close()
    else:
        plt.figure(figsize=(2 * n_col, 2 * n_row))
        for i in range(n_row * n_col):
            ax = plt.subplot(n_row, n_col, i + 1)
            if i == 0: 
                ax.set_title("Input Data")
            else:
                ax.set_title("Recon Data")
            ax.axis("off")
            ax.plot(eixo_x, spectre[i][0].squeeze(0))

        plt.savefig(f'imgs/{model_name}/image_at_epoch_{epoch:04d}_fold_{str(fold)}.png')
        plt.show()
        plt.close()

In [None]:
from torch.distributions.normal import Normal

def expected_sigm_of_norm(
    mean: torch.Tensor, std: torch.Tensor, method = 'probit'
) -> torch.Tensor:
    r"""
    Approximate the expected value of the sigmoid of a normal distribution.

    Thanks go to this guy:

    https://math.stackexchange.com/questions/207861/expected-value-of-applying-the-sigmoid-function-to-a-normal-distribution    

    Parameters:
    mean (torch.Tensor): Mean of the normal distribution.
    std (torch.Tensor): Standard deviation of the normal distribution.
    method (str): Method to use for the approximation. Options are 'probit', 'maclauren-2', 'maclauren-3'.

    Returns:
    torch.Tensor: An approximation to `E(sigmoid(N(mean, std**2)))`.
    """
    
    if method == 'maclauren-2':
        eu = torch.exp(-mean)
        approx_exp = 1/(eu + 1) + 0.5*(eu - 1)*eu / ((eu+1)**3) * std**2
        return torch.clamp(approx_exp, min=0, max=1)
    elif method == 'maclauren-3':
        eu = torch.exp(-mean)
        approx_exp = 1/(eu + 1) + 0.5*(eu - 1)*eu / ((eu + 1)**3) * std**2 + (eu**3 - 11*(eu**2) + 11*eu - 1) / ((8*(eu + 1))**5) * std**4
        return torch.clamp(approx_exp, min=0, max=1)
    elif method == 'probit':
        # lambd = 0.61 # lambda approx 0.61, suggests in article https://arxiv.org/abs/1703.00091
        lambd = np.sqrt(np.pi / 8)
        dist = Normal(loc = 0, scale = 1)
        return dist.cdf(mean.clone() / torch.sqrt(1 / (lambd ** 2) + std.clone() ** 2))
    else:
        raise Exception('Method "% s" not known' % method)

Focal Loss

In [None]:
def focal_loss(
    input: torch.Tensor, target: torch.Tensor, alpha: float, gamma: float = 2.0, reduction: str = 'none'
) -> torch.Tensor:
    r"""Criterion that computes Focal loss.
    According to :cite:`lin2018focal`, the Focal loss is computed as follows:
    .. math::
        \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
    Where:
       - :math:`p_t` is the model's estimated probability for each class.
    Args:
        input: logits tensor with shape :math:`(N, C, *)` where C = number of classes.
        target: labels tensor with shape :math:`(N, *)` where each value is :math:`0 ≤ targets[i] ≤ C−1`.
        alpha: Weighting factor :math:`\alpha \in [0, 1]`.
        gamma: Focusing parameter :math:`\gamma >= 0`.
        reduction: Specifies the reduction to apply to the
          output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
          will be applied, ``'mean'``: the sum of the output will be divided by
          the number of elements in the output, ``'sum'``: the output will be
          summed.
        eps: Deprecated: scalar to enforce numerical stabiliy. This is no longer used.
    Return:
        the computed loss.
    Example:
        >>> N = 5  # num_classes
        >>> input = torch.randn(1, N, 3, 5, requires_grad=True)
        >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
        >>> output = focal_loss(input, target, alpha=0.5, gamma=2.0, reduction='mean')
        >>> output.backward()

    Thanks: https://github.com/zhezh/focalloss/blob/master/focalloss.py
    """

    # compute softmax over the classes axis
    input_soft = torch.softmax(input, dim=1)
    log_input_soft = torch.log_softmax(input, dim=1)

    # create the labels one hot tensor
    target_one_hot = torch.nn.functional.one_hot(target, num_classes=input.shape[1]).type(torch.FloatTensor).to(input.device)

    # compute the actual focal loss
    weight = torch.pow(-input_soft + 1.0, gamma)
    focal = -alpha * weight * log_input_soft
    loss_tmp = torch.einsum('...bc,...bc->...', (target_one_hot, focal))

    if reduction == 'none':
        loss = loss_tmp
    elif reduction == 'mean':
        loss = torch.mean(loss_tmp)
    elif reduction == 'sum':
        loss = torch.sum(loss_tmp)
    else:
        raise NotImplementedError(f"Invalid reduction mode: {reduction}")
    return loss


class FocalLoss(nn.Module):
    r"""Criterion that computes Focal loss.
    According to :cite:`lin2018focal`, the Focal loss is computed as follows:
    .. math::
        \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
    Where:
       - :math:`p_t` is the model's estimated probability for each class.
    Args:
        alpha: Weighting factor :math:`\alpha \in [0, 1]`.
        gamma: Focusing parameter :math:`\gamma >= 0`.
        reduction: Specifies the reduction to apply to the
          output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
          will be applied, ``'mean'``: the sum of the output will be divided by
          the number of elements in the output, ``'sum'``: the output will be
          summed.
        eps: Deprecated: scalar to enforce numerical stability. This is no longer
          used.
    Shape:
        - Input: :math:`(N, C, *)` where C = number of classes.
        - Target: :math:`(N, *)` where each value is
          :math:`0 ≤ targets[i] ≤ C−1`.
    Example:
        >>> N = 5  # num_classes
        >>> kwargs = {"alpha": 0.5, "gamma": 2.0, "reduction": 'mean'}
        >>> criterion = FocalLoss(**kwargs)
        >>> input = torch.randn(1, N, 3, 5, requires_grad=True)
        >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
        >>> output = criterion(input, target)
        >>> output.backward()
    """

    def __init__(self, alpha: float, gamma: float = 2.0, reduction: str = 'none') -> None:
        super().__init__()
        self.alpha: float = alpha
        self.gamma: float = gamma
        self.reduction: str = reduction

    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        return focal_loss(input, target, self.alpha, self.gamma, self.reduction)

Adaptative Focal Loss

In [None]:
def adaptative_focal_loss(
    input: torch.Tensor, target: torch.Tensor, beta: torch.Tensor, alpha: float, gamma_start: float = 2.0, reduction: str = 'none'
) -> torch.Tensor:
    r"""Criterion that computes Adaptative Focal loss.
    According to :cite:`lin2018focal`, the Focal loss is computed as follows:
    .. math::
        \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
    Where:
       - :math:`p_t` is the model's estimated probability for each class.
    Args:
        input: logits tensor with shape :math:`(N, C, *)` where C = number of classes.
        target: labels tensor with shape :math:`(N, *)` where each value is :math:`0 ≤ targets[i] ≤ C−1`.
        alpha: Weighting factor :math:`\alpha \in [0, 1]`.
        gamma: Focusing parameter :math:`\gamma >= 0`.
        reduction: Specifies the reduction to apply to the
          output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
          will be applied, ``'mean'``: the sum of the output will be divided by
          the number of elements in the output, ``'sum'``: the output will be
          summed.
        eps: Deprecated: scalar to enforce numerical stabiliy. This is no longer used.
    Return:
        the computed loss.
    Example:
        >>> N = 5  # num_classes
        >>> input = torch.randn(1, N, 3, 5, requires_grad=True)
        >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
        >>> output = adaptative_focal_loss(input, target, alpha=0.5, gamma=2.0, reduction='mean')
        >>> output.backward()
    """

    # compute softmax over the classes axis
    input_soft = torch.softmax(input, dim=1)
    log_input_soft = torch.log_softmax(input, dim=1)

    # create the labels one hot tensor
    target_one_hot = torch.nn.functional.one_hot(target, num_classes=input.shape[1]).type(torch.FloatTensor).to(input.device)

    # compute the actual focal loss
    gamma_epoch = (gamma_start / torch.sqrt(beta)).to(input.device)
    weight = torch.pow(-input_soft + 1.0, gamma_epoch)
    focal = -alpha * weight * log_input_soft
    loss_tmp = torch.einsum('...bc,...bc->...', (target_one_hot, focal))

    if reduction == 'none':
        loss = loss_tmp
    elif reduction == 'mean':
        loss = torch.mean(loss_tmp)
    elif reduction == 'sum':
        loss = torch.sum(loss_tmp)
    else:
        raise NotImplementedError(f"Invalid reduction mode: {reduction}")
    return loss


class AdaptativeFocalLoss(nn.Module):
    r"""Criterion that computes Adaptative Focal loss.
    According to :cite:`lin2018focal`, the Focal loss is computed as follows:
    .. math::
        \text{FL}(p_t) = -\alpha_t (1 - p_t)^{\gamma} \, \text{log}(p_t)
    Where:
       - :math:`p_t` is the model's estimated probability for each class.
    Args:
        alpha: Weighting factor :math:`\alpha \in [0, 1]`.
        gamma: Focusing parameter :math:`\gamma >= 0`.
        reduction: Specifies the reduction to apply to the
          output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction
          will be applied, ``'mean'``: the sum of the output will be divided by
          the number of elements in the output, ``'sum'``: the output will be
          summed.
        eps: Deprecated: scalar to enforce numerical stability. This is no longer
          used.
    Shape:
        - Input: :math:`(N, C, *)` where C = number of classes.
        - Target: :math:`(N, *)` where each value is
          :math:`0 ≤ targets[i] ≤ C−1`.
    Example:
        >>> N = 5  # num_classes
        >>> kwargs = {"alpha": 0.5, "gamma": 2.0, "reduction": 'mean'}
        >>> criterion = FocalLoss(**kwargs)
        >>> input = torch.randn(1, N, 3, 5, requires_grad=True)
        >>> target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
        >>> output = criterion(input, target)
        >>> output.backward()
    """
 
    def __init__(self, alpha: float, gamma: float = 2.0, reduction: str = 'none') -> None:
        super().__init__()
        self.alpha: float = alpha
        self.gamma: float = gamma
        self.reduction: str = reduction

    def forward(self, input: torch.Tensor, target: torch.Tensor, beta: torch.Tensor) -> torch.Tensor:
        return adaptative_focal_loss(input, target, beta, self.alpha, self.gamma, self.reduction)

In [None]:
from pyspectra.transformers.spectral_correction import msc, snv, detrend, derivative, sav_gol

def preprocessing(data, NORM, _normalization, _train=True, d=1):
    """
    Preprocess the data using various normalization techniques.

    Parameters:
    data (pd.DataFrame or np.ndarray): Input data to be normalized.
    NORM (object): Normalization object to be used for fitting and transforming.
    _normalization (str): Type of normalization to be applied.
    _train (bool, optional): Flag indicating if the normalization should be fitted. Defaults to True.
    d (int, optional): Derivative order for 'derivate' and 'Sav_Gol' normalization. Defaults to 1.

    Returns:
    tuple: Transformed data and normalization object.
    """
    if _normalization == "SNV":
        print("Normalization SNV has been choice!")
        # Implemented
        # X_data = SNV(data)

        # Pyspectra
        if _train:
            X_data = NORM.fit_transform(data)
        else:
            X_data = NORM.transform(data)
    elif _normalization == "MSC":
        MSC= msc()
        MSC.fit(data)
        X_data=MSC.transform(data)
    elif _normalization == "SNV_Detrend":
        SNV= snv()
        X_data=SNV.fit_transform(data)

        Detr= detrend()
        X_data.columns = [
            908.1, 914.294, 920.489, 926.683, 932.877, 939.072, 
            945.266, 951.46, 957.655, 963.849, 970.044, 976.238, 
            982.432, 988.627, 994.821, 1001.015, 1007.21, 1013.404, 
            1019.598, 1025.793, 1031.987, 1038.181, 1044.376, 1050.57, 
            1056.764, 1062.959, 1069.153, 1075.348, 1081.542, 1087.736, 
            1093.931, 1100.125, 1106.319, 1112.514, 1118.708, 1124.902, 
            1131.097, 1137.291, 1143.485, 1149.68, 1155.874, 1162.069, 
            1168.263, 1174.457, 1180.652, 1186.846, 1193.04, 1199.235, 
            1205.429, 1211.623, 1217.818, 1224.012, 1230.206, 1236.401, 
            1242.595, 1248.789, 1254.984, 1261.178, 1267.373, 1273.567, 
            1279.761, 1285.956, 1292.15, 1298.344, 1304.539, 1310.733, 
            1316.927, 1323.122, 1329.316, 1335.51, 1341.705, 1347.899, 
            1354.094, 1360.288, 1366.482, 1372.677, 1378.871, 1385.065, 
            1391.26, 1397.454, 1403.648, 1409.843, 1416.037, 1422.231, 
            1428.426, 1434.62, 1440.814, 1447.009, 1453.203, 1459.398, 
            1465.592, 1471.786, 1477.981, 1484.175, 1490.369, 1496.564, 
            1502.758, 1508.952, 1515.147, 1521.341, 1527.535, 1533.73, 
            1539.924, 1546.119, 1552.313, 1558.507, 1564.702, 1570.896, 
            1577.09, 1583.285, 1589.479, 1595.673, 1601.868, 1608.062, 
            1614.256, 1620.451, 1626.645, 1632.839, 1639.034, 1645.228, 
            1651.423, 1657.617, 1663.811, 1670.006, 1676.2
        ]
        X_data = Detr.fit_transform(spc=X_data,wave=np.array(X_data.columns))
    elif _normalization == 'derivate':
        print(f"Normalization DERIV{d} has been chosen!")
        X_data = NORM.fit_transform(spc=data, d=d, drop=True)
    elif _normalization == "Sav_Gol":
        print(f"Normalization Sav_Gol{d} has been chosen!")
        X_data = NORM.transform(spc=data, window=11, poly=3, deriv=d)
    elif _normalization == "MinMax":
        print("Normalization MinMaxScaler has been chosen!")
        scaler = MinMaxScaler()  # x - x_min / x_max - x_min
        X_data = scaler.fit_transform(data)
    elif _normalization == "StdScaler":
        print("Normalization StandardScaler has been chosen!")
        if _train:
            X_data = NORM.fit_transform(data)
        else:
            X_data = NORM.transform(data)
    elif _normalization == "LOaO":
        print("Normalization Less One and One has been chosen!")
        X_data = LOaO(data)
    elif _normalization is None:
        X_data = data
    else:
        raise Exception("Sorry, not implemented yet!")

    return X_data, NORM

### NIR-SC-UFES

In [None]:
# Define constants
_data = "PAD-UFES-IR"
Sampling_mode = "None"
_augmentation = "None"
_dataset_name = 'CandNC-ALL'  # Carcinoma_vs_ACK, Nev_vs_Mel
_seed = 78645

# Load the dataset
dataset = pd.read_csv(f'data/IR-Spectroscopy-PAD-UFES-V4-{_dataset_name}.csv', decimal=",")

# Drop rows where 'y' is 9 or 6
dataset.drop(dataset[dataset['y'] == 9].index, inplace=True)
dataset.drop(dataset[dataset['y'] == 6].index, inplace=True)
dataset.reset_index(inplace=True, drop=True)

# Extract labels and features
labels = dataset.loc[:, ['y']]
features = dataset.loc[:, 'data1':'data125']
classes = dataset.loc[:, ['Classe']]

# Concatenate features and classes
FeaturesAndClass = pd.concat([features, classes], axis=1)

# Encode class labels as categorical codes
labels_subclasses = pd.Categorical(FeaturesAndClass['Classe']).codes

# Split the dataset into training+validation and test sets
x_to_train_valid, X_test, y_to_train_valid, y_test = train_test_split(
    FeaturesAndClass, labels, stratify=labels_subclasses, test_size=0.1, random_state=_seed
)

# Encode class labels for the training+validation set
labels_subclasses_train_valid = pd.Categorical(x_to_train_valid['Classe']).codes

# Split the training+validation set into training and validation sets
X_train, X_valid, y_train, y_valid = train_test_split(
    x_to_train_valid, y_to_train_valid, stratify=labels_subclasses_train_valid, test_size=0.112, random_state=_seed
)

# Print class distributions for training, validation, and test sets
print(f"X_train Classe: \n{X_train['Classe'].value_counts()}")
print(f"X_valid Classe: \n{X_valid['Classe'].value_counts()}")
print(f"X_test Classe: \n{X_test['Classe'].value_counts()}")

# Reset index for all sets
X_train.reset_index(inplace=True, drop=True)
y_train.reset_index(inplace=True, drop=True)
X_valid.reset_index(inplace=True, drop=True)
y_valid.reset_index(inplace=True, drop=True)
X_test.reset_index(inplace=True, drop=True)
y_test.reset_index(inplace=True, drop=True)

# Extract test classes
test_classes = X_test.loc[:, 'Classe']


In [None]:
y_train['y'].value_counts(), y_valid['y'].value_counts(), y_test['y'].value_counts()

#### Plot

In [None]:
_plot_raw_data = False
_tsne_plot = False
_test_plot = False

if _plot_raw_data:
    from matplotlib import colors as mcolors

    columns = np.array([908.1, 914.294, 920.489, 926.683, 932.877, 939.072, 
                945.266, 951.46, 957.655, 963.849, 970.044, 976.238, 
                982.432, 988.627, 994.821, 1001.015, 1007.21, 1013.404, 
                1019.598, 1025.793, 1031.987, 1038.181, 1044.376, 1050.57, 
                1056.764, 1062.959, 1069.153, 1075.348, 1081.542, 1087.736, 
                1093.931, 1100.125, 1106.319, 1112.514, 1118.708, 1124.902, 
                1131.097, 1137.291, 1143.485, 1149.68, 1155.874, 1162.069, 
                1168.263, 1174.457, 1180.652, 1186.846, 1193.04, 1199.235, 
                1205.429, 1211.623, 1217.818, 1224.012, 1230.206, 1236.401, 
                1242.595, 1248.789, 1254.984, 1261.178, 1267.373, 1273.567, 
                1279.761, 1285.956, 1292.15, 1298.344, 1304.539, 1310.733, 
                1316.927, 1323.122, 1329.316, 1335.51, 1341.705, 1347.899, 
                1354.094, 1360.288, 1366.482, 1372.677, 1378.871, 1385.065, 
                1391.26, 1397.454, 1403.648, 1409.843, 1416.037, 1422.231, 
                1428.426, 1434.62, 1440.814, 1447.009, 1453.203, 1459.398, 
                1465.592, 1471.786, 1477.981, 1484.175, 1490.369, 1496.564, 
                1502.758, 1508.952, 1515.147, 1521.341, 1527.535, 1533.73, 
                1539.924, 1546.119, 1552.313, 1558.507, 1564.702, 1570.896, 
                1577.09, 1583.285, 1589.479, 1595.673, 1601.868, 1608.062, 
                1614.256, 1620.451, 1626.645, 1632.839, 1639.034, 1645.228, 
                1651.423, 1657.617, 1663.811, 1670.006, 1676.2])
    
    CBC_data = X_train[X_train['Classe'] == 'CBC'].reset_index(drop=True).drop('Classe', axis=1)
    CEC_data = X_train[X_train['Classe'] == 'CEC'].reset_index(drop=True).drop('Classe', axis=1)
    MEL_data = X_train[X_train['Classe'] == 'MEL'].reset_index(drop=True).drop('Classe', axis=1)
    ACK_data = X_train[X_train['Classe'] == 'ACK'].reset_index(drop=True).drop('Classe', axis=1)
    SEK_data = X_train[X_train['Classe'] == 'SEK'].reset_index(drop=True).drop('Classe', axis=1)
    NEV_data = X_train[X_train['Classe'] == 'NEV'].reset_index(drop=True).drop('Classe', axis=1)

    plt.figure(figsize=(16,10))
    plt.rcParams['legend.fontsize'] = 22
    plt.rcParams.update({'font.size': 25})

    plt.plot(columns, CBC_data.to_numpy()[6], linewidth=3, label="CBC", color='#F2071B')
    plt.plot(columns, CEC_data.to_numpy()[6], linewidth=3, label="CEC", color='#FC5A50')
    plt.plot(columns, MEL_data.to_numpy()[6], linewidth=3, label="MEL", color='#000000')
    plt.plot(columns, ACK_data.to_numpy()[6], linewidth=3, label="ACK", color='#15B01A')
    plt.plot(columns, SEK_data.to_numpy()[6], linewidth=3, label="SEK", color='#7BC8F6')
    plt.plot(columns, NEV_data.to_numpy()[6], linewidth=3, label="NEV", color='#031CA6')

    plt.legend()
    plt.xlabel("Comprimento de onda (nm)")
    plt.ylabel("Absorbância")


    dir_save = './plots_imgs/NIR-SC-UFES_6CLASS_OriginalData_Disserta.pdf'
    plt.savefig(dir_save)
    plt.show()

    if _tsne_plot = :
        data = dataset.loc[:,'data1':'data125']
        data_labels = dataset.loc[:,['y']].squeeze(1)

        _preprocess = True

        if _preprocess:
            print("Using preprocess data")
            features_norm, _ = preprocessing(data, snv(), _normalization = "SNV")
        else:
            print("Using original data")
            pass

        # Sklearn
        from sklearn.manifold import TSNE
        # import fitsne

        def plot_tsne(data, labels, n_components=2, perplexity=30, n_iter=750):
            """
            Plot t-SNE data.

            Args:
                data: The data to be plotted.
                labels: The labels of the data.
                n_components: The number of dimensions to reduce to.
                perplexity: The perplexity parameter.
                n_iter: The number of iterations.

            Returns:
                A matplotlib figure object.
            """
            try:
                tsne = TSNE(n_components=n_components, perplexity=perplexity, n_iter=n_iter, learning_rate = 'auto', init = 'pca')
                df_x_embedded = tsne.fit_transform(X=data)
                # df_x_embedded = fitsne.FItSNE(X=data, no_dims=n_components, perplexity=perplexity, max_iter=n_iter, rand_seed=_seed)

                fig = plt.figure(figsize=(16,10))
                ax = sns.scatterplot(x=df_x_embedded[:, 0], y=df_x_embedded[:, 1], 
                                hue=labels, legend='full', s=250, palette=sns.color_palette("bright", len(set(labels))))
                plt.title(f't-SNE Plot (Perplexity={perplexity}, Iterations={n_iter})')
                handles, labels  =  ax.get_legend_handles_labels()
                ax.legend(handles, ["Benigno", "Maligno"], loc='lower right')
                return fig

            except Exception as e:
                print(f"Error during t-SNE: {e}")
                return None

        if _test_plot:
            for i in range(5, 101, 5):
            # i = 30
                if _preprocess:
                    fig = plot_tsne(features_norm, data_labels, n_components=2, perplexity=i, n_iter=1000)
                else: 
                    fig = plot_tsne(data, data_labels, n_components=2, perplexity=i, n_iter=1000)
                # plt.show()

                if fig is not None:
                    if _preprocess:
                        plt.savefig(f'./t-SNE/NIR-SC-UFES/{i}_NIR-SC-UFES_fig_NORM.png')
                    else:
                        plt.savefig(f'./t-SNE/NIR-SC-UFES/{i}_NIR-SC-UFES_fig.png')
                    plt.show()
        else:
            plt.figure(figsize=(16,10))
            plt.rcParams['legend.fontsize'] = 22
            plt.rcParams.update({'font.size': 25})

            i = 30

            if _preprocess:
                print("pre-processed ON")
                fig = plot_tsne(features_norm, data_labels, n_components=2, perplexity=i, n_iter=1000)
            else:
                print("pre-processed OFF")
                fig = plot_tsne(data, data_labels, n_components=2, perplexity=i, n_iter=1000)
            
            if fig is not None:
                if _preprocess:
                    plt.savefig(f'./t-SNE/{i}_NIR-SC-UFES_fig_NORM.pdf')
                    plt.savefig(f'./t-SNE/{i}_NIR-SC-UFES_fig_NORM.png')
                else:
                    plt.savefig(f'./t-SNE/{i}_NIR-SC-UFES_fig.pdf')
                    plt.savefig(f'./t-SNE/{i}_NIR-SC-UFES_fig.png')
                plt.show()

# VAE 1-D

In [None]:
def prepare_data_loader(data, target, batch_size=32, shuffle=True):
    """
    Convert data and target to PyTorch tensors and prepare a data loader.

    Parameters:
    data (pd.DataFrame or np.ndarray): Input data to be converted to tensors.
    target (pd.Series or np.ndarray): Target data to be converted to tensors.
    batch_size (int, optional): Batch size for the data loader. Defaults to 32.
    shuffle (bool, optional): Whether to shuffle the data. Defaults to True.

    Returns:
    tuple: A tuple containing the data loader and a dictionary with the sizes of the tensors.
    """
    # Convert data and target to numpy arrays and then to PyTorch tensors
    tensor_data = torch.tensor(np.array(data, dtype=np.float32))
    tensor_target = torch.tensor(np.array(target, dtype=np.float32))

    # Create a TensorDataset from the data and target tensors
    dataset = torch.utils.data.TensorDataset(tensor_data, tensor_target)

    # Create a DataLoader from the dataset
    dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle)

    # Create a dictionary to store the sizes of the data and target tensors
    size = {
        'x_size': tensor_data.size(),
        'y_size': tensor_target.size()
    }

    return dataloader, size


## VAE - IR - Conv1D

In [None]:
model_name = 'E_MLP-IR_CNN-1D'
conv = True
    
class ResNet1DBlock(nn.Module):
    """
    A 1D Residual Block used in ResNet architecture.

    Parameters:
    in_channels (int): Number of input channels.
    out_channels (int): Number of output channels.
    kernel_size (int, optional): Size of the convolving kernel. Defaults to 3.
    stride (int, optional): Stride of the convolution. Defaults to 1.
    padding (int, optional): Zero-padding added to both sides of the input. Defaults to 1.
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(ResNet1DBlock, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)
        self.tanh = nn.Tanh()
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, stride, padding)
        self.shortcut = nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride)

    def forward(self, x):
        """
        Forward pass for the ResNet1DBlock.

        Parameters:
        x (torch.Tensor): Input tensor.

        Returns:
        torch.Tensor: Output tensor after applying the ResNet block.
        """
        out = self.conv1(x)
        out = self.tanh(out)
        out = self.conv2(out)
        shortcut = self.shortcut(x)
        out += shortcut
        out = self.tanh(out)
        return out
    
class VAE(nn.Module):
    """
    A Variational Autoencoder (VAE) with 1D ResNet blocks.

    Parameters:
    latent_dims (int): Dimension of the latent space.
    attention (optional): Attention mechanism. Defaults to None.
    incerteza (bool, optional): Flag for uncertainty. Defaults to False.
    """
    def __init__(self, latent_dims, attention = None, incerteza=False):
        super().__init__()

        self.attention = attention
        self.incerteza = incerteza

        # Define the encoder
        self.encoder = nn.Sequential(OrderedDict([
            ('ConvUnit_1', nn.Conv1d(1, 32, kernel_size=3, stride=1, padding=1)),
            ('ConvUnit_2', ResNet1DBlock(32, 64)),
            ('ConvUnit_3', ResNet1DBlock(64, 128)),
        ]))

        # Determine the size of the flattened layer
        with torch.no_grad():
            dummy_input = torch.randn(1, 1, 125) # NIR-SC-UFES
            # dummy_input = torch.randn(1, 1, 1401) # BCA
            # dummy_input = torch.randn(1, 1, 22) # BCA w/ PCA
            # dummy_input = torch.randn(1, 1, 1557) # Urine
            # dummy_input = torch.randn(1, 1, 1555) # Urine w/ derivative
            
            dummy_input = self.encoder(dummy_input)
            self.flattened_size = dummy_input.view(dummy_input.size(0), -1).size(1)

        # Define the mean and log-variance layers
        self.mean_layer = nn.Linear(self.flattened_size, latent_dims)
        self.logvar_layer = nn.Linear(self.flattened_size, latent_dims)

        # Define the fully connected layer for the decoder
        self.fc = nn.Linear(latent_dims, self.flattened_size)

        # Define the decoder
        self.decoder = nn.Sequential(OrderedDict([
            ('DeLinearUnit_1', ResNet1DBlock(128, 64)),
            ('DeLinearUnit_2', ResNet1DBlock(64, 32)),
            ('DeLinearUnit_3', nn.Conv1d(32, 1, kernel_size=3, stride=1, padding=1)),
        ]))


    def encode(self, x):
        """
        Encode the input into the latent space.

        Parameters:
        x (torch.Tensor): Input tensor.

        Returns:
        tuple: Mean and log-variance of the encoded tensor.
        """
        z = self.encoder(x)
        z = torch.reshape(z, (-1, self.flattened_size))
        mean = self.mean_layer(z)
        logvar = self.logvar_layer(z)
        return mean, logvar
    
    def decode(self, x):
        """
        Decode the latent representation back to the input space.

        Parameters:
        x (torch.Tensor): Latent representation tensor.

        Returns:
        torch.Tensor: Reconstructed tensor.
        """
        z = self.fc(x)
        z = torch.reshape(z, (-1, 128, self.flattened_size // 128))
        z = self.decoder(z)
        return z

    def reparameterize_trick(self, mean, logvar):
        """
        Reparameterization trick to sample from the latent space.

        Parameters:
        mean (torch.Tensor): Mean of the latent distribution.
        logvar (torch.Tensor): Log-variance of the latent distribution.

        Returns:
        torch.Tensor: Sampled latent vector.
        """
        std_dev = torch.exp(0.5 * logvar)
        epsilon = torch.randn_like(std_dev)
        return epsilon * std_dev + mean

    def sample_latent_vector(self, x):
        """
        Sample a latent vector from the input.

        Parameters:
        x (torch.Tensor): Input tensor.

        Returns:
        torch.Tensor: Sampled latent vector.
        """
        mean, logvar = self.encode(x)
        z = self.reparameterize_trick(mean, logvar)
        return z
        
    def mean_std(self, x):
        """
        Get the mean and standard deviation of the encoded input.

        Parameters:
        x (torch.Tensor): Input tensor.

        Returns:
        tuple: Mean and standard deviation tensors.
        """
        mean, logvar = self.encode(x)
        std_dev = torch.exp(0.5 * logvar)
        return mean, std_dev

    def forward(self, x, encoder=True, decoder=False):
        """
        Forward pass for the VAE.

        Parameters:
        x (torch.Tensor): Input tensor.
        encoder (bool, optional): Flag to return mean and std from encoder. Defaults to True.
        decoder (bool, optional): Flag to return decoded output. Defaults to False.

        Returns:
        tuple or torch.Tensor: Mean and standard deviation if encoder is True, otherwise decoded output.
        """
        if encoder:
            mean, logvar = self.encode(x)
            std_dev = torch.exp(0.5 * logvar)
            return mean, std_dev
        elif decoder:
            mean, logvar = self.encode(x)
            z = self.reparameterize_trick(mean, logvar)
            return self.decode(z), mean, logvar

##### MLP Classify Network

In [None]:
class ClassifyingNetwork(nn.Module):
    """
    A neural network for the classification tasks of VAE-GNA.

    Parameters:
    num_ftrs (int): Number of input features.
    """
    def __init__(self, num_ftrs):
        super(ClassifyingNetwork, self).__init__()
        
        # Define a simple Multi-Layer Perceptron (MLP) for classification task
        self.MLPclassify = nn.Sequential(
            nn.Flatten(),                           # Flatten the input tensor
            nn.Linear(num_ftrs, num_ftrs // 2),     # First fully connected layer
            nn.BatchNorm1d(num_ftrs // 2),          # Batch normalization for regularization
            nn.Dropout(0.1),                        # Dropout for regularization
            nn.Tanh(),                              # Tanh activation function
            nn.Linear(num_ftrs // 2, 2),            # Second fully connected layer to output 2 classes
        )

    def forward(self, x):
        """
        Forward pass for the ClassifyingNetwork.

        Parameters:
        x (torch.Tensor): Input tensor.

        Returns:
        torch.Tensor: Output tensor with class scores.
        """
        x = x.to(device)            # Move the input tensor to the appropriate device (CPU or GPU)
        z = self.MLPclassify(x)     # Pass the input through the MLP
        
        return z

##### Attention Default

In [None]:
class AttentionLayer(nn.Module):
    """
    An attention layer for VAE-GNA.

    Parameters:
    mean_std_size (int): Size of the input features (mean and std combined).
    mean_size (int): Size of the mean features.
    """
    def __init__(self, mean_std_size, mean_size):
        super(AttentionLayer, self).__init__()
        
        # Define the attention mechanism
        self.Attention = nn.Sequential(
            nn.Linear(mean_std_size, mean_size),  # Fully connected layer to reduce dimension
            nn.BatchNorm1d(mean_size),            # Batch normalization for regularization
            nn.Sigmoid(),                         # Sigmoid activation function
        )

        # Define the softmax layer
        self.Softlayer = nn.Sequential(
            nn.Softmax(dim=1)  # Softmax activation function over the specified dimension
        )

    def forward(self, x):
        """
        Forward pass for the AttentionLayer.

        Parameters:
        x (torch.Tensor): Input tensor.

        Returns:
        torch.Tensor: Output tensor after applying the attention mechanism and softmax.
        """
        Z = self.Attention(x)  # Apply the attention mechanism
        Z = self.Softlayer(Z)  # Apply softmax activation
        return Z


## Optuna

In [None]:
%%time

def optuna_run(_classify=False, optim_params=None):
    """
    Run Optuna for hyperparameter optimization with cross-validation.

    Parameters:
    _classify (bool, optional): Flag indicating if the model is for classification. Defaults to False.
    optim_params (dict, optional): Dictionary of fixed parameters for optimization. Defaults to None.

    Returns:
    dict: Best hyperparameters found during optimization.
    """

    # Hyperparameters definition
    _k_folds = 5
    _lr = 1e-4 # Learning rate
    _epochs = 2000
    _sched_factor = 0.1 
    _sched_min_lr = 1e-6
    _sched_patience = 20
    
    # Normalization method
    _normalization = "SNV" # "SNV", "MinMax", "StdScaler", "LOaO", "SNV_Detrend", "derivate", Sav_Gol
    
    # Loss function
    _set_loss = "cross_entropy_loss" # "adaptative_focal_loss" or "focal_loss" or "cross_entropy_loss"

    # For fold results
    results = {}

    # Define the K-fold Cross Validator
    kfold = StratifiedKFold(n_splits=_k_folds, random_state=_seed, shuffle=True)

    # Start print
    print('--------------------------------')

    trials = []
    
    # Select Data Type
    print("Normal has been choice!")
    X = pd.concat([X_train, X_valid])
    y = pd.concat([y_train, y_valid])
    
    # Get subclass labels for training and validation data
    try:
        labels_subclasses_train_valid = pd.Categorical(X['Classe']).codes
    except:
        labels_subclasses_train_valid = y

    torch.autograd.set_detect_anomaly(True)

    # K-fold Cross Validation model evaluation
    for fold, (train_index, valid_index) in enumerate(kfold.split(X, labels_subclasses_train_valid)):

        # Print
        print(f'FOLD {fold}')
        print('--------------------------------')

        # Obtém os dados de treinamento e validação para o fold atual
        X_train_fold, X_valid_fold = X.iloc[train_index], X.iloc[valid_index]
        y_train_fold, y_valid_fold = y.iloc[train_index], y.iloc[valid_index]

        if Sampling_mode == 'None':
            print("None Oversampling mode has been chosen!")
        else:
            raise Exception("Sorry, not implemented yet!")

        try:
            X_train_fold = X_train_fold.drop(['Classe'], axis=1)
            X_valid_fold = X_valid_fold.drop(['Classe'], axis=1)
        except:
            pass

    
        # Set normalization
        NORM = snv() # None, derivative(), sav_gol(), StandardScaler()

        # Preprocess the data
        X_train_fold, NORM = preprocessing(X_train_fold, NORM, _normalization, _train=True)
        X_valid_fold, _ = preprocessing(X_valid_fold, NORM, _normalization, _train=False)
        
        if conv:
            X_train_fold = np.expand_dims(X_train_fold, 1)
            X_valid_fold = np.expand_dims(X_valid_fold, 1)
    
        if np.isnan(X_train_fold).any():
            print("Your tensor X_train_fold contains NaN values!")      

        MSE_criterion = nn.MSELoss(reduction='sum')
        
        prev_loss = np.inf
        train_acc = 0.0
        valid_acc = 0.0
        limit_stop = 20 #100

        x_train_ftrs = torch.Tensor([]).to(device)
        y_train_ftrs = torch.Tensor([]).to(device)
        x_valid_ftrs = torch.Tensor([]).to(device)
        y_valid_ftrs = torch.Tensor([]).to(device)

        max_latent_dims_parm_opt = X_train_fold.shape[2] // 2

        def objective(trial):
            """
            Objective function for Optuna.

            Parameters:
            trial (optuna.trial.Trial): A trial object for parameter suggestions.

            Returns:
            float: Metric to be optimized by Optuna.
            """
            _gamma_1 = 1

            if _classify:
                _gamma_2 = optim_params['_gamma_2']
                _latent_dims = optim_params['_latent_dims']
                _batch_size = optim_params['_batch_size']
                _gamma_3 = trial.suggest_float('_gamma_3', 0.001, 100)
            else:
                _gamma_2 = trial.suggest_float('_gamma_2', 0.001, 100)
                _latent_dims = trial.suggest_int('_latent_dims', 4, 512)
                _batch_size  = trial.suggest_int('_batch_size', 12, 256, 12)

            train_dataloader, train_size = prepare_data_loader(X_train_fold, y_train_fold, batch_size=_batch_size, shuffle=True)
            valid_dataloader, valid_size = prepare_data_loader(X_valid_fold, y_valid_fold, batch_size=_batch_size, shuffle=False)

            if _incerteza:
                input_attention_dims = _latent_dims
            else:
                input_attention_dims = _latent_dims*2

            if _attention:
                ATT_Layer = AttentionLayer(input_attention_dims, _latent_dims).to(device)
            else:
                ATT_Layer = None

            VAE_network = VAE(latent_dims=_latent_dims).to(device)

            if _classify:
                CLASS_network = ClassifyingNetwork(num_ftrs=input_attention_dims).to(device)
            else:
                CLASS_network = None

            model_path = './model/'
            vae_model_path = model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}.pth'
            
            if os.path.exists(vae_model_path) and _classify:
                VAE_network.load_state_dict(torch.load(vae_model_path, map_location=device))

            JOINT_Model = JointModel(VAE_network, CLASS_network, ATT_Layer)

            opt_parameters = list(VAE_network.parameters())

            if _classify:
                opt_parameters += list(CLASS_network.parameters())
            
            if _attention:
                opt_parameters += list(ATT_Layer.parameters())

            optimizer = torch.optim.Adam(opt_parameters, weight_decay=1e-3, lr=_lr)
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=_sched_factor, min_lr=_sched_min_lr, patience=_sched_patience)

            if _classify:
                if _set_loss == "focal_loss":
                    CLASS_criterion = FocalLoss(alpha=0.25, gamma=2.0, reduction='sum')
                elif _set_loss == "adaptative_focal_loss":
                    CLASS_criterion = AdaptativeFocalLoss(alpha=0.25, gamma=2.0, reduction='sum')
                elif _set_loss == "cross_entropy_loss":
                    CLASS_criterion = nn.CrossEntropyLoss(reduction='sum')
                else:
                    raise NotImplementedError(f"Invalid Loss: {_set_loss}")

            # Run NN
            train_loss = []
            valid_loss = []
            running_loss = 0.0
            running_kld = 0.0
            running_recon = 0.0
            running_corrects = 0

            for i, train_data in enumerate(train_dataloader):

                inputs, labels = train_data
                
                if conv:
                    inputs = inputs.to(device)
                else:
                    inputs = inputs.to(device).squeeze(1)

                labels = labels.type(dtype=torch.LongTensor)
                labels = labels.to(device)
                
                try:
                    labels = labels.squeeze(1)
                except:
                    pass

                optimizer.zero_grad()

                # Variational AutoEncoder NN
                inputs_pred, mean, logvar = VAE_network.forward(inputs, encoder=False, decoder=True)
            
                MSE_loss = MSE_criterion(inputs_pred, inputs) # Reconstruction Loss 
                KLD_loss = - 0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp()) # KLD Loss 
                VAE_loss = (_gamma_1 * MSE_loss) + (_gamma_2 * KLD_loss)

                if _classify:
                    outputs = JOINT_Model(inputs, attention=_attention, incerteza=_incerteza)

                    _, y_pred = torch.max(outputs, 1)
                
                    if _set_loss == "adaptative_focal_loss":
                        CLASS_loss = CLASS_criterion(outputs, labels, torch.Tensor([1]))
                        beta += 1
                    else:
                        CLASS_loss = CLASS_criterion(outputs, labels)
                    
                    running_corrects += torch.sum(y_pred == labels.data)
                    t_loss = VAE_loss + (_gamma_3 * CLASS_loss)
                else:
                    t_loss =  VAE_loss

                
                running_recon += _gamma_1 * MSE_loss.item() * inputs.size(0)
                running_kld += _gamma_2 * KLD_loss.item() * inputs.size(0)
                running_loss += t_loss.item() * inputs.size(0)

                t_loss.backward()
                optimizer.step()

            if _classify:
                train_acc = running_corrects.double() / train_size['y_size'][0]
            

            epoch_kld = running_kld / train_size['y_size'][0]
            epoch_recon = running_recon / train_size['y_size'][0]
            epoch_loss = running_loss / train_size['y_size'][0]
            # train_loss_graph.append(epoch_loss)

            # Validation
            with torch.no_grad():
                ypredVector = []
                labelsVector = []

                valid_running_loss = 0.0
                valid_running_kld = 0.0
                valid_running_recon = 0.0
                valid_running_corrects = 0

                for i, valid_data in enumerate(valid_dataloader):
                    
                    inputs, labels = valid_data
                    
                    if conv:
                        inputs = inputs.to(device)
                    else:
                        inputs = inputs.to(device).squeeze(1)
                    
                    labels = labels.type(dtype=torch.LongTensor)
                    labels = labels.to(device)

                    try:
                        labels = labels.squeeze(1)
                    except:
                        pass
                    
                    # Variational AutoEncoder NN
                    x_pred, mean, logvar = VAE_network.forward(inputs, encoder=False, decoder=True)
                
                    MSE_loss = MSE_criterion(x_pred, inputs)

                    KLD_loss = - 0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
                    TVAE_loss = (_gamma_1 * MSE_loss) + (_gamma_2 * KLD_loss)

                    if _classify:
                        outputs = JOINT_Model(inputs, attention=_attention, incerteza=_incerteza)
                            
                        _, y_pred = torch.max(outputs, 1)

                        if _set_loss == "adaptative_focal_loss":
                            CLASS_loss = CLASS_criterion(outputs, labels, beta)
                        else:
                            CLASS_loss = CLASS_criterion(outputs, labels)

                        valid_running_corrects += torch.sum(y_pred == labels.data)
                        loss = TVAE_loss + (_gamma_3 * CLASS_loss)
                        ypredVector += y_pred.detach().cpu()
                        labelsVector += labels.data.detach().cpu()
                    else:
                        loss = TVAE_loss

                    valid_running_loss += loss.item() * inputs.size(0)
                    valid_running_recon += _gamma_1 * MSE_loss.item() * inputs.size(0)
                    valid_running_kld += _gamma_2 * KLD_loss.item() * inputs.size(0)


                valid_epoch_loss = valid_running_loss / valid_size['y_size'][0]
                scheduler.step(valid_epoch_loss)
                scheduler.get_last_lr()

            if _classify:
                valid_acc = valid_running_corrects.double() / valid_size['y_size'][0]
                results[fold] = 100.0 * valid_acc

                balancedAccuracyScore = balanced_accuracy_score(labelsVector, ypredVector)

                return balancedAccuracyScore
            else:
                return valid_epoch_loss
        
        if _classify:
            study = optuna.create_study(direction='maximize', sampler=optuna.samplers.TPESampler())
        else:
            study = optuna.create_study(direction='minimize', sampler=optuna.samplers.TPESampler())
        
        study.optimize(objective, n_trials=_epochs)
        trial = study.best_trial
        trials.append(trial)


    if _classify:
        MAX = -np.inf
        for i, trial in enumerate(trials):
            if trial.value > MAX:
                best_value = trial.values
                best_params = trial.params
                best_fold = i
                MAX = trial.value

        print(f"  Fold: {best_fold}")
        print(f"  Value: {best_value}")

        print("  Params: ")
        for key, value in best_params.items():
            print(f"    {key}: {value}")
        
        return best_params
    else:
        MIN = np.inf
        for i, trial in enumerate(trials):
            print(trial.value)
            if trial.value < MIN:
                best_value = trial.value
                best_params = trial.params
                best_fold = i
                MIN = trial.value

        print(f"  Fold: {best_fold}")
        print(f"  Value: {best_value}")

        print("  Params: ")
        for key, value in best_params.items():
            print(f"    {key}: {value}")

        return best_params


## Runs

In [None]:
%%time

# Hyperparameters definition
_k_folds = 5  # Number of folds for K-fold cross-validation
_lr = 1e-4  # Learning rate
_epochs = 2000  # Number of training epochs
_sched_factor = 0.1  # Factor by which the learning rate is reduced
_sched_min_lr = 1e-6  # Minimum learning rate for the scheduler
_sched_patience = 20  # Number of epochs with no improvement after which learning rate will be reduced

# Normalization method
_normalization = "SNV" # Options: ("SNV", "MinMax", "StdScaler", "LOaO", "SNV_Detrend", "derivate", "Sav_Gol")

# Loss function
_set_loss = "cross_entropy_loss" # Loss function to be used ("cross_entropy_loss", "adaptative_focal_loss", "focal_loss")

# Loss weight for reconstruction loss in the VAE
_gamma_1 = 1

_attention = False
_incerteza = False
_att_method = 'Loung'  # Attention method
_plotLoss = True  # Flag to plot loss

class JointModel(nn.Module):
    """
    A joint model combining a Variational Autoencoder (VAE), a Classifying Network, and an Attention Layer.

    Parameters:
    VAE (nn.Module): The Variational Autoencoder model.
    ClassifyingNetwork (nn.Module): The Classifying Network model.
    AttentionLayer (nn.Module): The Attention Layer model.
    """
    def __init__(self, VAE, ClassifyingNetwork, AttentionLayer):
        super(JointModel, self).__init__()
        self.vae = VAE
        if ClassifyingNetwork is not None:
            self.classifying_net = ClassifyingNetwork
        if AttentionLayer is not None:
            self.attention_layer = AttentionLayer
        
    def forward(self, x, attention=False, incerteza=False, method='Loung'):
        """
        Forward pass for the JointModel.

        Parameters:
        x (torch.Tensor): Input tensor.
        attention (bool, optional): Flag to apply attention. Defaults to False.
        incerteza (bool, optional): Flag to handle uncertainty. Defaults to False.
        method (str, optional): Method for attention mechanism. Defaults to 'Loung'.

        Returns:
        torch.Tensor: Output tensor after passing through the joint model.
        """
        # Encode the input using the VAE to get mean and std deviation
        mean_vae, std_dev = self.vae.forward(x)
        ms_vector = torch.cat((mean_vae, std_dev), dim=1)

        if incerteza:
            # Handling uncertainty
            weight_mean_nn = self.classifying_net.MLPclassify[1].weight
            weight_std_nn = self.classifying_net.MLPclassify[1].weight
            bias = self.classifying_net.MLPclassify[1].bias
            
            mean_vae_new = mean_vae.unsqueeze(2)
            std_dev_new = std_dev.unsqueeze(2)

            z_mean = torch.zeros_like(mean_vae_new)
            z_std = torch.zeros_like(std_dev_new)

            # Expand weights mean, std and biases using linear interpolation
            weight_mean = torch.nn.functional.interpolate(
                weight_mean_nn.unsqueeze(0).unsqueeze(0),
                size=(mean_vae_new.size()[1], mean_vae_new.size()[1]),
                mode='bilinear',
                align_corners=False,
            ).squeeze(0).squeeze(0)
            
            weight_std = torch.nn.functional.interpolate(
                weight_std_nn.unsqueeze(0).unsqueeze(0),
                size=(std_dev_new.size()[1], std_dev_new.size()[1]),
                mode='bilinear',
                align_corners=False,
            ).squeeze(0).squeeze(0)
            
            bias = torch.linspace(bias.min().item(), bias.max().item(), steps=mean_vae_new.size()[1]).view(-1, 1).to(device)

            # Increase the vector size to accommodate the repetition.
            weight_mean = weight_mean.unsqueeze(0)
            weight_std = weight_std.unsqueeze(0)
            bias = bias.unsqueeze(0)

            # Replicate the vector along axis 0.
            repeated_vector_mean = weight_mean.repeat(mean_vae_new.size()[0], 1, 1)
            repeated_vector_std = weight_std.repeat(std_dev_new.size()[0], 1, 1)
            repeated_bias = bias.repeat(mean_vae_new.size()[0], 1, 1)

            # Remove the additional dimension
            weight_mean = repeated_vector_mean.squeeze(0)
            weight_std = repeated_vector_std.squeeze(0)
            bias = repeated_bias.squeeze(0)

            for i in range(mean_vae_new.size()[0]): # Referente aos batchs
                z_mean[i] = torch.matmul(weight_mean[i], mean_vae_new[i]) + bias[i]

            for i in range(std_dev_new.size()[0]):
                z_std[i] = torch.matmul(weight_std[i]**2, std_dev_new[i]**2)
            
        
            z_mean = z_mean.squeeze(2)
            z_std = z_std.squeeze(2)
            
            ms_vector = expected_sigm_of_norm(z_mean, z_std, method='probit')

            if attention:
                # Apply attention mechanism
                attention_weights = self.attention_layer(ms_vector)
                
                ms_vector_new = ms_vector.clone()

                if method == 'Loung':
                    ms_vector_new *= attention_weights
                elif method == 'Bahdanau':
                    ms_vector_new += attention_weights

                ms_vector = ms_vector_new
        
        # Pass the combined vector through the classifying network
        output = self.classifying_net(ms_vector)
        return output

# Run optimization for both classification and non-classification cases
for _classify in [False, True]:
    
    if _classify:
        # Run Optuna optimization for classification and get the best parameters
        best_params = optuna_run(_classify=True, optim_params=best_params)
        _gamma_3 = best_params['_gamma_3']
    else: 
        # Run Optuna optimization for non-classification and get the best parameters
        best_params = optuna_run(_classify=False)
        _gamma_2 = best_params['_gamma_2']  # Retrieve the best gamma_2 parameter
        _latent_dims = best_params['_latent_dims']  # Retrieve the best latent dimensions
        _batch_size = best_params['_batch_size']  # Retrieve the best batch size

    # For storing fold results
    results = {}

    # Define the K-fold Cross Validator
    kfold = StratifiedKFold(n_splits=_k_folds, random_state=_seed, shuffle=True)

    # Start print
    print('--------------------------------')


    print("Normal has been choice!")
    X = pd.concat([X_train, X_valid])
    y = pd.concat([y_train, y_valid])

    # Get subclass labels for training and validation data
    try:
        labels_subclasses_train_valid = pd.Categorical(X['Classe']).codes
    except:
        labels_subclasses_train_valid = y

    # Enable anomaly detection
    torch.autograd.set_detect_anomaly(True)

    # K-fold Cross Validation model evaluation
    for fold, (train_index, valid_index) in enumerate(kfold.split(X, labels_subclasses_train_valid)):

        # Print the fold number
        print(f'FOLD {fold}')
        print('--------------------------------')

        # Get the training and validation data for the current fold
        X_train_fold, X_valid_fold = X.iloc[train_index], X.iloc[valid_index]
        y_train_fold, y_valid_fold = y.iloc[train_index], y.iloc[valid_index]

        # Handle oversampling based on the sampling mode
        if Sampling_mode == 'SMOTE' or Sampling_mode == 'EMD_SMOTE':
            print("Oversampling SMOTE has been choice!")
            from imblearn.over_sampling import SMOTE

            try:
                X_train_fold = X_train_fold.drop(['Classe'], axis=1)
                X_valid_fold = X_valid_fold.drop(['Classe'], axis=1)
            except:
                pass

            sm = SMOTE(random_state=_seed) # Only apply SMOTE to the training dataset
            X_train_fold, y_train_fold = sm.fit_resample(X_train_fold, y_train_fold)
        elif Sampling_mode == 'GAN':
            print("Oversampling GAN has been choice!")
            X_train_fold = pd.concat([X_train_fold, GANDataset.loc[:,:'Classe']])
            df_new = pd.DataFrame({'y': np.ones(num_samples_aug, dtype=int)})
            y_train_fold = pd.concat([y_train_fold, df_new], ignore_index=True)
        elif Sampling_mode == 'None':
            print("None Oversampling mode has been chosen!")
        else:
            raise Exception("Sorry, not implemented yet!")

        try:
            X_train_fold = X_train_fold.drop(['Classe'], axis=1)
            X_valid_fold = X_valid_fold.drop(['Classe'], axis=1)
        except:
            pass
        
        # Set normalization method
        NORM = snv()  # Options: None, derivative(), sav_gol(), StandardScaler()
        
        # Preprocess the data
        X_train_fold, NORM = preprocessing(X_train_fold, NORM, _normalization, _train=True)
        X_valid_fold, _ = preprocessing(X_valid_fold, NORM, _normalization, _train=False)

        if conv:
            # Expand dimensions for convolutional input
            X_train_fold = np.expand_dims(X_train_fold, 1)
            X_valid_fold = np.expand_dims(X_valid_fold, 1)

         # Prepare data loaders
        train_dataloader, train_size = prepare_data_loader(X_train_fold, y_train_fold, batch_size=_batch_size, shuffle=True)
        valid_dataloader, valid_size = prepare_data_loader(X_valid_fold, y_valid_fold, batch_size=_batch_size, shuffle=False)
            
        # Set input dimensions for attention layer based on uncertainty handling
        if _incerteza:
            input_attention_dims = _latent_dims
        else:
            input_attention_dims = _latent_dims * 2

        # Initialize attention layer if needed
        if _attention:
            ATT_Layer = AttentionLayer(input_attention_dims, _latent_dims).to(device)
        else:
            ATT_Layer = None

         # Initialize the VAE network
        VAE_network = VAE(latent_dims=_latent_dims).to(device)

        # Initialize the classification network if classification is enabled
        if _classify:
            CLASS_network = ClassifyingNetwork(num_ftrs=input_attention_dims).to(device)
        else:
            CLASS_network = None

        # Define model paths for saving/loading
        model_path = './model/'
        vae_model_path = model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}.pth'
        
        # Load pre-trained VAE model if it exists
        if os.path.exists(vae_model_path) and _classify:
            print(f'VAE already exists. Loading model {vae_model_path}.')
            VAE_network.load_state_dict(torch.load(vae_model_path, map_location=device))

        # Initialize the joint model
        JOINT_Model = JointModel(VAE_network, CLASS_network, ATT_Layer)

        # Collect parameters for optimization
        opt_parameters = list(VAE_network.parameters())

        if _classify:
            opt_parameters += list(CLASS_network.parameters())
        
        if _attention:
            opt_parameters += list(ATT_Layer.parameters())

        # Set optimizer and scheduler
        optimizer = torch.optim.Adam(opt_parameters, weight_decay=1e-3, lr=_lr)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=_sched_factor, min_lr=_sched_min_lr, patience=_sched_patience, verbose=True)

        # Set loss criteria
        MSE_criterion = nn.MSELoss(reduction='sum')

        if _classify:
            if _set_loss == "focal_loss":
                print('Focal Loss Chosen!')
                CLASS_criterion = FocalLoss(alpha=0.25, gamma=2.0, reduction='sum')
            elif _set_loss == "adaptative_focal_loss":
                print('Adaptative Focal Loss Chosen!')
                CLASS_criterion = AdaptativeFocalLoss(alpha=0.25, gamma=2.0, reduction='sum')
            elif _set_loss == "cross_entropy_loss":
                print('Cross Entropy Loss Chosen!')
                CLASS_criterion = nn.CrossEntropyLoss(reduction='sum')
            else:
                raise NotImplementedError(f"Invalid Loss: {_set_loss}")
        
        # Train the model
        print('Training VAE')
        mean_train_loss = []
        mean_valid_loss = []
        curr_loss = 0
        prev_loss = np.inf  # Initialize previous loss to infinity for comparison
        train_acc = 0.0
        valid_acc = 0.0
        limit_stop = 20  # Early stopping limit
        train_loss_graph = []
        valid_loss_graph = []
        train_acc_graph = []
        valid_acc_graph = []

        # Initialize tensors to store features and labels for training and validation
        x_train_ftrs = torch.Tensor([]).to(device)
        y_train_ftrs = torch.Tensor([]).to(device)
        x_valid_ftrs = torch.Tensor([]).to(device)
        y_valid_ftrs = torch.Tensor([]).to(device)

        # Training loop
        for epoch in range(_epochs):
            train_loss = []
            valid_loss = []
            running_loss = 0.0
            running_kld = 0.0
            running_recon = 0.0
            running_corrects = 0

            # Beta for Adaptative Focal Loss
            beta = torch.Tensor([epoch+1])

            # Iterate over training data
            for i, train_data in enumerate(train_dataloader):

                inputs, labels = train_data
                
                if conv:
                    inputs = inputs.to(device)
                else:
                    inputs = inputs.to(device).squeeze(1)

                labels = labels.type(dtype=torch.LongTensor)
                labels = labels.to(device)

                try:
                    labels = labels.squeeze(1)
                except:
                    pass

                optimizer.zero_grad()

                # Variational AutoEncoder NN
                inputs_pred, mean, logvar = VAE_network.forward(inputs, encoder=False, decoder=True)
                
                # Compute reconstruction loss (MSE) and KL divergence
                MSE_loss = MSE_criterion(inputs_pred, inputs)
                KLD_loss = - 0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
                VAE_loss = (_gamma_1 * MSE_loss) + (_gamma_2 * KLD_loss)
                
                if _classify:
                    # Forward pass through joint model
                    outputs = JOINT_Model(inputs, attention=_attention, incerteza=_incerteza, method=_att_method)
                    
                    _, y_pred = torch.max(outputs, 1)

                    if _set_loss == "adaptative_focal_loss":
                        CLASS_loss = CLASS_criterion(outputs, labels, beta)
                        beta += 1
                    else:
                        CLASS_loss = CLASS_criterion(outputs, labels)

                    running_corrects += torch.sum(y_pred == labels.data)
                    t_loss = VAE_loss + (_gamma_3 * CLASS_loss)
                else:
                    t_loss =  VAE_loss

                # Accumulate losses
                running_recon += _gamma_1 * MSE_loss.item() * inputs.size(0)
                running_kld += _gamma_2 * KLD_loss.item() * inputs.size(0)
                running_loss += t_loss.item() * inputs.size(0)

                t_loss.backward()
                optimizer.step()
        
            if _classify:
                train_acc = running_corrects.double() / train_size['y_size'][0]
                train_acc_graph.append(train_acc)

            epoch_kld = running_kld / train_size['y_size'][0]
            epoch_recon = running_recon / train_size['y_size'][0]
            epoch_loss = running_loss / train_size['y_size'][0]
            train_loss_graph.append(epoch_loss)

            # Validation step
            with torch.no_grad():
                valid_running_loss = 0.0
                valid_running_kld = 0.0
                valid_running_recon = 0.0
                valid_running_corrects = 0

                for i, valid_data in enumerate(valid_dataloader):

                    inputs, labels = valid_data
                    
                    if conv:
                        inputs = inputs.to(device)
                    else:
                        inputs = inputs.to(device).squeeze(1)
                    
                    labels = labels.type(dtype=torch.LongTensor)
                    labels = labels.to(device)
                    
                    try:
                        labels = labels.squeeze(1)
                    except:
                        pass


                    # Variational AutoEncoder NN
                    x_pred, mean, logvar = VAE_network.forward(inputs, encoder=False, decoder=True)

                    MSE_loss = MSE_criterion(x_pred, inputs)
                    KLD_loss = - 0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
                    TVAE_loss = (_gamma_1 * MSE_loss) + (_gamma_2 * KLD_loss)

                    if _classify:
                        outputs = JOINT_Model(inputs, attention=_attention, incerteza=_incerteza, method=_att_method)
                            
                        _, y_pred = torch.max(outputs, 1)
                        
                        if _set_loss == "adaptative_focal_loss":
                            CLASS_loss = CLASS_criterion(outputs, labels, beta)
                        else:
                            CLASS_loss = CLASS_criterion(outputs, labels)

                        valid_running_corrects += torch.sum(y_pred == labels.data)
                        loss = TVAE_loss + (_gamma_3 * CLASS_loss)
                    else:
                        loss = TVAE_loss

                    valid_running_loss += loss.item() * inputs.size(0)
                    valid_running_recon += _gamma_1 * MSE_loss.item() * inputs.size(0)
                    valid_running_kld += _gamma_2 * KLD_loss.item() * inputs.size(0)

                if epoch % 10 == 0 and not _classify:
                    plot_gallery([inputs.detach().cpu(), x_pred.detach().cpu()], epoch, fold, model_name, 1, 2, all_plot=True)  

                valid_epoch_loss = valid_running_loss / valid_size['y_size'][0]
                valid_epoch_kld = valid_running_kld / valid_size['y_size'][0]
                valid_epoch_recon = valid_running_recon / valid_size['y_size'][0]
                valid_loss_graph.append(valid_epoch_loss)

                scheduler.step(valid_epoch_loss)
                scheduler.get_last_lr()
                curr_loss = valid_epoch_loss

            if _classify:
                valid_acc = valid_running_corrects.double() / valid_size['y_size'][0]
                valid_acc_graph.append(valid_acc)
                
                results[fold] = 100.0 * valid_acc
                print(f'Epoch: {epoch}')
                print(f'Train loss: {epoch_loss}, Train Acc: {train_acc*100:.2f}')
                print(f'Valid loss: {valid_epoch_loss}, Valid  Acc: {valid_acc*100:.2f}')
            else:
                print(f'Epoch: {epoch} | Train loss: {epoch_loss} | Train KLD: {epoch_kld} | Train Recon Loss:{epoch_recon}')
                print(f'Epoch: {epoch} | Valid loss: {valid_epoch_loss} | Valid KLD: {valid_epoch_kld} | Valid Recon Loss:{valid_epoch_recon}')
            
            # Check if the current loss is less than the best so far
            if curr_loss < prev_loss:
                best_epoch = epoch
                print(f"best_epoch: {best_epoch}")
                if _classify: 
                    if _attention:
                        saved_model = {
                            'VAE': VAE_network.state_dict(),
                            'CLASS': CLASS_network.state_dict(),
                            'ATT': ATT_Layer.state_dict(),
                        }
                    else:
                        saved_model = {
                            'VAE': VAE_network.state_dict(),
                            'CLASS': CLASS_network.state_dict(),
                        }
                else:
                    saved_model = {
                            'VAE': VAE_network.state_dict(),
                        }

            # Early Stopping!
            if curr_loss > prev_loss and epoch > 50:
                print(f"prev_loss: {prev_loss}, curr_loss: {curr_loss}")
                trigger_times += 1
                print(f'Times without improved: {trigger_times}')

                if trigger_times >= limit_stop:
                    print(f'[*] Early stopping in Epoch: {epoch} !')
                    print('Saving Model at epoch {}'.format(epoch+1))
                    break
            else:
                trigger_times = 0
                prev_loss = curr_loss

        # Plot and save loss and accuracy data if required
        if _plotLoss:
            if _classify:
                loss_acc_data = {
                    "train_loss": train_loss_graph,
                    "valid_loss": valid_loss_graph,
                    "train_acc": train_acc_graph,
                    "valid_acc": valid_acc_graph
                }
                with open(f'{_data}_step2_loss_acc_data_{model_name}_{_normalization}_Fold_{str(fold)}-A-{_att_method}-{_incerteza}-Loss-{_set_loss}.pkl', 'wb') as f:
                    pickle.dump(loss_acc_data, f)
            else:
                loss_data = {
                    "train_loss": train_loss_graph,
                    "valid_loss": valid_loss_graph
                }
                with open(f'{_data}_step1_loss_data_{model_name}_{_normalization}_Fold_{str(fold)}-A-{_att_method}-{_incerteza}-Loss-{_set_loss}.pkl', 'wb') as f:
                    pickle.dump(loss_data, f)

        # Save models
        if _classify:
            print('Saving VAE and CLASS model!')
            if _attention and _incerteza:
                torch.save(saved_model['VAE'], model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_VAE_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-A-{_att_method}-I-{_incerteza}-Loss-{_set_loss}-C.pth')
                torch.save(saved_model['CLASS'], model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_CLASS_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-A-{_att_method}-I-{_incerteza}-Loss-{_set_loss}-C.pth')
                torch.save(saved_model['ATT'], model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_ATT_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-A-{_att_method}-I-{_incerteza}-Loss-{_set_loss}-C.pth')
            elif _incerteza == True and _attention == False:
                torch.save(saved_model['VAE'], model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_VAE_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-I-{_incerteza}-Loss-{_set_loss}-C.pth')
                torch.save(saved_model['CLASS'], model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_CLASS_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-I-{_incerteza}-Loss-{_set_loss}-C.pth')
            elif _attention == True and _incerteza == False:
                torch.save(saved_model['VAE'], model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_VAE_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-A-{_att_method}-{_incerteza}-Loss-{_set_loss}-C.pth')
                torch.save(saved_model['CLASS'], model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_CLASS_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-A-{_att_method}-{_incerteza}-Loss-{_set_loss}-C.pth')
                torch.save(saved_model['ATT'], model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_ATT_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-A-{_att_method}-{_incerteza}-Loss-{_set_loss}-C.pth')
            else:
                torch.save(saved_model['VAE'], model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_VAE_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-Loss-{_set_loss}-C.pth')
                torch.save(saved_model['CLASS'], model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_CLASS_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-Loss-{_set_loss}-C.pth')
        else:
            if _attention:
                torch.save(saved_model['VAE'], model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_{model_name}_{_normalization}_{_augmentation}_A_{_att_method}_Fold_{str(fold)}.pth')
            else:
                torch.save(saved_model['VAE'], model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}.pth')

    # Print fold results
    print(f'K-FOLD CROSS VALIDATION RESULTS FOR {_k_folds} FOLDS')
    print('--------------------------------')
    sum = 0.0
    for key, value in results.items():
        print(f'Fold {key}: {value} %')
        sum += value
        print(f'Average: {sum/len(results.items())} %')


print('Testing VAE+CLASS')

with torch.no_grad():
     # Initialize lists to store various evaluation metrics
    accuracy = []
    balancedAccuracyScore = []
    recall = []
    precision = []
    f1 = []
    auc = []
    test_loss = []

    try:
        X_test_fold = X_test.drop(['Classe'], axis=1)
    except:
        X_test_fold = X_test.copy()

    # Preprocess test data
    X_test_fold, _ = preprocessing(X_test_fold, NORM, _normalization, _train=False)

    if conv:
        X_test_fold = np.expand_dims(X_test_fold, 1)

    # Prepare test data loader
    test_dataloader, test_size = prepare_data_loader(X_test_fold, y_test, batch_size=_batch_size, shuffle=False)
    
    for fold in range(_k_folds):
        # Initialize VAE and ClassifyingNetwork models
        VAE_network = VAE(latent_dims=_latent_dims).to(device) 
        num_ftrs = _latent_dims * 2
        CLASS_network = ClassifyingNetwork(num_ftrs=num_ftrs).to(device)
        
        # Define model paths
        model_path = './model/'
        if _attention and _incerteza:
            vae_model_path = model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_VAE_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-A-{_att_method}-I-{_incerteza}-Loss-{_set_loss}-C.pth'
            class_model_path = model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_CLASS_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-A-{_att_method}-I-{_incerteza}-Loss-{_set_loss}-C.pth'
            att_model_path = model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_ATT_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-A-{_att_method}-I-{_incerteza}-Loss-{_set_loss}-C.pth'
        elif _incerteza == True and  _attention == False:
            vae_model_path = model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_VAE_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-I-{_incerteza}-Loss-{_set_loss}-C.pth'
            class_model_path = model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_CLASS_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-I-{_incerteza}-Loss-{_set_loss}-C.pth'
        elif _attention == True and _incerteza == False:
            vae_model_path = model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_VAE_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-A-{_att_method}-{_incerteza}-Loss-{_set_loss}-C.pth'
            class_model_path = model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_CLASS_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-A-{_att_method}-{_incerteza}-Loss-{_set_loss}-C.pth'
            att_model_path = model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_ATT_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-A-{_att_method}-{_incerteza}-Loss-{_set_loss}-C.pth'
        else:
            vae_model_path = model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_VAE_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-Loss-{_set_loss}-C.pth'
            class_model_path = model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_CLASS_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-Loss-{_set_loss}-C.pth'

        MSE_criterion = nn.MSELoss(reduction='sum')

        # Set the classification criterion
        if _set_loss == "focal_loss":
            print('Focal Loss Chosen!')
            CLASS_criterion = FocalLoss(alpha=0.25, gamma=2.0, reduction='sum')
        elif _set_loss == "adaptative_focal_loss":
            print('Adaptative Focal Loss Chosen!')
            CLASS_criterion = AdaptativeFocalLoss(alpha=0.25, gamma=2.0, reduction='sum')
        elif _set_loss == "cross_entropy_loss":
            print('Cross Entropy Loss Chosen!')
            CLASS_criterion = nn.CrossEntropyLoss(reduction='sum')
        else:
            raise NotImplementedError(f"Invalid Loss: {_set_loss}")

        # Load the pre-trained models if they exist
        if os.path.exists(vae_model_path) and os.path.exists(class_model_path):
            print(f'VAE already exists. Loading model {vae_model_path}.')
            VAE_network.load_state_dict(torch.load(vae_model_path, map_location = device))
            VAE_network.eval()
            print(f'CLASS already exists. Loading model {class_model_path}.')
            CLASS_network.load_state_dict(torch.load(class_model_path, map_location = device))
            CLASS_network.eval()
            if _attention:
                print(f'ATT Weights already exists. Loading model {att_model_path}.')
                ATT_Layer.load_state_dict(torch.load(att_model_path, map_location = device))
                ATT_Layer.eval()

        # Initialize the joint model
        JOINT_Model = JointModel(VAE_network, CLASS_network, ATT_Layer)

        test_running_loss = 0.0
        test_running_corrects = 0
        test_acc = 0
        ypredVector = []
        labelsVector = []

        # Test loop
        for i, test_data in enumerate(test_dataloader):
            inputs, labels = test_data
            if conv:
                inputs = inputs.to(device)
            else:
                inputs = inputs.to(device).squeeze(1)
            labels = labels.type(dtype=torch.LongTensor)
            labels = labels.to(device)
            
            try:
                labels = labels.squeeze(1)
            except:
                pass
            
            # Variational AutoEncoder NN
            x_pred, mean, logvar = VAE_network.forward(inputs, encoder=False, decoder=True)
            
            TMSE_loss = MSE_criterion(x_pred, inputs)
            
            outputs = JOINT_Model(inputs, attention=_attention, incerteza=_incerteza, method=_att_method)

            _, y_pred = torch.max(outputs, 1)

            if _set_loss == "adaptative_focal_loss":
                TCLASS_loss = CLASS_criterion(outputs, labels, torch.Tensor([1]))
            else:
                TCLASS_loss = CLASS_criterion(outputs, labels)
                
            valid_running_corrects += torch.sum(y_pred == labels.data)

            ypredVector += y_pred.detach().cpu()
            labelsVector += labels.data.detach().cpu()
    
            loss = TVAE_loss + (_gamma_3 * TCLASS_loss)

            test_running_loss += loss.item() * inputs.size(0)

            # plot_gallery([inputs.detach().cpu(), x_pred.detach().cpu()], epoch, fold, 1, 2, all_plot=True)  

        test_epoch_loss = test_running_loss / test_size['y_size'][0]

        # Calculate evaluation metrics
        accuracy_fold = accuracy_score(labelsVector, ypredVector)
        balancedAccuracyScore_fold = balanced_accuracy_score(labelsVector, ypredVector)
        recall_fold = recall_score(labelsVector, ypredVector, average='weighted')
        precision_fold = precision_score(labelsVector, ypredVector, average='weighted', zero_division=True)
        f1_fold = f1_score(labelsVector, ypredVector, average='weighted')

        test_loss.append(test_epoch_loss)
        accuracy.append(accuracy_score(labelsVector, ypredVector)) 
        balancedAccuracyScore.append(balanced_accuracy_score(labelsVector, ypredVector)) 
        recall.append(recall_score(labelsVector, ypredVector, average='weighted'))
        precision.append(precision_score(labelsVector, ypredVector, average='weighted', zero_division=True))
        f1.append(f1_score(labelsVector, ypredVector, average='weighted'))


        # Save results per fold
        fold_csv_filename = f'./results_per_fold/{_data}_resultados_{model_name}_{_normalization}_{Sampling_mode}_per_fold.csv'
        fold_results_data = {
            'Fold': fold,
            'Dataset': _dataset_name,
            '_gamma_2': _gamma_2,
            '_gamma_3': _gamma_3,
            '_batch_size': _batch_size,
            '_latent_dims': _latent_dims,
            'Model': f'{model_name}_{_normalization}',
            'Attention': _attention,
            'Attention_Method': _att_method,
            'Incerteza': _incerteza,
            'Sampling Mode': Sampling_mode,
            'Loss Function': _set_loss,
            'Test - loss': test_epoch_loss,
            'Test - Accuracy Score': accuracy_fold,
            'Test - Balanced Accuracy Score': balancedAccuracyScore_fold,
            'Test - Precision Score': recall_fold,
            'Test - Recall Score': precision_fold,
            'Test - F1 Score': f1_fold,
        }
        fold_header = fold_results_data.keys()

        save_results_to_csv(fold_csv_filename, fold_results_data, fold_header)

    # Save general results
    general_csv_filename = f'./results/{_data}_resultados_{model_name}_{_normalization}_{Sampling_mode}.csv'
    general_results_data = {
        'Dataset': _dataset_name,
        'Model': f'{model_name}_{_normalization}',
        'Attention': _attention,
        'Attention_Method': _att_method,
        'Incerteza': _incerteza,
        'Sampling Mode': Sampling_mode,
        '_gamma_2': _gamma_2,
        '_gamma_3': _gamma_3,
        '_batch_size': _batch_size,
        '_latent_dims': _latent_dims,
        'Loss Function': _set_loss,
        'loss - mean': np.mean(test_loss),
        'loss - std': np.std(test_loss),
        'Accuracy Score - mean': np.mean(accuracy),
        'Accuracy Score - std': np.std(accuracy),
        'Balanced Accuracy Score - mean': np.mean(balancedAccuracyScore),
        'Balanced Accuracy Score - std': np.std(balancedAccuracyScore),
        'Precision Score - mean': np.mean(precision),
        'Precision Score - std': np.std(precision),
        'Recall Score - mean': np.mean(recall),
        'Recall Score - std': np.std(recall),
        'F1 Score - mean': np.mean(f1),
        'F1 Score - std': np.std(f1),
    }
    general_header = general_results_data.keys()

    save_results_to_csv(general_csv_filename, general_results_data, general_header)

    print(f"Resultados salvos em ./results/{_data}_resultados_{model_name}_{_normalization}_{Sampling_mode}.csv")


### Script to run all models experiments

This code executes multiple experiments, including:
- Attention mechanism enabled/disabled
- Gaussian neurons enabled/disabled
- Loss function 

Some functions are repeated, allowing the experiments to be run after executing Optuna, without the need to execute the #Run section again.

In [None]:
def plot_loss(loss_file, step=1):
    """
    Plots the loss function from a pickle file (.pkl).

    Args:
        loss_file (str): Path to the pickle file containing the loss function data.
        step (int): Indicates the step of training. Default is 1.
                    If step is 2, it also plots the training and validation accuracy.

    Returns:
        None.
    """
    # Load the loss data from the pickle file
    with open(loss_file, 'rb') as f:
        loss_dict = pickle.load(f)
    
    # Extract accuracy values if step is 2
    if step == 2:
        train_acc_values = loss_dict['train_acc']
        valid_acc_values = loss_dict['valid_acc']
    else:
        pass
    
    # Extract loss values
    train_loss_values = loss_dict['train_loss']
    valid_loss_values = loss_dict['valid_loss']

    # Extract the number of epochs
    epochs = np.arange(1,len(loss_dict['train_loss'])+1)

    # Plot the loss function
    plt.figure(figsize=(16,10))
    plt.rcParams['legend.fontsize'] = 22
    plt.rcParams.update({'font.size': 25})
    plt.title(f"Loss Function {model_name}_{_normalization}_Fold_{fold}")
    plt.plot(epochs, train_loss_values, label="Train_Loss", linewidth=2, color='Blue')
    plt.plot(epochs, valid_loss_values, label="Valid_Loss", linewidth=2, color='Red')

    # Plot accuracy if step is 2
    if step == 2:
        plt.plot(epochs, train_acc_values, label="Train_Accuracy", linewidth=2, color='Green')
        plt.plot(epochs, valid_acc_values, label="Valid_Accuracy", linewidth=2, color='Orange')

    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.savefig(f"./Loss_Plot/Loss_{model_name}_{_set_loss}_I-{_incerteza}_A-{_attention}-{_att_method}_{_normalization}_Fold_{fold}-Step{step}.pdf")
    # plt.show()
    plt.close()

# Example usage
# loss_file = './Loss_Plot/pkls/loss_acc_data_22PC_BCA_in_GNA_VAE_cross_entropy_loss_I-False_A-False_None_Fold_0.pkl'
# plot_loss(loss_file, step=2)


In [None]:
# Set the normalization method
_normalization = "SNV"

# Define the path of general_csv_filename, to load the hyperparameters
general_csv_filename = f'./results/{_data}_resultados_{model_name}_{_normalization}_{Sampling_mode}.csv'

# Print the filename for verification
general_csv_filename

In [None]:
%%time

# Define hyperparameters
_k_folds = 5  # Number of folds for K-fold cross-validation
_lr = 1e-4  # Learning rate
_epochs = 2000  # Number of training epochs
_sched_factor = 0.1  # Factor by which the learning rate is reduced
_sched_min_lr = 1e-6  # Minimum learning rate for the scheduler
_sched_patience = 20  # Number of epochs with no improvement after which learning rate will be reduced
_set_loss = "cross_entropy_loss"  # Loss function to be used ("cross_entropy_loss", "adaptative_focal_loss", "focal_loss")
_gamma_1 = 1  # Weight for the reconstruction loss in the VAE
_plotLoss = True  # Flag to indicate if the loss should be plotted


class JointModel(nn.Module):
    """
    A joint model combining a Variational Autoencoder (VAE), a Classifying Network, and an Attention Layer.

    Parameters:
    VAE (nn.Module): The Variational Autoencoder model.
    ClassifyingNetwork (nn.Module): The Classifying Network model.
    AttentionLayer (nn.Module): The Attention Layer model.
    """
    def __init__(self, VAE, ClassifyingNetwork, AttentionLayer):
        super(JointModel, self).__init__()
        self.vae = VAE
        if ClassifyingNetwork is not None:
            self.classifying_net = ClassifyingNetwork
        if AttentionLayer is not None:
            self.attention_layer = AttentionLayer
        
    def forward(self, x, attention=False, incerteza=False, method='Loung'):
        """
        Forward pass for the JointModel.

        Parameters:
        x (torch.Tensor): Input tensor.
        attention (bool, optional): Flag to apply attention. Defaults to False.
        incerteza (bool, optional): Flag to handle uncertainty. Defaults to False.
        method (str, optional): Method for attention mechanism. Defaults to 'Loung'.

        Returns:
        torch.Tensor: Output tensor after passing through the joint model.
        """
        # Encode the input using the VAE to get mean and std deviation
        mean_vae, std_dev = self.vae.forward(x)
        ms_vector = torch.cat((mean_vae, std_dev), dim=1)

        if incerteza:
            # Handling uncertainty
            weight_mean_nn = self.classifying_net.MLPclassify[1].weight
            weight_std_nn = self.classifying_net.MLPclassify[1].weight
            bias = self.classifying_net.MLPclassify[1].bias
            
            mean_vae_new = mean_vae.unsqueeze(2)
            std_dev_new = std_dev.unsqueeze(2)

            z_mean = torch.zeros_like(mean_vae_new)
            z_std = torch.zeros_like(std_dev_new)

            # Expand weights mean, std and biases using linear interpolation
            weight_mean = torch.nn.functional.interpolate(
                weight_mean_nn.unsqueeze(0).unsqueeze(0),
                size=(mean_vae_new.size()[1], mean_vae_new.size()[1]),
                mode='bilinear',
                align_corners=False,
            ).squeeze(0).squeeze(0)
            
            weight_std = torch.nn.functional.interpolate(
                weight_std_nn.unsqueeze(0).unsqueeze(0),
                size=(std_dev_new.size()[1], std_dev_new.size()[1]),
                mode='bilinear',
                align_corners=False,
            ).squeeze(0).squeeze(0)
            
            bias = torch.linspace(bias.min().item(), bias.max().item(), steps=mean_vae_new.size()[1]).view(-1, 1).to(device)

            # Increase the vector size to accommodate the repetition.
            weight_mean = weight_mean.unsqueeze(0)
            weight_std = weight_std.unsqueeze(0)
            bias = bias.unsqueeze(0)

            # Replicate the vector along axis 0.
            repeated_vector_mean = weight_mean.repeat(mean_vae_new.size()[0], 1, 1)
            repeated_vector_std = weight_std.repeat(std_dev_new.size()[0], 1, 1)
            repeated_bias = bias.repeat(mean_vae_new.size()[0], 1, 1)

            # Remove the additional dimension
            weight_mean = repeated_vector_mean.squeeze(0)
            weight_std = repeated_vector_std.squeeze(0)
            bias = repeated_bias.squeeze(0)

            for i in range(mean_vae_new.size()[0]): # Referente aos batchs
                z_mean[i] = torch.matmul(weight_mean[i], mean_vae_new[i]) + bias[i]

            for i in range(std_dev_new.size()[0]):
                z_std[i] = torch.matmul(weight_std[i]**2, std_dev_new[i]**2)
            
        
            z_mean = z_mean.squeeze(2)
            z_std = z_std.squeeze(2)
            
            ms_vector = expected_sigm_of_norm(z_mean, z_std, method='probit')

            if attention:
                # Apply attention mechanism
                attention_weights = self.attention_layer(ms_vector)
                
                ms_vector_new = ms_vector.clone()

                if method == 'Loung':
                    ms_vector_new *= attention_weights
                elif method == 'Bahdanau':
                    ms_vector_new += attention_weights

                ms_vector = ms_vector_new
        
        # Pass the combined vector through the classifying network
        output = self.classifying_net(ms_vector)
        return output

# Set classification and attention method flags
_classify = True
_att_method = 'Luong'

# Loop through different loss functions
for _set_loss in ["cross_entropy_loss", "focal_loss", "adaptative_focal_loss"]:
    print(f"Loss Running {_set_loss}!")

    # Loop through different uncertainty settings
    for _incerteza in [False, True]:
        print(f"Running with Incerteza: {_incerteza}!")

        # Loop through different attention settings
        for _attention in [False, True]:
            print(f"Running with Attention: {_attention}!")

            # Define the column names to look for in the CSV
            column_names = ['_gamma_2', '_gamma_3', '_batch_size', '_latent_dims']

            # Initialize the variables with default values
            _gamma_2, _gamma_3, _batch_size, _latent_dims = None, None, None, None

            # Open the CSV file for reading
            with open(general_csv_filename, newline='') as csvfile:
                reader = csv.DictReader(csvfile)

                # Iterate through the rows of the CSV file
                for row in reader:
                    # Check if all the desired columns are present in the current row
                    if all(col in row for col in column_names):
                        # Convert the current row's values to the desired types
                        if Sampling_mode == row['Sampling Mode']:
                            _gamma_2 = float(row['_gamma_2'])
                            _gamma_3 = float(row['_gamma_3'])
                            _batch_size = int(row['_batch_size'])
                            _latent_dims = int(row['_latent_dims'])
                            break  # Stop searching after finding the first match
            
            if _gamma_2 is not None:
                # Do something with the found values
                print(f"_gamma_2: {_gamma_2}, _gamma_3: {_gamma_3}, _batch_size: {_batch_size}, _latent_dims: {_latent_dims}")
            else:
                raise NotImplementedError(f"Invalid, values not found in the CSV dataset.")
                
            # For fold results
            results = {}

            # Define the K-fold Cross Validator
            kfold = StratifiedKFold(n_splits=_k_folds, random_state=_seed, shuffle=True)

            # Start print
            print('--------------------------------')

            print("Normal has been choice!")
            X = pd.concat([X_train, X_valid])
            y = pd.concat([y_train, y_valid])

            # Get subclass labels for training and validation data
            try:
                labels_subclasses_train_valid = pd.Categorical(X['Classe']).codes
            except:
                labels_subclasses_train_valid = y


            torch.autograd.set_detect_anomaly(True)

            # K-fold Cross Validation model evaluation
            for fold, (train_index, valid_index) in enumerate(kfold.split(X, labels_subclasses_train_valid)):

                # Print
                print(f'FOLD {fold}')
                print('--------------------------------')

                # Obtém os dados de treinamento e validação para o fold atual
                X_train_fold, X_valid_fold = X.iloc[train_index], X.iloc[valid_index]
                y_train_fold, y_valid_fold = y.iloc[train_index], y.iloc[valid_index]

                if Sampling_mode == 'None':
                    print("None Oversampling mode has been chosen!")
                else:
                    raise Exception("Sorry, not implemented yet!")

                try:
                    X_train_fold = X_train_fold.drop(['Classe'], axis=1)
                    X_valid_fold = X_valid_fold.drop(['Classe'], axis=1)
                except:
                    pass
                
                # Set normalization method
                NORM = snv()  # Options: None, derivative(), sav_gol(), StandardScaler()
        
                # Preprocess the data
                X_train_fold, NORM = preprocessing(X_train_fold, NORM, _normalization, _train=True)
                X_valid_fold, _ = preprocessing(X_valid_fold, NORM, _normalization, _train=False)

                if conv:
                    X_train_fold = np.expand_dims(X_train_fold, 1)
                    X_valid_fold = np.expand_dims(X_valid_fold, 1)

                train_dataloader, train_size = prepare_data_loader(X_train_fold, y_train_fold, batch_size=_batch_size, shuffle=True)
                valid_dataloader, valid_size = prepare_data_loader(X_valid_fold, y_valid_fold, batch_size=_batch_size, shuffle=False)
                    
                if _incerteza:
                    input_attention_dims = _latent_dims
                else:
                    input_attention_dims = _latent_dims*2

                if _attention:
                    ATT_Layer = AttentionLayer(input_attention_dims, _latent_dims).to(device)
                else:
                    ATT_Layer = None

                VAE_network = VAE(latent_dims=_latent_dims).to(device)

                if _classify:
                    CLASS_network = ClassifyingNetwork(num_ftrs=input_attention_dims).to(device)
                else:
                    CLASS_network = None

                model_path = './model/'
                vae_model_path = model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}.pth'
                
                if os.path.exists(vae_model_path) and _classify:
                    print(f'VAE already exists. Loading model {vae_model_path}.')
                    VAE_network.load_state_dict(torch.load(vae_model_path, map_location=device))

                JOINT_Model = JointModel(VAE_network, CLASS_network, ATT_Layer)

                opt_parameters = list(VAE_network.parameters())

                if _classify:
                    opt_parameters += list(CLASS_network.parameters())
                
                if _attention:
                    opt_parameters += list(ATT_Layer.parameters())

                optimizer = torch.optim.Adam(opt_parameters, weight_decay=1e-3, lr=_lr)

                scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=_sched_factor, min_lr=_sched_min_lr, patience=_sched_patience)

                MSE_criterion = nn.MSELoss(reduction='sum')

                if _classify:
                    if _set_loss == "focal_loss":
                        print('Focal Loss Chosen!')
                        CLASS_criterion = FocalLoss(alpha=0.25, gamma=2.0, reduction='sum')
                    elif _set_loss == "adaptative_focal_loss":
                        print('Adaptative Focal Loss Chosen!')
                        CLASS_criterion = AdaptativeFocalLoss(alpha=0.25, gamma=2.0, reduction='sum')
                    elif _set_loss == "cross_entropy_loss":
                        print('Cross Entropy Loss Chosen!')
                        CLASS_criterion = nn.CrossEntropyLoss(reduction='sum')
                    else:
                        raise NotImplementedError(f"Invalid Loss: {_set_loss}")
                
                print('Training VAE')
                mean_train_loss = []
                mean_valid_loss = []
                curr_loss = 0
                prev_loss = np.inf
                train_acc = 0.0
                valid_acc = 0.0
                limit_stop = 20 #100
                train_loss_graph = []
                valid_loss_graph = []
                train_acc_graph = []
                valid_acc_graph = []

                x_train_ftrs = torch.Tensor([]).to(device)
                y_train_ftrs = torch.Tensor([]).to(device)
                x_valid_ftrs = torch.Tensor([]).to(device)
                y_valid_ftrs = torch.Tensor([]).to(device)

                for epoch in range(_epochs):
                    train_loss = []
                    valid_loss = []
                    running_loss = 0.0
                    running_kld = 0.0
                    running_recon = 0.0
                    running_corrects = 0

                    beta = torch.Tensor([epoch+1])
                
                    for i, train_data in enumerate(train_dataloader):

                        inputs, labels = train_data
                        
                        if conv:
                            inputs = inputs.to(device)
                        else:
                            inputs = inputs.to(device).squeeze(1)

                        labels = labels.type(dtype=torch.LongTensor)
                        labels = labels.to(device)

                        try:
                            labels = labels.squeeze(1)
                        except:
                            pass

                        optimizer.zero_grad()

                        # Variational AutoEncoder NN
                        inputs_pred, mean, logvar = VAE_network.forward(inputs, encoder=False, decoder=True)
                        
                        MSE_loss = MSE_criterion(inputs_pred, inputs) # Reconstruction Loss 
                        KLD_loss = - 0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
                        VAE_loss = (_gamma_1 * MSE_loss) + (_gamma_2 * KLD_loss)
                        
                        if _classify:
                            outputs = JOINT_Model(inputs, attention=_attention, incerteza=_incerteza, method=_att_method)
                            
                            _, y_pred = torch.max(outputs, 1)

                            if _set_loss == "adaptative_focal_loss":
                                CLASS_loss = CLASS_criterion(outputs, labels, beta)
                                beta += 1
                            else:
                                CLASS_loss = CLASS_criterion(outputs, labels)

                            running_corrects += torch.sum(y_pred == labels.data)
                            t_loss = VAE_loss + (_gamma_3 * CLASS_loss)
                        else:
                            t_loss =  VAE_loss

                        running_recon += _gamma_1 * MSE_loss.item() * inputs.size(0)
                        running_kld += _gamma_2 * KLD_loss.item() * inputs.size(0)
                        running_loss += t_loss.item() * inputs.size(0)

                        t_loss.backward()
                        optimizer.step()
                
                    if _classify:
                        train_acc = running_corrects.double() / train_size['y_size'][0]
                    

                    epoch_kld = running_kld / train_size['y_size'][0]
                    epoch_recon = running_recon / train_size['y_size'][0]
                    epoch_loss = running_loss / train_size['y_size'][0]
                    train_loss_graph.append(epoch_loss)

                    # print("Valid")
                    with torch.no_grad():
                        valid_running_loss = 0.0
                        valid_running_kld = 0.0
                        valid_running_recon = 0.0
                        valid_running_corrects = 0

                        for i, valid_data in enumerate(valid_dataloader):

                            inputs, labels = valid_data
                            
                            if conv:
                                inputs = inputs.to(device)
                            else:
                                inputs = inputs.to(device).squeeze(1)
                            
                            labels = labels.type(dtype=torch.LongTensor)
                            labels = labels.to(device)
                            
                            try:
                                labels = labels.squeeze(1)
                            except:
                                pass


                            # Variational AutoEncoder NN
                            x_pred, mean, logvar = VAE_network.forward(inputs, encoder=False, decoder=True)
                            MSE_loss = MSE_criterion(x_pred, inputs)
                            KLD_loss = - 0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
                            TVAE_loss = (_gamma_1 * MSE_loss) + (_gamma_2 * KLD_loss)

                            if _classify:
                                outputs = JOINT_Model(inputs, attention=_attention, incerteza=_incerteza, method=_att_method)
                                    
                                _, y_pred = torch.max(outputs, 1)
                                
                                if _set_loss == "adaptative_focal_loss":
                                    CLASS_loss = CLASS_criterion(outputs, labels, beta)
                                else:
                                    CLASS_loss = CLASS_criterion(outputs, labels)

                                valid_running_corrects += torch.sum(y_pred == labels.data)
                                loss = TVAE_loss + (_gamma_3 * CLASS_loss)
                            else:
                                loss = TVAE_loss

                            valid_running_loss += loss.item() * inputs.size(0)
                            valid_running_recon += _gamma_1 * MSE_loss.item() * inputs.size(0)
                            valid_running_kld += _gamma_2 * KLD_loss.item() * inputs.size(0)

                        if epoch % 10 == 0 and not _classify:
                            plot_gallery([inputs.detach().cpu(), x_pred.detach().cpu()], epoch, fold, model_name, 1, 2, all_plot=True)  

                        valid_epoch_loss = valid_running_loss / valid_size['y_size'][0]
                        valid_epoch_kld = valid_running_kld / valid_size['y_size'][0]
                        valid_epoch_recon = valid_running_recon / valid_size['y_size'][0]
                        valid_loss_graph.append(valid_epoch_loss)

                        scheduler.step(valid_epoch_loss)
                        scheduler.get_last_lr()
                        curr_loss = valid_epoch_loss

                    if _classify:
                        valid_acc = valid_running_corrects.double() / valid_size['y_size'][0]
                        results[fold] = 100.0 * valid_acc
                        print(f'Epoch: {epoch}')
                        print(f'Train loss: {epoch_loss}, Train Acc: {train_acc*100:.2f}')
                        print(f'Valid loss: {valid_epoch_loss}, Valid  Acc: {valid_acc*100:.2f}')
                    else:
                        print(f'Epoch: {epoch} | Train loss: {epoch_loss} | Train KLD: {epoch_kld} | Train Recon Loss:{epoch_recon}')
                        print(f'Epoch: {epoch} | Valid loss: {valid_epoch_loss} | Valid KLD: {valid_epoch_kld} | Valid Recon Loss:{valid_epoch_recon}')
                        
                    # Verifique se a loss atual é menor que a melhor até agora
                    if curr_loss < prev_loss:
                        best_epoch = epoch
                        print(f"best_epoch: {best_epoch}")
                        if _classify: 
                            if _attention:
                                saved_model = {
                                    'VAE': VAE_network.state_dict(),
                                    'CLASS': CLASS_network.state_dict(),
                                    'ATT': ATT_Layer.state_dict(),
                                }
                            else:
                                saved_model = {
                                    'VAE': VAE_network.state_dict(),
                                    'CLASS': CLASS_network.state_dict(),
                                }
                        else:
                            saved_model = {
                                    'VAE': VAE_network.state_dict(),
                                }


                    # Early Stopping!
                    if curr_loss > prev_loss and epoch > 50:
                        print(f"prev_loss: {prev_loss}, curr_loss: {curr_loss}")
                        trigger_times += 1
                        print(f'Times without improved: {trigger_times}')

                        if trigger_times >= limit_stop:
                            print(f'[*] Early stopping in Epoch: {epoch} !')
                            print('Saving Model at epoch {}'.format(epoch+1))
                            break
                    else:
                        trigger_times = 0
                        prev_loss = curr_loss

                if _plotLoss:
                    # Create a dictionary to store loss and accuracy data for training and validation
                    loss_acc_data = {
                        "train_loss": train_loss_graph,  # List of training loss values for each epoch
                        "valid_loss": valid_loss_graph,  # List of validation loss values for each epoch
                        "train_acc": train_acc_graph,    # List of training accuracy values for each epoch
                        "valid_acc": valid_acc_graph     # List of validation accuracy values for each epoch
                    }

                    # Save the loss and accuracy data to a pickle file
                    with open(f'./Loss_Plot/pkls/loss_acc_data_{model_name}_{_set_loss}_I-{_incerteza}_A-{_attention}-{_att_method}_{_normalization}_Fold_{fold}.pkl', 'wb') as f:
                        pickle.dump(loss_acc_data, f)

                    # Extract the loss values for plotting
                    train_loss_values = loss_acc_data['train_loss']
                    valid_loss_values = loss_acc_data['valid_loss']

                    # Create an array of epoch numbers for the x-axis
                    epochs = np.arange(1, len(loss_acc_data['train_loss']) + 1)

                    # Plot the loss function
                    plt.figure(figsize=(16, 10))  # Set the figure size
                    plt.rcParams['legend.fontsize'] = 22  # Set the legend font size
                    plt.rcParams.update({'font.size': 25})  # Update the font size
                    plt.title(f"Loss Function {model_name}_{_normalization}_Fold_{fold}")  # Set the plot title
                    plt.plot(epochs, train_loss_values, label="Train_Loss", linewidth=2, color='Blue')  # Plot training loss
                    plt.plot(epochs, valid_loss_values, label="Valid_Loss", linewidth=2, color='Red')  # Plot validation loss
                    plt.xlabel('Epoch')  # Set the x-axis label
                    plt.ylabel('Loss')  # Set the y-axis label
                    plt.legend()  # Display the legend
                    plt.savefig(f"./Loss_Plot/Loss_{model_name}_{_set_loss}_I-{_incerteza}_A-{_attention}-{_att_method}_{_normalization}_Fold_{fold}.pdf")  # Save the plot to a PDF file
                    plt.close()  # Close the plot


                if _classify:
                    print('Saving VAE and CLASS model!')
                    if _attention == True and _incerteza == True:
                        torch.save(saved_model['VAE'], model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_VAE_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-A-{_att_method}-I-{_incerteza}-Loss-{_set_loss}-C.pth')
                        torch.save(saved_model['CLASS'], model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_CLASS_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-A-{_att_method}-I-{_incerteza}-Loss-{_set_loss}-C.pth')
                        torch.save(saved_model['ATT'], model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_ATT_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-A-{_att_method}-I-{_incerteza}-Loss-{_set_loss}-C.pth')
                    elif _incerteza == True and _attention == False:
                        torch.save(saved_model['VAE'], model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_VAE_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-I-{_incerteza}-Loss-{_set_loss}-C.pth')
                        torch.save(saved_model['CLASS'], model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_CLASS_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-I-{_incerteza}-Loss-{_set_loss}-C.pth')
                    elif _attention == True and _incerteza == False:
                        torch.save(saved_model['VAE'], model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_VAE_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-A-{_att_method}-Loss-{_set_loss}-C.pth')
                        torch.save(saved_model['CLASS'], model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_CLASS_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-A-{_att_method}-Loss-{_set_loss}-C.pth')
                        torch.save(saved_model['ATT'], model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_ATT_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-A-{_att_method}-Loss-{_set_loss}-C.pth')
                    else:
                        torch.save(saved_model['VAE'], model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_VAE_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-Loss-{_set_loss}-C.pth')
                        torch.save(saved_model['CLASS'], model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_CLASS_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-Loss-{_set_loss}-C.pth')
                else:
                    if _attention:
                        torch.save(saved_model['VAE'], model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_{model_name}_{_normalization}_{_augmentation}_A-{_att_method}_Fold_{str(fold)}.pth')
                    else:
                        torch.save(saved_model['VAE'], model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}.pth')

            # Print fold results
            print(f'K-FOLD CROSS VALIDATION RESULTS FOR {_k_folds} FOLDS')
            print('--------------------------------')
            sum = 0.0
            for key, value in results.items():
                print(f'Fold {key}: {value} %')
                sum += value
                print(f'Average: {sum/len(results.items())} %')

            print('Testing VAE+CLASS')

            with torch.no_grad():

                accuracy = []
                balancedAccuracyScore = []
                recall = []
                precision = []
                f1 = []
                auc = []
                test_loss = []

                try:
                    X_test_fold = X_test.drop(['Classe'], axis=1)
                except:
                    X_test_fold = X_test.copy()

                X_test_fold, _ = preprocessing(X_test_fold, NORM, _normalization, _train=False)

                if conv:
                    X_test_fold = np.expand_dims(X_test_fold, 1)

                test_dataloader, test_size = prepare_data_loader(X_test_fold, y_test, batch_size=_batch_size, shuffle=False)
                
                for fold in range(_k_folds):
                    
                    VAE_network = VAE(latent_dims=_latent_dims).to(device)
                        
                                        
                    if _incerteza:
                        input_attention_dims = _latent_dims
                    else:
                        input_attention_dims = _latent_dims*2
                    
                    CLASS_network = ClassifyingNetwork(num_ftrs=input_attention_dims).to(device)
                    
                    model_path = './model/'
                    if _attention == True and _incerteza == True:
                        vae_model_path = model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_VAE_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-A-{_att_method}-I-{_incerteza}-Loss-{_set_loss}-C.pth'
                        class_model_path = model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_CLASS_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-A-{_att_method}-I-{_incerteza}-Loss-{_set_loss}-C.pth'
                        att_model_path = model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_ATT_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-A-{_att_method}-I-{_incerteza}-Loss-{_set_loss}-C.pth'
                    elif _incerteza == True and  _attention == False:
                        vae_model_path = model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_VAE_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-I-{_incerteza}-Loss-{_set_loss}-C.pth'
                        class_model_path = model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_CLASS_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-I-{_incerteza}-Loss-{_set_loss}-C.pth'
                    elif _attention == True and _incerteza == False:
                        vae_model_path = model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_VAE_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-A-{_att_method}-Loss-{_set_loss}-C.pth'
                        class_model_path = model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_CLASS_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-A-{_att_method}-Loss-{_set_loss}-C.pth'
                        att_model_path = model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_ATT_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-A-{_att_method}-Loss-{_set_loss}-C.pth'
                    else:
                        vae_model_path = model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_VAE_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-Loss-{_set_loss}-C.pth'
                        class_model_path = model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_CLASS_{model_name}_{_normalization}_{_augmentation}_Fold_{str(fold)}-Loss-{_set_loss}-C.pth'

                    MSE_criterion = nn.MSELoss(reduction='sum')

                    if _set_loss == "focal_loss":
                        print('Focal Loss Chosen!')
                        CLASS_criterion = FocalLoss(alpha=0.25, gamma=2.0, reduction='sum')
                    elif _set_loss == "adaptative_focal_loss":
                        print('Adaptative Focal Loss Chosen!')
                        CLASS_criterion = AdaptativeFocalLoss(alpha=0.25, gamma=2.0, reduction='sum')
                    elif _set_loss == "cross_entropy_loss":
                        print('Cross Entropy Loss Chosen!')
                        CLASS_criterion = nn.CrossEntropyLoss(reduction='sum')
                    else:
                        raise NotImplementedError(f"Invalid Loss: {_set_loss}")

                    if os.path.exists(vae_model_path) and os.path.exists(class_model_path):
                        print(f'VAE already exists. Loading model {vae_model_path}.')
                        VAE_network.load_state_dict(torch.load(vae_model_path, map_location = device))
                        VAE_network.eval()
                        print(f'CLASS already exists. Loading model {class_model_path}.')
                        CLASS_network.load_state_dict(torch.load(class_model_path, map_location = device))
                        CLASS_network.eval()
                        if _attention:
                            print(f'ATT Weights already exists. Loading model {att_model_path}.')
                            ATT_Layer.load_state_dict(torch.load(att_model_path, map_location = device))
                            ATT_Layer.eval()

                    JOINT_Model = JointModel(VAE_network, CLASS_network, ATT_Layer)

                    test_running_loss = 0.0
                    test_running_corrects = 0
                    test_acc = 0
                    ypredVector = []
                    labelsVector = []

                    for i, test_data in enumerate(test_dataloader):

                        inputs, labels = test_data
                        if conv:
                            inputs = inputs.to(device)
                        else:
                            inputs = inputs.to(device).squeeze(1)
                        labels = labels.type(dtype=torch.LongTensor)
                        labels = labels.to(device)
                        
                        try:
                            labels = labels.squeeze(1)
                        except:
                            pass
                        
                        # Variational AutoEncoder NN
                        x_pred, mean, logvar = VAE_network.forward(inputs, encoder=False, decoder=True)
                        
                        TMSE_loss = MSE_criterion(x_pred, inputs)
                        
                        outputs = JOINT_Model(inputs, attention=_attention, incerteza=_incerteza, method=_att_method)

                        _, y_pred = torch.max(outputs, 1)

                        if _set_loss == "adaptative_focal_loss":
                            TCLASS_loss = CLASS_criterion(outputs, labels, torch.Tensor([1]))
                        else:
                            TCLASS_loss = CLASS_criterion(outputs, labels)
                            
                        valid_running_corrects += torch.sum(y_pred == labels.data)

                        ypredVector += y_pred.detach().cpu()
                        labelsVector += labels.data.detach().cpu()

                        loss = TVAE_loss + (_gamma_3 * TCLASS_loss)

                        test_running_loss += loss.item() * inputs.size(0)

                        # plot_gallery([inputs.detach().cpu(), x_pred.detach().cpu()], epoch, fold, 1, 2, all_plot=True)  

                    test_epoch_loss = test_running_loss / test_size['y_size'][0]

                    # Calculate performance metrics
                    accuracy_fold = accuracy_score(labelsVector, ypredVector)
                    balancedAccuracyScore_fold = balanced_accuracy_score(labelsVector, ypredVector)
                    recall_fold = recall_score(labelsVector, ypredVector, average='weighted')
                    precision_fold = precision_score(labelsVector, ypredVector, average='weighted', zero_division=True)
                    f1_fold = f1_score(labelsVector, ypredVector, average='weighted')

                    # Append performance metrics to their respective lists
                    test_loss.append(test_epoch_loss)
                    accuracy.append(accuracy_score(labelsVector, ypredVector))
                    balancedAccuracyScore.append(balanced_accuracy_score(labelsVector, ypredVector))
                    recall.append(recall_score(labelsVector, ypredVector, average='weighted'))
                    precision.append(precision_score(labelsVector, ypredVector, average='weighted', zero_division=True))
                    f1.append(f1_score(labelsVector, ypredVector, average='weighted'))

                    # Plot the confusion matrix
                    plt.figure(figsize=(16, 10))  # Set the figure size
                    plt.rcParams['legend.fontsize'] = 22  # Set the legend font size
                    plt.rcParams.update({'font.size': 25})  # Update the font size

                    # Calculate the confusion matrix
                    conf_matrix = confusion_matrix(labelsVector, ypredVector)

                    # Convert the confusion matrix to a DataFrame for better readability with seaborn
                    DetaFrame_cm = pd.DataFrame(conf_matrix)

                    # Plot the confusion matrix using seaborn
                    sns.heatmap(DetaFrame_cm, annot=True, xticklabels=['BCA True', 'BCA False'], yticklabels=['BCA True', 'BCA False'], fmt='d')
                    plt.title(f"Confusion Matrix {model_name}_{_normalization}_Fold_{fold}")  # Set the plot title

                    # Save the confusion matrix plot to a PDF file
                    plt.savefig(f"./Confusion_Matrix/Folds/Confusion_Matrix_{model_name}_{_set_loss}_I-{_incerteza}_A-{_attention}-{_att_method}_{_normalization}_Fold_{fold}.pdf")

                    # plt.show()  # Uncomment this line to display the plot if running interactively
                    plt.close()  # Close the plot

                    # Save results per fold
                    fold_csv_filename = f'./results_per_fold/{_data}_resultados_{model_name}_{_normalization}_{Sampling_mode}_per_fold.csv'
                    fold_results_data = {
                        'Fold': fold,
                        'Dataset': _dataset_name,
                        '_gamma_2': _gamma_2,
                        '_gamma_3': _gamma_3,
                        '_batch_size': _batch_size,
                        '_latent_dims': _latent_dims,
                        'Model': f'{model_name}_{_normalization}',
                        'Attention': _attention,
                        'Attention_Method': _att_method,
                        'Incerteza': _incerteza,
                        'Sampling Mode': Sampling_mode,
                        'Loss Function': _set_loss,
                        'Test - loss': test_epoch_loss,
                        'Test - Accuracy Score': accuracy_fold,
                        'Test - Balanced Accuracy Score': balancedAccuracyScore_fold,
                        'Test - Precision Score': recall_fold,
                        'Test - Recall Score': precision_fold,
                        'Test - F1 Score': f1_fold,
                    }
                    fold_header = fold_results_data.keys()

                    save_results_to_csv(fold_csv_filename, fold_results_data, fold_header)

                # Save general results
                general_csv_filename = f'./results/{_data}_resultados_{model_name}_{_normalization}_{Sampling_mode}.csv'
                general_results_data = {
                    'Dataset': _dataset_name,
                    'Model': f'{model_name}_{_normalization}',
                    'Attention': _attention,
                    'Attention_Method': _att_method,
                    'Incerteza': _incerteza,
                    'Sampling Mode': Sampling_mode,
                    '_gamma_2': _gamma_2,
                    '_gamma_3': _gamma_3,
                    '_batch_size': _batch_size,
                    '_latent_dims': _latent_dims,
                    'Loss Function': _set_loss,
                    'loss - mean': np.mean(test_loss),
                    'loss - std': np.std(test_loss),
                    'Accuracy Score - mean': np.mean(accuracy),
                    'Accuracy Score - std': np.std(accuracy),
                    'Balanced Accuracy Score - mean': np.mean(balancedAccuracyScore),
                    'Balanced Accuracy Score - std': np.std(balancedAccuracyScore),
                    'Precision Score - mean': np.mean(precision),
                    'Precision Score - std': np.std(precision),
                    'Recall Score - mean': np.mean(recall),
                    'Recall Score - std': np.std(recall),
                    'F1 Score - mean': np.mean(f1),
                    'F1 Score - std': np.std(f1),
                }
                general_header = general_results_data.keys()

                save_results_to_csv(general_csv_filename, general_results_data, general_header)

                print(f"Resultados salvos em ./results/{_data}_resultados_{model_name}_{_normalization}_{Sampling_mode}.csv")

## Plot ReconData

In [None]:
with torch.no_grad():

    accuracy = []
    balancedAccuracyScore = []
    recall = []
    precision = []
    f1 = []
    auc = []
    test_loss = []

    try:
        X_test_fold = X_test.drop(['Classe'], axis=1)
    except KeyError:
        X_test_fold = X_test.copy()
    
    # NORM = derivative()
    # NORM = snv()
    # NORM = SNVTransformer()
    X_test_fold, _ = preprocessing(X_test_fold, NORM, _normalization, _train=False)

    if conv:
        X_test_fold = np.expand_dims(X_test_fold, 1)

    test_dataloader, test_size = prepare_data_loader(X_test_fold, y_test, batch_size=_batch_size, shuffle=False)
        
    VAE_network = VAE(latent_dims=_latent_dims).to(device)
    
    model_path = './model/'
    _set_loss = 'focal_loss'
    # Only Vae
    # vae_model_path = model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_{model_name}_{_normalization}_{_augmentation}_Fold_4.pth'
    vae_model_path = model_path + f'Kfold_results/{Sampling_mode}/NIR-SC-UFES_MS_E_MLP-IR_CNN-1D_SNV_None_Fold_4.pth'
    print(vae_model_path)
    # Full model VAE
    # vae_model_path = model_path + f'Kfold_results/{Sampling_mode}/{_data}_MS_VAE_{model_name}_{_normalization}_{_augmentation}_Fold_3-A-{_att_method}-I-{_incerteza}-Loss-{_set_loss}-C.pth'
    MSE_criterion = nn.MSELoss(reduction='sum')
    # print(vae_model_path)

    if os.path.exists(vae_model_path):
        print(f'VAE already exists. Loading model {vae_model_path}.')
        VAE_network.load_state_dict(torch.load(vae_model_path, map_location = device))
        VAE_network.eval()

    test_running_loss = 0.0
    test_running_corrects = 0
    test_acc = 0
    ypredVector = []
    labelsVector = []

    for i, test_data in enumerate(test_dataloader):

        inputs, labels = test_data
        if conv:
            inputs = inputs.to(device)
        else:
            inputs = inputs.to(device).squeeze(1)
            
        labels = labels.type(dtype=torch.LongTensor)
        labels = labels.to(device)
        
        try:
            labels = labels.squeeze(1)
        except:
            pass
        
        # Variational AutoEncoder NN
        x_pred, mean, logvar = VAE_network.forward(inputs, encoder=False, decoder=True)

        eixo_x = [908.1, 914.294, 920.489, 926.683, 932.877, 939.072, 
          945.266, 951.46, 957.655, 963.849, 970.044, 976.238, 
          982.432, 988.627, 994.821, 1001.015, 1007.21, 1013.404, 
          1019.598, 1025.793, 1031.987, 1038.181, 1044.376, 1050.57, 
          1056.764, 1062.959, 1069.153, 1075.348, 1081.542, 1087.736, 
          1093.931, 1100.125, 1106.319, 1112.514, 1118.708, 1124.902, 
          1131.097, 1137.291, 1143.485, 1149.68, 1155.874, 1162.069, 
          1168.263, 1174.457, 1180.652, 1186.846, 1193.04, 1199.235, 
          1205.429, 1211.623, 1217.818, 1224.012, 1230.206, 1236.401, 
          1242.595, 1248.789, 1254.984, 1261.178, 1267.373, 1273.567, 
          1279.761, 1285.956, 1292.15, 1298.344, 1304.539, 1310.733, 
          1316.927, 1323.122, 1329.316, 1335.51, 1341.705, 1347.899, 
          1354.094, 1360.288, 1366.482, 1372.677, 1378.871, 1385.065, 
          1391.26, 1397.454, 1403.648, 1409.843, 1416.037, 1422.231, 
          1428.426, 1434.62, 1440.814, 1447.009, 1453.203, 1459.398, 
          1465.592, 1471.786, 1477.981, 1484.175, 1490.369, 1496.564, 
          1502.758, 1508.952, 1515.147, 1521.341, 1527.535, 1533.73, 
          1539.924, 1546.119, 1552.313, 1558.507, 1564.702, 1570.896, 
          1577.09, 1583.285, 1589.479, 1595.673, 1601.868, 1608.062, 
          1614.256, 1620.451, 1626.645, 1632.839, 1639.034, 1645.228, 
          1651.423, 1657.617, 1663.811, 1670.006, 1676.2]


        plt.figure(figsize=(16,10))

        plt.rcParams['legend.fontsize'] = 22
        plt.rcParams.update({'font.size': 25})
        # ax.axis("off")
        plt.plot(eixo_x, inputs[2][0].squeeze(0).cpu(), linewidth=3, c='g')
        plt.plot(eixo_x, x_pred[2][0].squeeze(0).cpu(), linewidth=3, c='r')

        plt.legend(['Original', 'Reconstruído'])    
        # plt.xlabel("Wavenumber (nm)")
        # plt.xlabel("Principal Components")
        # plt.ylabel("Absorption Level")
        plt.xlabel("Comprimento de onda (nm)")
        plt.ylabel("Absorbância")
        # plt.savefig(f'./Orig_NIR-SC-UFES_in_GNA_VAE.pdf')
        # dir_save = '/mnt/hdd/matheusbecali/E-MLP/t-SNE/Orig_Recon_NIR-SC-UFES_in_GNA_VAE.pdf'
        dir_save = '/mnt/hdd/matheusbecali/E-MLP/plots_imgs/Dissertarion/Orig_Recon_NIR-SC-UFES_in_GNA_VAE.pdf'
        plt.savefig(dir_save)
        plt.show()
        # plt.close()
        # plot_gallery([inputs.detach().cpu(), x_pred.detach().cpu()], epoch, fold, 1, 2, all_plot=True)
        break  