# Variational Autoencoder (VAE) for MNIST Anomaly Detection

This notebook demonstrates training and evaluation of a VAE model on MNIST data for anomaly detection.  
We use reconstruction loss to identify abnormal samples and evaluate the model with ROC, PR curves, and confusion matrix.

---

### Contents
- Setup and imports
- Model training
- Threshold selection using reconstruction loss
- Evaluation and visualization
- Conclusion


In [None]:
# Setup and Imports
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

from model import VAE_MNIST
from dataloader_generator import train_dl, val_dl, test_dl
from utils import (
    train_model, get_recon_losses_per_image_after_training,
    show_reconstructions, plot_confusion_matrix, plot_roc_pr
)

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


## Hyperparameters and Setup

We set the random seed for reproducibility and define model parameters, training epochs, and loss type.


In [None]:
# Reproducibility and Hyperparameters
seed = 15
torch.manual_seed(seed)

encoding_size = 10
input_channel = 1
learning_rate = 1e-3
LOSS_TYPE = 'bce'
num_epochs = 300

index_to_name = {0: 'normal', 1: 'abnormal'}
name_to_index = {'normal': 0, 'abnormal': 1}


## Initialize the Model, Optimizer, and Scheduler

We instantiate the VAE model, define the Adam optimizer, and setup a cosine annealing learning rate scheduler.


In [None]:
# Model, optimizer and scheduler setup
model = VAE_MNIST(input_channel, encoding_size, drop_rate=0.1, multiple=4, skip_connect=False).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)


## Train the Model

We train the model for the specified number of epochs, tracking training and validation losses.


In [None]:
# Train the model
train_loss, val_loss, train_recon_loss, train_kl_loss = train_model(
    model, num_epochs, train_dl, val_dl, test_dl, optimizer,
    loss_type=LOSS_TYPE, scheduler=scheduler, clip_norm=True, max_norm=100.0
)


## Determine Anomaly Detection Threshold

Using the training set reconstruction losses, we calculate a threshold as mean + 2*std deviation.
This threshold will be used to classify test images as normal or abnormal.


In [None]:
train_losses = get_recon_losses_per_image_after_training(model, train_dl, LOSS_TYPE)
test_losses = get_recon_losses_per_image_after_training(model, test_dl, LOSS_TYPE)
THRESHOLD = train_losses.mean() + 2 * train_losses.std()

print(f"Anomaly detection threshold set at: {THRESHOLD:.4f}")


## Visualize Training Loss Components Over Epochs


In [None]:
plt.plot(train_loss, label='Total Train Loss')
plt.plot(val_loss, label='Total Val Loss')
plt.plot(train_recon_loss, label='Recon Loss')
plt.plot(train_kl_loss, label='KL Loss')
plt.title("VAE Loss Components Over Epochs")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)
plt.show()


## Visualize Reconstructions of Abnormal Samples

We extract samples labeled as abnormal from the test set and display their reconstructions.


In [None]:
sample_fives = []
for x, y in test_dl:
    r = x[y == 1]
    sample_fives.extend(r)
sample_fives_tensor = torch.stack(sample_fives)

show_reconstructions(model, dataloader=None, sample_input=sample_fives_tensor, num_images=20)


## Confusion Matrix on Test Set

Using the threshold, classify test samples and plot the confusion matrix.


In [None]:
plot_confusion_matrix(
    model, test_dl, threshold=THRESHOLD,
    labels=['normal', 'abnormal'], normalize=False,
    title='Confusion Matrix', loss_type=LOSS_TYPE
)


## ROC and Precision-Recall Curves

Evaluate the model performance with ROC and PR curves to understand classification quality.


In [None]:
plot_roc_pr(model, test_dl)


## Reconstruction Loss Distribution

Plot the distribution of reconstruction losses for train (normal) and test samples with threshold overlay.


In [None]:
sns.histplot(train_losses, label='Train (normal)', stat='density', kde=True)
sns.histplot(test_losses, label='Test', stat='density', kde=True)
plt.axvline(THRESHOLD, color='red', linestyle='--', label='Threshold')
plt.title("Reconstruction Loss Distribution")
plt.legend()
plt.grid(True)
plt.show()


# Conclusion

- The VAE model is trained to reconstruct normal MNIST digits and detect anomalies based on reconstruction loss.
- Thresholding reconstruction loss effectively separates normal and abnormal samples.
- Performance evaluation via confusion matrix and ROC/PR curves shows the model’s anomaly detection capability.
- Reconstruction loss distribution illustrates clear separation with the chosen threshold.

This pipeline can be extended to other datasets and architectures for anomaly detection tasks.
