In [None]:
# ============================================================
# Exercise 10 - Part 1: Autoencoder for Defect Detection
# ============================================================

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
import glob

# ============================================================
# 1. Load and Prepare Image Data
# ============================================================

print("="*60)
print("Loading Defect Detection Dataset...")
print("="*60)

def load_images_from_folder(folder_path, img_size=(64, 64)):
    """Load images from a folder and flatten them"""
    images = []
    for img_path in glob.glob(os.path.join(folder_path, '*.png')):
        img = Image.open(img_path).convert('L')  # Grayscale
        img = img.resize(img_size)
        img_array = np.array(img) / 255.0  # Normalize to [0, 1]
        images.append(img_array.flatten())
    return np.array(images)

# Load training data
train_good = load_images_from_folder('defect/train/good')
train_defective = load_images_from_folder('defect/train/defective')
train_data = np.vstack([train_good, train_defective])

print(f"Training - Good images: {len(train_good)}")
print(f"Training - Defective images: {len(train_defective)}")
print(f"Total training images: {len(train_data)}")
print(f"Image shape (flattened): {train_data.shape}")

# Load test data
test_good = load_images_from_folder('defect/test/good')
test_defective = load_images_from_folder('defect/test/defective')
test_data = np.vstack([test_good, test_defective])

# Create labels for test data (0 = good, 1 = defective)
test_labels = np.array([0]*len(test_good) + [1]*len(test_defective))

print(f"\nTest - Good images: {len(test_good)}")
print(f"Test - Defective images: {len(test_defective)}")
print(f"Total test images: {len(test_data)}")

# Visualize some samples
print("\nVisualizing sample images...")
fig, axes = plt.subplots(2, 5, figsize=(12, 5))
img_size = int(np.sqrt(train_data.shape[1]))

for i in range(5):
    # Good image
    axes[0, i].imshow(train_good[i].reshape(img_size, img_size), cmap='gray')
    axes[0, i].set_title('Good')
    axes[0, i].axis('off')
    
    # Defective image
    axes[1, i].imshow(train_defective[i].reshape(img_size, img_size), cmap='gray')
    axes[1, i].set_title('Defective')
    axes[1, i].axis('off')

plt.suptitle('Sample Training Images', fontsize=14)
plt.tight_layout()
plt.show()

# ============================================================
# 2. Define AutoEncoder
# ============================================================

print("\n" + "="*60)
print("Building AutoEncoder Model...")
print("="*60)

input_dim = train_data.shape[1]  # Flattened image size

class AutoEncoder(nn.Module):
    def __init__(self, input_dim, encoding_dim=32):
        super(AutoEncoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, encoding_dim),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Linear(encoding_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, input_dim),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

# Initialize model
model = AutoEncoder(input_dim=input_dim, encoding_dim=32)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

print(f"Input dimension: {input_dim}")
print(f"Encoding dimension: 32")
print("\nModel Architecture:")
print(model)

# ============================================================
# 3. Train AutoEncoder (on GOOD images only)
# ============================================================

print("\n" + "="*60)
print("Training AutoEncoder on GOOD images only...")
print("="*60)

# Convert to PyTorch tensors
X_train_good = torch.FloatTensor(train_good)

num_epochs = 100
train_losses = []

for epoch in range(num_epochs):
    optimizer.zero_grad()
    outputs = model(X_train_good)
    loss = criterion(outputs, X_train_good)
    loss.backward()
    optimizer.step()
    
    train_losses.append(loss.item())
    
    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.6f}')

# Plot training loss
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss', color='blue')
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('MSE Loss', fontsize=12)
plt.title('AutoEncoder Training Loss', fontsize=14, fontweight='bold')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# ============================================================
# 4. Detect Anomalies in Test Data
# ============================================================

print("\n" + "="*60)
print("Detecting Anomalies in Test Data...")
print("="*60)

# Convert test data to tensor
X_test = torch.FloatTensor(test_data)

# Get reconstructions
model.eval()
with torch.no_grad():
    reconstructions = model(X_test)
    mse = torch.mean((X_test - reconstructions) ** 2, dim=1)

reconstruction_errors = mse.numpy()

# Set threshold at 95th percentile
threshold = np.percentile(reconstruction_errors, 95)
predicted_anomalies = reconstruction_errors > threshold

print(f"\nReconstruction Error Statistics:")
print(f"  Mean: {reconstruction_errors.mean():.6f}")
print(f"  Std: {reconstruction_errors.std():.6f}")
print(f"  Min: {reconstruction_errors.min():.6f}")
print(f"  Max: {reconstruction_errors.max():.6f}")
print(f"  Threshold (95th percentile): {threshold:.6f}")

print(f"\nResults:")
print(f"  Total test samples: {len(test_data)}")
print(f"  Predicted anomalies: {np.sum(predicted_anomalies)}")
print(f"  Actual defective: {np.sum(test_labels == 1)}")

# ============================================================
# 5. Visualize Results
# ============================================================

print("\n" + "="*60)
print("Creating Visualizations...")
print("="*60)

# Plot reconstruction error
plt.figure(figsize=(12, 5))
plt.plot(reconstruction_errors, 'o-', markersize=4, alpha=0.6, label='Reconstruction Error')
plt.axhline(threshold, color='r', linestyle='--', linewidth=2, 
            label=f'Threshold (95th percentile = {threshold:.6f})')
plt.xlabel('Test Sample Index', fontsize=12)
plt.ylabel('MSE (Reconstruction Error)', fontsize=12)
plt.title('AutoEncoder: Anomaly Detection Results', fontsize=14, fontweight='bold')
plt.legend(fontsize=10)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Visualize original vs reconstructed images
print("\nVisualizing Original vs Reconstructed Images...")

fig, axes = plt.subplots(4, 6, figsize=(15, 10))

# Show good images (low error)
good_indices = np.where(test_labels == 0)[0][:3]
for i, idx in enumerate(good_indices):
    # Original
    axes[0, i*2].imshow(test_data[idx].reshape(img_size, img_size), cmap='gray')
    axes[0, i*2].set_title(f'Good Original\nError: {reconstruction_errors[idx]:.4f}', fontsize=9)
    axes[0, i*2].axis('off')
    # Reconstructed
    axes[0, i*2+1].imshow(reconstructions[idx].numpy().reshape(img_size, img_size), cmap='gray')
    axes[0, i*2+1].set_title('Reconstructed', fontsize=9)
    axes[0, i*2+1].axis('off')

# Show defective images (should have high error)
defect_indices = np.where(test_labels == 1)[0][:3]
for i, idx in enumerate(defect_indices):
    # Original
    axes[1, i*2].imshow(test_data[idx].reshape(img_size, img_size), cmap='gray')
    axes[1, i*2].set_title(f'Defective Original\nError: {reconstruction_errors[idx]:.4f}', fontsize=9)
    axes[1, i*2].axis('off')
    # Reconstructed
    axes[1, i*2+1].imshow(reconstructions[idx].numpy().reshape(img_size, img_size), cmap='gray')
    axes[1, i*2+1].set_title('Reconstructed', fontsize=9)
    axes[1, i*2+1].axis('off')

# Show highest error samples (predicted anomalies)
highest_error_idx = np.argsort(reconstruction_errors)[-3:][::-1]
for i, idx in enumerate(highest_error_idx):
    # Original
    axes[2, i*2].imshow(test_data[idx].reshape(img_size, img_size), cmap='gray')
    axes[2, i*2].set_title(f'Highest Error\nError: {reconstruction_errors[idx]:.4f}', fontsize=9)
    axes[2, i*2].axis('off')
    # Reconstructed
    axes[2, i*2+1].imshow(reconstructions[idx].numpy().reshape(img_size, img_size), cmap='gray')
    axes[2, i*2+1].set_title('Reconstructed', fontsize=9)
    axes[2, i*2+1].axis('off')

# Show lowest error samples (normal)
lowest_error_idx = np.argsort(reconstruction_errors)[:3]
for i, idx in enumerate(lowest_error_idx):
    # Original
    axes[3, i*2].imshow(test_data[idx].reshape(img_size, img_size), cmap='gray')
    axes[3, i*2].set_title(f'Lowest Error\nError: {reconstruction_errors[idx]:.4f}', fontsize=9)
    axes[3, i*2].axis('off')
    # Reconstructed
    axes[3, i*2+1].imshow(reconstructions[idx].numpy().reshape(img_size, img_size), cmap='gray')
    axes[3, i*2+1].set_title('Reconstructed', fontsize=9)
    axes[3, i*2+1].axis('off')

plt.suptitle('AutoEncoder: Original vs Reconstructed Images', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("\n" + "="*60)
print("AutoEncoder Defect Detection Complete!")
print("="*60)