# 🔍 MVTec Anomaly Detection - Production Training

**Multi-Category Autoencoder Training for Industrial Anomaly Detection**

This notebook implements a production-ready training pipeline for anomaly detection using a UNet-style autoencoder on the MVTec AD dataset. The model learns to reconstruct normal images from all categories (screw, capsule, hazelnut) and detects anomalies based on reconstruction error.

## Key Features:
- ✅ **Multi-category training** on all "good" images
- ✅ **Production-ready architecture** with proper error handling  
- ✅ **Comprehensive evaluation** with metrics and visualizations
- ✅ **Optimized threshold selection** for anomaly detection
- ✅ **Model persistence** for deployment

---

## 1️⃣ Environment Setup and Dependencies

In [None]:
# Install required packages
%pip install torch torchvision matplotlib seaborn scikit-learn pandas pillow numpy -q

print("✅ Package installation completed!")

In [None]:
import sys
import os
from pathlib import Path
import glob
import json

# Add project src to Python path
sys.path.insert(0, os.path.abspath("../src"))
sys.path.insert(0, os.path.abspath("../"))

# Core libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as transforms
from PIL import Image

# Data science and visualization
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Machine learning metrics
from sklearn.metrics import (
    confusion_matrix, accuracy_score, precision_score, 
    recall_score, f1_score, roc_auc_score, roc_curve
)

# Set plotting style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

# Check environment
print("🔧 Environment Information:")
print(f"   PyTorch version: {torch.__version__}")
print(f"   CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"   CUDA device: {torch.cuda.get_device_name(0)}")
print(f"   Working directory: {os.getcwd()}")
print(f"   Python version: {sys.version.split()[0]}")

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"   Using device: {device}")

print("\n✅ Environment setup completed!")

## 2️⃣ Dataset Loading and Preprocessing

We load all "good" images from all categories for training and prepare test sets for evaluation.

In [None]:
# Dataset configuration
DATA_ROOT = "../data/mvtec_ad"
CATEGORIES = ["screw", "capsule", "hazelnut"]
IMAGE_SIZE = 128
BATCH_SIZE = 32

# Image transformations for training
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # Normalize to [-1, 1]
])

class MVTecDataset(Dataset):
    """Custom dataset for MVTec AD images."""
    
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        return image, image_path

print("✅ Dataset class and transforms defined!")

In [None]:
# Load training images (all "good" images from all categories)
print("📂 Loading training images...")

train_images = []
category_counts = {}

for category in CATEGORIES:
    train_path = f"{DATA_ROOT}/{category}/train/good"
    if os.path.exists(train_path):
        category_images = glob.glob(f"{train_path}/*.png")
        train_images.extend(category_images)
        category_counts[category] = len(category_images)
        print(f"   {category}: {len(category_images)} images")
    else:
        print(f"   ⚠️ Warning: {train_path} not found")

print(f"\n📊 Total training images: {len(train_images)}")
print(f"📊 Category distribution: {category_counts}")

# Create training dataset and dataloader
train_dataset = MVTecDataset(train_images, transform=transform)
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=2,
    pin_memory=True if torch.cuda.is_available() else False
)

print(f"✅ Training DataLoader created with batch size {BATCH_SIZE}")
print(f"   Total batches per epoch: {len(train_loader)}")

### 📊 Exploratory Data Analysis (EDA)

Let's analyze the training images to check for data quality issues and ensure consistency across categories.

In [None]:
# 1. Analyze image dimensions and file sizes
print("🔍 Analyzing image properties...")

image_info = []
corrupted_files = []

# Sample analysis (check first 50 images to avoid long processing)
sample_images = train_images[:min(50, len(train_images))]

for img_path in sample_images:
    try:
        with Image.open(img_path) as img:
            # Extract category from path
            category = img_path.split('/')[-4]
            
            image_info.append({
                'path': img_path,
                'category': category,
                'width': img.width,
                'height': img.height,
                'mode': img.mode,
                'size_kb': os.path.getsize(img_path) / 1024
            })
    except Exception as e:
        corrupted_files.append({'path': img_path, 'error': str(e)})

# Create DataFrame for analysis
df_images = pd.DataFrame(image_info)

print(f"📊 Sample Analysis Results ({len(sample_images)} images):")
print(f"   Corrupted files: {len(corrupted_files)}")

if len(corrupted_files) > 0:
    print("⚠️ Corrupted files found:")
    for cf in corrupted_files:
        print(f"   - {cf['path']}: {cf['error']}")

# Display basic statistics
if len(df_images) > 0:
    print(f"\n📏 Image Dimensions:")
    print(f"   Width range: {df_images['width'].min()}-{df_images['width'].max()}")
    print(f"   Height range: {df_images['height'].min()}-{df_images['height'].max()}")
    print(f"   Most common size: {df_images.groupby(['width', 'height']).size().idxmax()}")
    
    print(f"\n🎨 Image Modes:")
    print(df_images['mode'].value_counts())
    
    print(f"\n💾 File Sizes:")
    print(f"   Average: {df_images['size_kb'].mean():.1f} KB")
    print(f"   Range: {df_images['size_kb'].min():.1f}-{df_images['size_kb'].max():.1f} KB")
    
    print(f"\n📂 Category Distribution (sample):")
    print(df_images['category'].value_counts())

In [None]:
# 2. Visual inspection of sample images
print("\n🖼️ Visual Sample Inspection:")

# Select 2 random images per category for visual check
fig, axes = plt.subplots(3, 2, figsize=(10, 12))
fig.suptitle("Sample Training Images by Category", fontsize=16)

for i, category in enumerate(CATEGORIES):
    category_images = [img for img in train_images if f'/{category}/' in img]
    
    if len(category_images) >= 2:
        # Select 2 random samples
        samples = np.random.choice(category_images, 2, replace=False)
        
        for j, img_path in enumerate(samples):
            try:
                img = Image.open(img_path)
                axes[i, j].imshow(img)
                axes[i, j].set_title(f"{category} - {img.size}")
                axes[i, j].axis('off')
            except Exception as e:
                axes[i, j].text(0.5, 0.5, f"Error loading\n{e}", 
                               ha='center', va='center', transform=axes[i, j].transAxes)
                axes[i, j].set_title(f"{category} - ERROR")

plt.tight_layout()
plt.show()

print("✅ Visual inspection completed!")

In [None]:
# 3. Pixel intensity analysis
print("\n📈 Pixel Intensity Analysis:")

# Sample a few images for pixel analysis
sample_size = min(10, len(train_images))
pixel_stats = []

for img_path in np.random.choice(train_images, sample_size, replace=False):
    try:
        img = Image.open(img_path).convert('RGB')
        img_array = np.array(img)
        
        # Calculate statistics per channel
        for channel, name in enumerate(['R', 'G', 'B']):
            channel_data = img_array[:, :, channel]
            pixel_stats.append({
                'category': img_path.split('/')[-4],
                'channel': name,
                'mean': channel_data.mean(),
                'std': channel_data.std(),
                'min': channel_data.min(),
                'max': channel_data.max()
            })
    except Exception as e:
        print(f"⚠️ Error processing {img_path}: {e}")

if pixel_stats:
    df_pixels = pd.DataFrame(pixel_stats)
    
    print("📊 Pixel Intensity Statistics (0-255 range):")
    print(f"   Mean values: {df_pixels['mean'].mean():.1f} ± {df_pixels['mean'].std():.1f}")
    print(f"   Min values: {df_pixels['min'].min():.0f}")
    print(f"   Max values: {df_pixels['max'].max():.0f}")
    
    # Check for any unusual ranges
    if df_pixels['min'].min() < 0:
        print("⚠️ Warning: Found negative pixel values!")
    if df_pixels['max'].max() > 255:
        print("⚠️ Warning: Found pixel values > 255!")
    
    print("\n📈 Statistics by Category:")
    category_stats = df_pixels.groupby('category')['mean'].agg(['mean', 'std']).round(1)
    print(category_stats)

print("\n✅ Pixel analysis completed!")

In [None]:
# 4. ETL Requirements Assessment
print("\n🔍 ETL ASSESSMENT:")
print("=" * 50)

# Check if additional ETL is needed
etl_needed = []
etl_recommendations = []

# Check 1: Consistent image dimensions
if len(df_images) > 0:
    unique_sizes = df_images.groupby(['width', 'height']).size()
    if len(unique_sizes) > 1:
        etl_needed.append("Image resizing")
        etl_recommendations.append("✅ HANDLED: Images resized to 128x128 in transforms")
    else:
        etl_recommendations.append("✅ OK: All images have consistent dimensions")

# Check 2: Color mode consistency
if len(df_images) > 0 and len(df_images['mode'].unique()) > 1:
    etl_needed.append("Color mode standardization")
    etl_recommendations.append("✅ HANDLED: Images converted to RGB in dataset")
else:
    etl_recommendations.append("✅ OK: Consistent color modes")

# Check 3: Corrupted files
if len(corrupted_files) > 0:
    etl_needed.append("Remove corrupted files")
    etl_recommendations.append(f"⚠️ ACTION NEEDED: Remove {len(corrupted_files)} corrupted files")
else:
    etl_recommendations.append("✅ OK: No corrupted files detected")

# Check 4: Pixel value normalization
etl_recommendations.append("✅ HANDLED: Pixel normalization to [-1,1] in transforms")

# Summary
print("📋 ETL STATUS:")
if len(etl_needed) == 0 or all("HANDLED" in rec for rec in etl_recommendations):
    print("🎉 READY FOR TRAINING!")
    print("   All necessary ETL operations are handled by the pipeline")
else:
    print("⚠️ ADDITIONAL ETL NEEDED:")
    for need in etl_needed:
        print(f"   - {need}")

print("\n📝 ETL PIPELINE SUMMARY:")
for rec in etl_recommendations:
    print(f"   {rec}")

print("\n🔧 CURRENT ETL OPERATIONS:")
print("   1. ✅ Image loading and RGB conversion")
print("   2. ✅ Resize to 128x128 pixels")  
print("   3. ✅ Tensor conversion")
print("   4. ✅ Normalization to [-1, 1] range")
print("   5. ✅ Batch loading with DataLoader")

print("\n💡 CONCLUSION:")
print("   The current ETL pipeline is sufficient for training!")
print("   No additional preprocessing steps required.")

print("\n✅ EDA completed - Ready to proceed with training!")

## 3️⃣ Model Architecture Definition

We use a UNet-style autoencoder with skip connections for better feature preservation.

In [None]:
class AutoencoderUNetLite(nn.Module):
    """
    UNet-style Autoencoder for anomaly detection.
    
    Architecture:
    - Encoder: 3 levels with max pooling
    - Bottleneck: Feature compression  
    - Decoder: 3 levels with transposed convolutions
    - Skip connections: Preserve fine-grained features
    """
    
    def __init__(self, in_channels=3, out_channels=3):
        super(AutoencoderUNetLite, self).__init__()
        
        # Encoder path
        self.enc1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.enc2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.enc3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        
        # Decoder path
        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),  # 512 due to skip connection
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding=1),  # 256 due to skip connection
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),   # 128 due to skip connection
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        
        # Final output layer
        self.final = nn.Conv2d(64, out_channels, kernel_size=1)
        self.tanh = nn.Tanh()  # Output in [-1, 1] range
    
    def forward(self, x):
        # Encoder path with skip connection storage
        enc1 = self.enc1(x)          # 128x128x64
        pool1 = self.pool1(enc1)     # 64x64x64
        
        enc2 = self.enc2(pool1)      # 64x64x128
        pool2 = self.pool2(enc2)     # 32x32x128
        
        enc3 = self.enc3(pool2)      # 32x32x256
        pool3 = self.pool3(enc3)     # 16x16x256
        
        # Bottleneck
        bottleneck = self.bottleneck(pool3)  # 16x16x512
        
        # Decoder path with skip connections
        up3 = self.upconv3(bottleneck)       # 32x32x256
        merge3 = torch.cat([up3, enc3], dim=1)  # 32x32x512
        dec3 = self.dec3(merge3)             # 32x32x256
        
        up2 = self.upconv2(dec3)             # 64x64x128
        merge2 = torch.cat([up2, enc2], dim=1)  # 64x64x256
        dec2 = self.dec2(merge2)             # 64x64x128
        
        up1 = self.upconv1(dec2)             # 128x128x64
        merge1 = torch.cat([up1, enc1], dim=1)  # 128x128x128
        dec1 = self.dec1(merge1)             # 128x128x64
        
        # Final output
        output = self.final(dec1)            # 128x128x3
        return self.tanh(output)

# Create model instance
model = AutoencoderUNetLite().to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("🏗️ Model Architecture:")
print(f"   Model: AutoencoderUNetLite")
print(f"   Input shape: {IMAGE_SIZE}x{IMAGE_SIZE}x3")
print(f"   Output shape: {IMAGE_SIZE}x{IMAGE_SIZE}x3")
print(f"   Total parameters: {total_params:,}")
print(f"   Trainable parameters: {trainable_params:,}")
print(f"   Model size: ~{total_params * 4 / 1024 / 1024:.1f} MB")

print("\n✅ Model architecture defined and initialized!")