# VT-AE Training with MDN for Anomaly Detection
Author: Pankaj Mishra

This notebook trains a Vision Transformer Autoencoder (VT-AE) with Mixture Density Network (MDN) for anomaly detection on the MVTech dataset.

## 1. Import Libraries

In [1]:
import torch
import mvtech
import torchvision.utils as utils
import matplotlib.pyplot as plt
from torch.optim import Adam
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import os
import numpy as np
import pytorch_ssim
import mdn1
from VT_AE import VT_AE as ae
import argparse

  from .autonotebook import tqdm as notebook_tqdm


## 2. Configuration Parameters
Set your training parameters here instead of using command line arguments.

In [None]:
# Configuration parameters (modify these as needed)
config = {
    'product': 'hazelnut',  # product from the dataset MvTec or BTAD
    'epochs': 400,          # Number of epochs to train
    'learning_rate': 0.0001, # learning rate
    'patch_size': 16,       # Patch size of the images (match pretrained ViT)
    'batch_size': 8         # batch size
}

print("Configuration:")
for key, value in config.items():
    print(f"{key}: {value}")

print("\nNote: Using patch_size=32 to match pretrained ViT-Large model.")

Configuration:
product: hazelnut
epochs: 400
learning_rate: 0.0001
patch_size: 32
batch_size: 8

Note: Using patch_size=32 to match pretrained ViT-Large model.


## 3. Initialize Training Components

In [3]:
# Initialize tensorboard writer
writer = SummaryWriter()

# Training variables
prdt = config["product"]
epoch = config["epochs"]
minloss = 1e10
ep = 0

# SSIM Loss
ssim_loss = pytorch_ssim.SSIM()

print(f"Training setup complete for product: {prdt}")
print(f"Training for {epoch} epochs")

Training setup complete for product: hazelnut
Training for 400 epochs


## 4. Load Dataset

In [4]:
# Dataset
root = "/Users/dania/code/EmergingTechnologies/VT-ADL/datasets/mvtec"
data = mvtech.Mvtec(1, root, "bottle")


total train images of good bottle are: 209
total test images of broken_small bottle are: 22
total test images of broken_large bottle are: 20
the good images for test images of good bottle is not included in the test anomolous data
total test images of contamination bottle are: 21
total ground_truth images of broken_small bottle are: 22
total ground_truth images of broken_large bottle are: 20
total ground_truth images of contamination bottle are: 21
total test images of good bottle are: 20
 --Size of bottle train loader: torch.Size([209, 3, 512, 512])--
 --Size of bottle test anomaly loader: torch.Size([63, 3, 512, 512])--
 --Size of bottle test normal loader: torch.Size([20, 3, 512, 512])--
 --Total Image in bottle Validation loader: 20--


## 5. Initialize Models

In [None]:
# Model declaration
model = ae(patch_size=config["patch_size"], train=True, 
          pretrained_vit_name='microsoft/beit-large-patch16-512',
          load_pretrained=True)#.cuda()
G_estimate = mdn1.MDN()#.cuda()

# Put models in training mode
# (The two models are trained as a separate module so that it would be easy to use as an independent module in different scenarios)
model.train()
G_estimate.train()

print("Models initialized")
print(f"Patch size: {config['patch_size']}")

You are using a model of type beit to instantiate a model of type vit. This is not supported for all configurations of models and can yield errors.


## 6. Initialize Optimizer

In [6]:
# Optimizer Declaration
Optimiser = Adam(
    list(model.parameters()) + list(G_estimate.parameters()), 
    lr=config["learning_rate"], 
    weight_decay=0.0001
)

print(f"Optimizer initialized with learning rate: {config['learning_rate']}")

Optimizer initialized with learning rate: 0.0001


## 7. Training Loop

In [7]:
# Uncomment the line below if you want to track errors
# torch.autograd.set_detect_anomaly(True)

print('\nNetwork training started.....')

for i in range(epoch):
    t_loss = []
    
    for j, m in data.train_loader:
        # Handle grayscale images by converting to 3-channel
        if j.size(1) == 1:
            j = torch.stack([j, j, j]).squeeze(2).permute(1, 0, 2, 3)
        
        model.zero_grad()
        
        # Forward pass
        vector, reconstructions = model(j) #j.cuda()
        pi, mu, sigma = G_estimate(vector)
        
        # Loss calculations
        loss1 = F.mse_loss(reconstructions, j, reduction='mean')  # Reconstruction Loss #j.cuda()
        loss2 = -ssim_loss(j, reconstructions)  # SSIM loss for structural similarity
        loss3 = mdn1.mdn_loss_function(vector, mu, sigma, pi)  # MDN loss for gaussian approximation
        
        print(f'loss3: {loss3.item()}')
        loss = 5 * loss1 + 0.5 * loss2 + loss3  # Total loss
        
        t_loss.append(loss.item())  # storing all batch losses to calculate mean epoch loss
        
        # Tensorboard logging
        writer.add_scalar('recon-loss', loss1.item(), i)
        writer.add_scalar('ssim loss', loss2.item(), i)
        writer.add_scalar('Gaussian loss', loss3.item(), i)
        writer.add_histogram('Vectors', vector)
        
        ## Uncomment below to store the distributions of pi, var and mean ##        
        # writer.add_histogram('Pi', pi)
        # writer.add_histogram('Variance', sigma)
        # writer.add_histogram('Mean', mu)

        # Backward pass and optimization
        loss.backward()
        Optimiser.step()
    
    # Log epoch-level information
    writer.add_image('Reconstructed Image', utils.make_grid(reconstructions), i, dataformats='CHW')
    writer.add_scalar('Mean Epoch loss', np.mean(t_loss), i)
    
    print(f'Mean Epoch {i} loss: {np.mean(t_loss)}')
    print(f'Min loss epoch: {ep} with min loss: {minloss}')
    
    # Save the best model
    if np.mean(t_loss) <= minloss:
        minloss = np.mean(t_loss)
        ep = i
        os.makedirs('./saved_model', exist_ok=True)
        torch.save(model.state_dict(), f'./saved_model/VT_AE_Mvtech_{prdt}.pt')
        torch.save(G_estimate.state_dict(), f'./saved_model/G_estimate_Mvtech_{prdt}.pt')
        print(f"New best model saved at epoch {i} with loss {minloss}")

writer.close()
print("Training completed!")


Network training started.....
loss3: 87506.96875


  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


KeyboardInterrupt: 

## 8. Training Summary

In [None]:
print("\n=== Training Summary ===")
print(f"Product: {prdt}")
print(f"Total epochs: {epoch}")
print(f"Best model saved at epoch: {ep}")
print(f"Best loss achieved: {minloss}")
print(f"Models saved in: ./saved_model/")

## Notes

**Abbreviations used:**
- **GN** - Gaussian Noise
- **LD** - Linear Decoder  
- **DR** - Dynamic Routing
- **Gn** - No of gaussian for the estimation of density, with n as the number
- **Pn** - Patch with n is dim of patch
- **SS** - trained with ssim loss

**Loss Components:**
1. **Reconstruction Loss (MSE)**: Measures pixel-wise difference between input and reconstructed images
2. **SSIM Loss**: Structural similarity loss for perceptual quality
3. **MDN Loss**: Mixture Density Network loss for gaussian approximation of the latent space

**Total Loss**: `5 * reconstruction_loss + 0.5 * ssim_loss + mdn_loss`