# Test of looping over $\beta$-VAE to detect and classify outliers

In [1]:
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from tqdm import tqdm

from pathlib import Path

In [2]:
from magnetics_diagnostic_analysis.project_vae.setting_vae import config



Choosen device = cuda


In [3]:
import torch
from torch import nn

from magnetics_diagnostic_analysis.ml_tools.pytorch_device_selection import print_torch_info
print_torch_info()


Torch version?  2.4.1+cu121
Cuda?           True

GPU number : 2
GPU 0: Tesla T4
GPU 1: Tesla T4


In [4]:
suffix = "vae"

### 1. Create dataset and DataLoader

I took the decision that one data sample will be : all the time values of one shot and for all diagnostics. It will be easy after, to reduce to one diagnostic only (It wouldn't habe been the case if we wanted to use all diagnostics for one timestep -> there reduce to one diagnostic just give us one number and that is to small).

As all shots own different lenghts, we are going to use LSTM unit in entry of our VAE. This LSTM unit is combined with padded sequence and have masking behaviour.

Thus, after the LSTM, we will have a constant size tensor (the LSTM hidden state) that we can use in our VAE.

Consideration:

We want our model to be robust to any different size during testing time.

Thus, we are going to find the max_lenght for each batch size in the dataloader.

And thanks to the two functions `pack_padded_sequence`, `pad_packed_sequence`, the LSTM is aware of the true lenght of each sequence and use masking.

In [8]:
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence

In [9]:
path = Path().absolute().parent.parent / "data/preprocessed/mscred/data_magnetics_mscred_cleaned.nc"
data_all = xr.open_dataset(path)
data_all

In [10]:
def find_seq_length(data: xr.Dataset) -> np.ndarray:
    # Find the length of each sequence in the dataset
    seq_indices = data['shot_index'].values
    return np.bincount(seq_indices)

lengths = find_seq_length(data_all)
lengths

array([2084, 1989, 2230, ..., 3153, 2089, 1947])

In [11]:
class TimeSeriesDataset(Dataset):
    def __init__(self, data: xr.Dataset, n_var_to_keep: int = 4, n_subsample: int = 10):
        # Group data by shot_index
        self.shot_indices = data['shot_index'].values
        self.unique_shots = np.unique(self.shot_indices)
        
        # Precompute sequences for each shot index
        self.sequences = {}
        for shot in self.unique_shots:
            mask = self.shot_indices == shot
            shot_data = []
            for var in data.data_vars:
                if var == 'shot_index':
                    continue
                if data[var].ndim == 1:
                    var_data = data[var].values[mask][:, np.newaxis]
                else:
                    var_data = data[var].values[mask]
                    if var_data.shape[1] > n_var_to_keep:
                        var_data = var_data[:, :n_var_to_keep]
                var_data = var_data[::n_subsample]
                shot_data.append(var_data)
            self.sequences[shot] = np.concatenate(shot_data, axis=1)      # axis=1 => along features dimension
        
        self.lengths = {shot: len(self.sequences[shot]) for shot in self.unique_shots}

    def __len__(self):
        return len(self.unique_shots)
    
    def __getitem__(self, idx):
        shot = self.unique_shots[idx]
        return self.sequences[shot], self.lengths[shot]

In [12]:
def create_datasets(
    data: xr.Dataset,
    set_separation: int = 12000,
) -> tuple[Dataset]:
    """
    Create train, validation and test data loaders from time series data
    
    Args:
        data: xarray Dataset with shot_index variable
        batch_size: batch size for data loaders
        set_separation: boundarie between train and test sets
        device: device to load data on
    
    Returns:
        train_loader, valid_loader, test_loader: DataLoader objects
    """
    # Get shot indices
    shot_indices = data['shot_index'].values
    unique_shots = np.unique(shot_indices)

    # Find the time index where each shot starts
    shot_start_times = {}
    real_test_end = config.DATA_NUMBER
    for shot in unique_shots:
        start_idx = np.where(shot_indices == shot)[0][0]
        shot_start_times[shot] = start_idx
        if start_idx > real_test_end:
            break

    # Split shots into sets based on the set_separation values but also their start time
    train_end= set_separation
    test_shots = [shot for shot in unique_shots[:len(shot_start_times)] if real_test_end > shot_start_times[shot] >= train_end][:-1]    # [:-1] to exclude last incomplete shot
    # Get real start and end times for each split: preserving shots integrity
    real_test_start = shot_start_times[np.min(test_shots)]

    # Create datasets for each split
    train_dataset = TimeSeriesDataset(data.isel(time=slice(0, real_test_start)))
    test_dataset = TimeSeriesDataset(data.isel(time=slice(real_test_start, real_test_end)))
    
    return train_dataset, test_dataset

In [None]:
train_set, test_set = create_datasets(data_all, set_separation=config.SET_SEPARATION)

print("Training set size:", len(train_set))
print("Testing set size:", len(test_set))
print(train_set)

In [13]:
path_train = config.DIR_PREPROCESSED_DATA / f"dataset_magnetics_vae_train.pt"
path_test = config.DIR_PREPROCESSED_DATA / f"dataset_magnetics_vae_test.pt"
if not path_train.exists():
    torch.save(train_set, path_train)
    print(f"Saved dataset to {path_train}")
if not path_test.exists():
    torch.save(test_set, path_test)
    print(f"Saved dataset to {path_test}")


In [14]:
path_train = config.DIR_PREPROCESSED_DATA / f"dataset_magnetics_vae_train.pt"
path_test = config.DIR_PREPROCESSED_DATA / f"dataset_magnetics_vae_test.pt"
train_set = torch.load(path_train)
test_set = torch.load(path_test)

  train_set = torch.load(path_train)
  test_set = torch.load(path_test)


In [15]:
def pad_sequences_smartly(batch):
    """Custom collate function to pad sequences to max length in batch"""
    sequences, lengths = zip(*batch)
    
    # Convert sequences to tensors
    sequence_tensors = [torch.from_numpy(seq).float() for seq in sequences]
    padded_sequences = pad_sequence(
        sequence_tensors, 
        batch_first=True, 
        padding_value=0.0
    )
    length_tensor = torch.tensor(lengths, dtype=torch.long)
    
    return padded_sequences, length_tensor

In [16]:
def create_data_loaders(
    train_dataset: Dataset,
    test_dataset: Dataset,
    batch_size: int = 32,
    collate_fn: callable = None
) -> tuple[DataLoader]:
    """
    Create data loaders for training and testing

    Args:
        train_dataset: Dataset for training
        test_dataset: Dataset for testing
        batch_size: Batch size for data loaders
        collate_fn: Custom collate function

    Returns:
        train_loader, test_loader: DataLoader objects
    """
    # Create DataLoaders with custom collate function
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn,
        drop_last=True
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_fn,
        drop_last=True
    )
    
    return train_loader, test_loader

In [17]:
train_loader, test_loader = create_data_loaders(
    train_dataset=train_set,
    test_dataset=test_set,
    batch_size=config.BATCH_SIZE,
    collate_fn=pad_sequences_smartly
)

In [18]:
print("Length of train_loader:", len(train_loader))

print(next(iter(train_loader))[0].shape)  # Get the first batch: data
print(next(iter(train_loader))[1].shape)  # Get the first batch: lengths

Length of train_loader: 368
torch.Size([10, 248, 21])
torch.Size([10])


## 2. Model LSTM-$\beta$-VAE implementation

In [19]:
class MultiHeadAttention(nn.Module):
    def __init__(self, input_dim, embed_dim, num_heads, dropout=0.1):
        super().__init__()
        assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads"
        
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        # Projection layers for queries, keys and values
        self.q_proj = nn.Linear(input_dim, embed_dim)
        self.k_proj = nn.Linear(input_dim, embed_dim)
        self.v_proj = nn.Linear(input_dim, embed_dim)
        
        self.output_proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        
        self.scale = np.sqrt(self.head_dim)
        
    def forward(self, query, key, value, key_padding_mask=None):
        """
        query: [batch_size, query_len, input_dim]
        key: [batch_size, key_len, input_dim]
        value: [batch_size, value_len, input_dim]
        key_padding_mask: [batch_size, key_len] (True for padding positions)
        """
        batch_size = query.size(0)
        query_len = query.size(1)
        key_len = key.size(1)
        
        # Project queries, keys and values
        Q = self.q_proj(query)  # [batch_size, query_len, embed_dim]
        K = self.k_proj(key)    # [batch_size, key_len, embed_dim]
        V = self.v_proj(value)  # [batch_size, value_len, embed_dim]
        
        # Reshape for multi-head attention
        Q = Q.view(batch_size, query_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, key_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, key_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Calculate attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale  # [batch_size, num_heads, query_len, key_len]
        
        # Apply mask if provided
        if key_padding_mask is not None:
            # Expand mask to [batch_size, num_heads, query_len, key_len]
            mask = key_padding_mask.unsqueeze(1).unsqueeze(2).expand(-1, self.num_heads, query_len, -1)
            scores = scores.masked_fill(mask, float('-inf'))
        
        # Apply softmax to get attention weights
        attn_weights = nn.functional.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Apply attention to values
        attn_output = torch.matmul(attn_weights, V)  # [batch_size, num_heads, query_len, head_dim]
        
        # Concatenate heads and put through final linear layer
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, query_len, self.embed_dim)
        attn_output = self.output_proj(attn_output)
        
        return attn_output, attn_weights
    
class MultiHeadAttention2(nn.Module):
    def __init__(self, dim, num_heads=8):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        
        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        self.out_proj = nn.Linear(dim, dim)
        
    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        
        # Projections linéaires
        Q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        K = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        V = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        
        # Calcul d'attention par tête
        scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) / math.sqrt(self.head_dim)
        attention = torch.softmax(scores, dim=-1)
        out = torch.einsum('bhqk,bkhd->bqhd', attention, V)
        
        # Concaténation et projection finale
        out = out.contiguous().view(batch_size, seq_len, -1)
        return self.out_proj(out)

In [20]:
class SafeAttentionEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, num_heads=4):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, 
                           batch_first=True, bidirectional=True)
        
        # Attention sur la sortie LSTM
        self.attention = MultiHeadAttention(
            input_dim=hidden_dim * 2,
            embed_dim=hidden_dim * 2, 
            num_heads=num_heads
        )
        
        # Couches pour produire mean et logvar
        self.to_mean = nn.Linear(hidden_dim * 2, latent_dim)
        self.to_logvar = nn.Linear(hidden_dim * 2, latent_dim)
        
    def forward(self, x, lengths):
        # 1. LSTM
        packed_input = pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_output, _ = self.lstm(packed_input)
        lstm_output, _ = pad_packed_sequence(packed_output, batch_first=True)
        
        # 2. Masque pour l'attention
        mask = torch.arange(x.size(1), device=x.device)[None, :] >= lengths[:, None]
        
        # 3. Auto-attention sur la sortie LSTM
        attn_output, attn_weights = self.attention(
            query=lstm_output, 
            key=lstm_output, 
            value=lstm_output,
            key_padding_mask=mask
        )
        
        # 4. Pooling global (seulement sur les steps valides)
        mask_expanded = mask.unsqueeze(-1)
        summed = torch.sum(attn_output * (~mask_expanded).float(), dim=1)
        context = summed / lengths.float().unsqueeze(-1)
        
        # 5. Transformation en vecteur latent
        z_mean = self.to_mean(context)
        z_logvar = self.to_logvar(context)
        
        return z_mean, z_logvar, attn_weights

class SafeAttentionDecoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim, num_layers, num_heads=4):
        super().__init__()
        self.latent_to_hidden = nn.Linear(latent_dim, hidden_dim)
        
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True)
        self.output_layer = nn.Linear(hidden_dim, output_dim)
        
        # Pas d'attention croisée avec l'encodeur pour éviter la fuite!
        
    def forward(self, z, lengths):
        batch_size = z.size(0)
        max_length = lengths.max()
        
        # Initialisation à partir du vecteur latent
        h0 = self.latent_to_hidden(z).unsqueeze(0).repeat(self.lstm.num_layers, 1, 1)
        c0 = torch.zeros_like(h0)
        
        # Séquence d'entrée nulle
        input_seq = torch.zeros(batch_size, max_length, self.lstm.input_size, device=z.device)
        
        # LSTM
        packed_input = pack_padded_sequence(input_seq, lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_output, _ = self.lstm(packed_input, (h0, c0))
        output, _ = pad_packed_sequence(packed_output, batch_first=True)
        
        return self.output_layer(output)

In [21]:
class LengthAwareLSTMEncoder(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, latent_dim: int, num_layers: int) -> None:
        super().__init__()
        self.encoder_lstm = nn.LSTM(input_dim, hidden_dim, num_layers, bidirectional=False, batch_first=True)
        self.encoder_linear_mean = nn.Linear(hidden_dim, latent_dim)
        self.encoder_linear_logvar = nn.Linear(hidden_dim, latent_dim)
        #self.dropout = nn.Dropout(p=0.5)

    def forward(self, x: torch.Tensor, lengths: torch.Tensor) -> tuple[torch.Tensor]:
        packed_input = pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_output, (hidden, _) = self.encoder_lstm(packed_input)
        #output, output_lengths = pad_packed_sequence(packed_output, batch_first=True)

        last_hidden = hidden[-1]
        mean = self.encoder_linear_mean(last_hidden)
        logvar = self.encoder_linear_logvar(last_hidden)
        return mean, logvar

In [22]:
batch_size = 10
input_dim = 24
hidden_dim = 128
latent_dim = 24
num_layers = 2

n_time = 200
length_foo = torch.randint(low=50, high=n_time+1, size=(batch_size,))
seq_length = max(length_foo).item()
x_foo = torch.randn(batch_size, seq_length, input_dim)
encoder_foo = LengthAwareLSTMEncoder(input_dim, hidden_dim, latent_dim, num_layers)
mean_foo, logvar_foo = encoder_foo(x_foo, length_foo)


assert mean_foo.shape == (batch_size, latent_dim)
assert logvar_foo.shape == (batch_size, latent_dim)

In [23]:
class LengthAwareLSTMDecoder(nn.Module):
    def __init__(self, latent_dim: int, hidden_dim: int, output_dim: int, num_layers: int) -> None:
        super().__init__()
        self.decoder_linear_init = nn.Linear(latent_dim, hidden_dim * num_layers * 2)  # For hidden and cell states of each layer
        self.decoder_lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers, bidirectional=False, batch_first=True)
        self.decoder_output_layer = nn.Linear(hidden_dim, output_dim)
        #self.dropout = nn.Dropout(p=0.5)
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

    def forward(self, z: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
        batch_size = z.size(0)

        init_states = self.decoder_linear_init(z)
        h0 = init_states[:, :self.hidden_dim * self.num_layers].reshape(self.num_layers, batch_size, self.hidden_dim)
        c0 = init_states[:, self.hidden_dim * self.num_layers:].reshape(self.num_layers, batch_size, self.hidden_dim)

        max_length = torch.max(lengths)
        input_seq = torch.zeros(batch_size, max_length, self.hidden_dim, device=z.device)

        packed_input = pack_padded_sequence(input_seq, lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_output, _ = self.decoder_lstm(packed_input, (h0, c0))

        transformed_data = self.decoder_output_layer(packed_output.data)
        output_packed = torch.nn.utils.rnn.PackedSequence(
            data=transformed_data, 
            batch_sizes=packed_output.batch_sizes,
            sorted_indices=packed_output.sorted_indices,
            unsorted_indices=packed_output.unsorted_indices
        )
        output, _ = pad_packed_sequence(output_packed, batch_first=True)
        
        return output

In [24]:
batch_size = 10
seq_length = 200
input_dim = 24
hidden_dim = 128
latent_dim = 24
num_layers = 2

z_foo = torch.randn(batch_size, latent_dim)
decoder_foo = LengthAwareLSTMDecoder(latent_dim, hidden_dim, input_dim, num_layers)
masked_output_foo = decoder_foo(z_foo, length_foo)

assert masked_output_foo.shape == (batch_size, torch.max(length_foo), input_dim)

In [25]:
class LSTMBetaVAE(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, latent_dim: int, lstm_num_layers: int) -> None:
        super().__init__()
        self.encoder = LengthAwareLSTMEncoder(input_dim, hidden_dim, latent_dim, lstm_num_layers)
        self.decoder = LengthAwareLSTMDecoder(latent_dim, hidden_dim, input_dim, lstm_num_layers)

    def forward(self, x: torch.Tensor, lengths: torch.Tensor) -> tuple[torch.Tensor]:
        z_mean, z_logvar = self.encoder(x, lengths)
        z = self.reparameterize(z_mean, z_logvar)
        return self.decoder(z, lengths), z_mean, z_logvar

    def reparameterize(self, mean: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mean + eps * std

In [26]:
batch_size = 10
input_dim = 24
hidden_dim = 128
latent_dim = 24
num_layers = 2

n_time = 200
length_foo = torch.randint(low=100, high=n_time+1, size=(batch_size,))
seq_length = max(length_foo).item()
x_foo = torch.randn(batch_size, seq_length, input_dim)
vae_foo = LSTMBetaVAE(input_dim, hidden_dim, latent_dim, num_layers)
output_foo, mean_foo, logvar_foo = vae_foo(x_foo, length_foo)


assert mean_foo.shape == (batch_size, latent_dim)
assert logvar_foo.shape == (batch_size, latent_dim)
assert output_foo.shape == (batch_size, torch.max(length_foo), input_dim)
assert output_foo.shape == x_foo.shape

In [27]:
def vae_loss_function(
    x_recon: torch.Tensor, 
    x: torch.Tensor, 
    z_mean: torch.Tensor, 
    z_logvar: torch.Tensor, 
    lengths: torch.Tensor, 
    beta: float = 1.0
    ) -> tuple[torch.Tensor]:

    batch_size, seq_length, _ = x.shape

    mask = torch.arange(seq_length, device=x_recon.device)[None, :] < lengths[:, None]      # shape [batch_size, max_length]
    mask = mask.unsqueeze(-1).float()                                                       # shape [batch_size, max_length, 1]

    MSE = nn.functional.mse_loss(x_recon, x, reduction='none')
    MSE = (MSE * mask).sum(dim=(1,2))        # Mask application
    num_valid_steps = mask.sum(dim=(1,2))    # Normalizing factor
    MSE = torch.where(num_valid_steps > 0, MSE / num_valid_steps, torch.zeros_like(MSE))

    KLD = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp(), dim=1)
    KLD = KLD / batch_size

    TOTAL = torch.mean(MSE + beta * KLD)
    MSE = torch.mean(MSE)
    KLD = torch.mean(KLD)

    return TOTAL, MSE, KLD

In [28]:
batch_size = 10
input_dim = 24
hidden_dim = 128
latent_dim = 24
num_layers = 2

n_time = 200
length_foo = torch.randint(low=100, high=n_time+1, size=(batch_size,))
seq_length = max(length_foo).item()
x_foo = torch.randn(batch_size, seq_length, input_dim)
x_recon_foo = torch.randn(batch_size, seq_length, input_dim)

z_mean_foo = torch.randn(batch_size, latent_dim)
z_logvar_foo = torch.randn(batch_size, latent_dim)

beta = 20.0

loss, loss_mse, loss_kld = vae_loss_function(
    x_foo, x_recon_foo, z_mean_foo, z_logvar_foo, length_foo, beta
)

print(loss)
print(loss_mse)
print(loss_kld)
assert loss.shape == torch.Size([])
assert loss_mse.shape == torch.Size([])
assert loss_kld.shape == torch.Size([])

tensor(91.8865)
tensor(48.0560)
tensor(2.1915)


## 3. Train loop

- On choisit d'éliminer les outliers par rapport à la densité de la reconstruction et non à la clusterisation de l'espace latent.

In [29]:
from torch.utils.data import Subset
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KernelDensity
from sklearn.cluster import DBSCAN

import gc

In [30]:
from magnetics_diagnostic_analysis.ml_tools.train_callbacks import EarlyStopping, LRScheduling, GradientClipping, DropOutScheduling

In [31]:
def vae_reconstruction_error(
    x_recon: torch.Tensor, 
    x: torch.Tensor, 
    lengths: torch.Tensor, 
    ) -> tuple[torch.Tensor]:

    _, seq_length, _ = x.shape

    mask = torch.arange(seq_length, device=x_recon.device)[None, :] < lengths[:, None]      # shape [batch_size, max_length]
    mask = mask.unsqueeze(-1).float()                                                       # shape [batch_size, max_length, 1]

    mse = nn.functional.mse_loss(x_recon, x, reduction='none')
    mse = (mse * mask).sum(dim=(1,2))        # Mask application
    num_valid_steps = mask.sum(dim=(1,2))    # Normalizing factor
    mse = torch.where(num_valid_steps > 0, mse / num_valid_steps, torch.zeros_like(mse))

    return mse

In [32]:
def train_iterative_vae_pipeline(
    train_dataset: Dataset,
    n_iterations: int = 5,
    n_epochs_per_iter: int = 50,
    batch_size: int = 32,
    reconstruction_percentile_threshold: float = 95,
    kde_bandwidth: float = 0.5,
    dbscan_eps: float = 0.5,
    dbscan_min_samples: int = 5,
    device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
) -> dict:
    
    # Model parameters
    sample_data, _ = train_dataset[0]
    input_dim = sample_data.shape[-1]
    hidden_dim = 256
    latent_dim = 128
    lstm_layers = 2

    # Good and bad health indices initialization
    valid_indices = list(range(len(train_dataset)))
    all_anomaly_indices = np.array([], dtype=int)

    # Model storage
    vae_models = []

    full_loader = DataLoader(
            dataset=train_dataset, 
            batch_size=batch_size, 
            shuffle=True, 
            collate_fn=pad_sequences_smartly, 
            drop_last=False)

    for iteration in range(n_iterations):
        print(f"Iteration {iteration + 1}/{n_iterations}")
        print(f"Training on {len(valid_indices)} samples...")

        # Data SubSet creation
        current_subset = Subset(train_dataset, valid_indices)
        train_loader = DataLoader(
            dataset=current_subset, 
            batch_size=batch_size, 
            shuffle=True, 
            collate_fn=pad_sequences_smartly, 
            drop_last=False)

        # VAE Training
        vae = LSTMBetaVAE(input_dim, hidden_dim, latent_dim, lstm_layers).to(device)
        optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)

        # Training loop
        vae.train()
        print(f"{'Epoch':<40} {'Loss':<20} {'mse':<20} {'kld':<20}")
        for epoch in range(n_epochs_per_iter):
            total_loss = 0
            for batch_data, batch_lengths in tqdm(train_loader, desc=f"Training VAE number {iteration + 1}", leave=False):
                batch_data = batch_data.to(device)
                batch_lengths = batch_lengths.to(device)

                optimizer.zero_grad()
                recon_batch, z_mean, z_logvar = vae(batch_data, batch_lengths)
                loss, mse, kld = vae_loss_function(recon_batch, batch_data, z_mean, z_logvar, batch_lengths, beta=3.0)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
                total_mse, total_kld = mse.item(), kld.item()

            epo, total_loss, total_mse, total_kld = (epoch+1)/n_epochs_per_iter, total_loss/len(train_loader), total_mse/len(train_loader), total_kld/len(train_loader)
            print(f"{epo:<40} {total_loss:<20} {total_mse:<20} {total_kld:<20}")

        # VAE Evaluation
        vae.eval()
        reconstruction_errors = torch.zeros(len(train_dataset))           # or len(current_subset)
        with torch.no_grad():
            i = 0
            for batch_data, batch_lengths in tqdm(full_loader, desc=f"Evaluating VAE number {iteration + 1}", leave=False):
                batch_data = batch_data.to(device)
                batch_lengths = batch_lengths.to(device)
                
                recon_batch, _, _ = vae(batch_data, batch_lengths)
                mse = vae_reconstruction_error(recon_batch, batch_data, batch_lengths)
                reconstruction_errors[i*batch_size : (i+1)*batch_size] = mse
                i += 1

        reconstruction_errors = reconstruction_errors.cpu().numpy()
        # Outlier detection with KDE
        kde = KernelDensity(bandwidth=kde_bandwidth)
        kde.fit(reconstruction_errors.reshape(-1, 1))  # shape [n_samples, 1]
        density_scores = np.exp(kde.score_samples(reconstruction_errors.reshape(-1, 1)))       # shape [n_samples]

        threshold = np.percentile(density_scores, 100 - reconstruction_percentile_threshold)
        new_anomalies = np.where(density_scores < threshold)[0]

        # Update anomaly indices
        all_anomaly_indices = np.unique(np.concatenate([all_anomaly_indices, new_anomalies]))
        valid_indices = list(np.setdiff1d(np.arange(len(train_dataset)), all_anomaly_indices))
        
        vae_models.append(vae.state_dict())
        print(f"New anomalies detected: {len(new_anomalies)}")
        print(f"Total anomalies: {len(all_anomaly_indices)}\n")

        # Delete cache and big variables
        del vae, optimizer, reconstruction_errors, kde
        torch.cuda.empty_cache()
        gc.collect()



    # Last training phase
    print("Training final model...")
    final_subset = Subset(train_dataset, valid_indices)
    final_loader = DataLoader(final_subset, batch_size=batch_size, shuffle=True, collate_fn=pad_sequences_smartly, drop_last=False)

    final_vae = LSTMBetaVAE(input_dim, hidden_dim, latent_dim, lstm_layers).to(device)
    optimizer = torch.optim.Adam(final_vae.parameters(), lr=1e-3)
    
    final_vae.train()
    print(f"{'Epoch':<40} {'Loss':<20} {'mse':<20} {'kld':<20}")
    for epoch in range(n_epochs_per_iter):
        total_loss = 0
        for batch_data, batch_lengths in tqdm(final_loader, desc="Training final VAE", leave=False):
            batch_data = batch_data.to(device)
            batch_lengths = batch_lengths.to(device)

            optimizer.zero_grad()
            recon_batch, z_mean, z_logvar = final_vae(batch_data, batch_lengths)
            loss, mse, kld = vae_loss_function(recon_batch, batch_data, z_mean, z_logvar, batch_lengths)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            total_mse, total_kld = mse.item(), kld.item()

        epo, total_loss, total_mse, total_kld = (epoch+1)/n_epochs_per_iter, total_loss/len(train_loader), total_mse/len(train_loader), total_kld/len(train_loader)
        print(f"{epo:<40} {total_loss:<20} {total_mse:<20} {total_kld:<20}")

    # Latent features for all data
    final_vae.eval()
    with torch.no_grad():
        z_mean_all = []
        for batch_data, batch_lengths in tqdm(full_loader, desc="Extracting latent features", leave=False):
            batch_data = batch_data.to(device)
            batch_lengths = batch_lengths.to(device)

            z_mean, _ = final_vae.encoder(batch_data, batch_lengths)
            z_mean_all.append(z_mean.cpu().numpy())
        
        latent_features = np.concatenate(z_mean_all, axis=0)

    # Final clustering on latent space on all train_dataset, with DBScan
    dbscan = DBSCAN(eps=dbscan_eps, min_samples=dbscan_min_samples)
    clusters = dbscan.fit_predict(latent_features)
    outlier_mask = clusters == -1
    
    return {
        'final_vae': final_vae,
        'latent_features': latent_features,
        'anomaly_indices': all_anomaly_indices,
        'clusters': clusters,
        'outlier_mask': outlier_mask,
        'vae_models': vae_models
    }

In [33]:
n_iterations = 2
n_epochs_per_iter = 2
batch_size = 1
reconstruction_percentile_threshold = 50
kde_bandwidth = 0.5
dbscan_eps = 0.5
dbscan_min_samples = 5
device  = config.DEVICE

results = train_iterative_vae_pipeline(
    train_dataset=train_set,
    n_iterations=n_iterations,
    n_epochs_per_iter=n_epochs_per_iter,
    batch_size=batch_size,
    reconstruction_percentile_threshold=reconstruction_percentile_threshold,
    kde_bandwidth=kde_bandwidth,
    dbscan_eps=dbscan_eps,
    dbscan_min_samples=dbscan_min_samples,
    device=device
)

Iteration 1/2
Training on 3682 samples...
Epoch                                    Loss                 mse                  kld                 


                                                                          

0.5                                      259997949846.5964    77275950.30526887    0.07982567948274545 


                                                                          

1.0                                      259388006326.85715   74225366.97881585    2.474522380839218   


                                                                             

New anomalies detected: 0
Total anomalies: 0

Iteration 2/2
Training on 3682 samples...
Epoch                                    Loss                 mse                  kld                 


                                                                          

0.5                                      259998333782.9093    82269243.51548071    0.0049988540989235  


                                                                          

1.0                                      259383494653.2189    80830245.96197718    0.6236869318856939  


                                                                             

New anomalies detected: 0
Total anomalies: 0

Training final model...
Epoch                                    Loss                 mse                  kld                 


                                                            

OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacity of 14.58 GiB of which 163.12 MiB is free. Process 747003 has 135.45 MiB memory in use. Process 2951200 has 135.51 MiB memory in use. Process 3952432 has 135.51 MiB memory in use. Process 2130804 has 170.70 MiB memory in use. Including non-PyTorch memory, this process has 226.00 MiB memory in use. Of the allocated memory 74.15 MiB is allocated by PyTorch, and 13.85 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
def evaluate_on_test(test_dataset, trained_vae, device):
    test_loader = DataLoader(
        dataset=test_dataset, 
        batch_size=32, 
        shuffle=False, 
        collate_fn=pad_sequences_smartly, 
        drop_last=False
    )
    trained_vae.eval()
    
    reconstruction_errors = []
    latent_features = []
    
    with torch.no_grad():
        for batch_data, batch_lengths in test_loader:
            batch_data = batch_data.to(device)
            batch_lengths = batch_lengths.to(device)
            
            recon_batch, z_mean, _ = trained_vae(batch_data, batch_lengths)
            # Store latent features
            latent_features.append(z_mean.cpu().numpy())

            mse = vae_reconstruction_error(recon_batch, batch_data, batch_lengths)
            # Store reconstruction errors
            reconstruction_errors.extend(mse.cpu().numpy())
    
    return np.array(reconstruction_errors), np.concatenate(latent_features, axis=0)

In [None]:
def train_one_time(
    model: nn.Module,
    train_loader: DataLoader,
    n_epochs: int,
    optimizer: torch.optim.Optimizer,
    device: torch.device
):
    current_data = data.copy()
    anomaly_indices = np.array([], dtype=int)

    reconstruction_error_threshold_percentile = 95


    for epoch in range(n_epochs):
        model.train()
        for batch in train_loader:
            data, lengths = batch
            data = data.to(device)
            lengths = lengths.to(device)

            optimizer.zero_grad()
            output = model(data, lengths)
            loss = criterion(output, data)
            loss.backward()
            optimizer.step()




In [None]:
def train(n_iterations: int):
    for i in range(n_iterations):
        model = VAE()
        train_one_time(model, train_loader, n_epochs, optimizer, device)

In [None]:
from sklearn.ensemble import IsolationForest
