#### Exercise 10 solution Autoencoder

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import glob
import numpy as np

In [2]:
# Setup
IMAGE_SIZE = 128
BATCH_SIZE = 32
LEARNING_RATE = 1e-3
NUM_EPOCHS = 20
TRAIN_GOOD_DIR = './train/good' # Folder containing only normal/good images

# Check for MPS first, then CUDA, then fall back to CPU
if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
elif torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")

print(f"Using device: {DEVICE}")

Using device: mps


In [3]:
# Create Dataset Class and dataloaders

class DefectDataset(Dataset):
    """Custom Dataset for loading images from a directory."""
    def __init__(self, root_dir, transform=None):
        self.image_paths = sorted(glob.glob(os.path.join(root_dir, '*.png')) +
                                  glob.glob(os.path.join(root_dir, '*.jpg')))
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        # The autoencoder is trained unsupervised, so the target is the input itself
        return image, image 

# Image transformations
data_transforms = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
])

# Load the training data (ONLY normal/good images)
train_dataset = DefectDataset(root_dir=TRAIN_GOOD_DIR, transform=data_transforms)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

print(f"Loaded {len(train_dataset)} normal images for training.")

Loaded 2463 normal images for training.


In [4]:
# AUTOENCODER MODEL

class ConvolutionalAutoencoder(nn.Module):
    def __init__(self):
        super(ConvolutionalAutoencoder, self).__init__()
        
        # Encoder (Downsampling to Latent Space)
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1), # 128x128 -> 128x128
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),                      # 128x128 -> 64x64
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),                      # 64x64 -> 32x32
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2)                       # 32x32 -> 16x16 (Latent Space)
        )
        
        # Decoder (Upsampling from Latent Space)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2), # 16x16 -> 32x32
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2), # 32x32 -> 64x64
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 3, kernel_size=2, stride=2),  # 64x64 -> 128x128
            nn.Sigmoid() # Use Sigmoid to ensure output is in [0, 1] range for image pixels
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

model = ConvolutionalAutoencoder().to(DEVICE)
criterion = nn.MSELoss() # Mean Squared Error is the standard reconstruction loss
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)



In [5]:
# TRAINING LOOP

print("\nStarting Training...")
model.train()
for epoch in range(NUM_EPOCHS):
    running_loss = 0.0
    for data in train_loader:
        inputs, targets = data # targets are the same as inputs (unsupervised)
        inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)

    epoch_loss = running_loss / len(train_dataset)
    print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], Loss: {epoch_loss:.6f}')

print("Training finished.")




Starting Training...
Epoch [1/20], Loss: 0.154788
Epoch [2/20], Loss: 0.000327
Epoch [3/20], Loss: 0.000038
Epoch [4/20], Loss: 0.000021
Epoch [5/20], Loss: 0.000014
Epoch [6/20], Loss: 0.000010
Epoch [7/20], Loss: 0.000007
Epoch [8/20], Loss: 0.000005
Epoch [9/20], Loss: 0.000004
Epoch [10/20], Loss: 0.000003
Epoch [11/20], Loss: 0.000003
Epoch [12/20], Loss: 0.000002
Epoch [13/20], Loss: 0.000002
Epoch [14/20], Loss: 0.000002
Epoch [15/20], Loss: 0.000002
Epoch [16/20], Loss: 0.000001
Epoch [17/20], Loss: 0.000001
Epoch [18/20], Loss: 0.000001
Epoch [19/20], Loss: 0.000001
Epoch [20/20], Loss: 0.000001
Training finished.


In [6]:
# EVALUATION/DEFECT IDENTIFICATION ON TEST DATA

def identify_defect_score(model, image_path, threshold=0.000001):
    """
    Calculates the reconstruction error for an image and classifies it.
    The threshold is typically determined experimentally on a validation set.
    """
    model.eval()
    
    # Load and transform image
    image = Image.open(image_path).convert('RGB')
    input_tensor = data_transforms(image).unsqueeze(0).to(DEVICE)
    
    # Reconstruct
    with torch.no_grad():
        reconstructed_tensor = model(input_tensor)
    
    # Calculate Reconstruction Error (MSE)
    # The loss function is what we use to measure error
    reconstruction_error = criterion(reconstructed_tensor, input_tensor).item()
    
    # Classification
    is_defective = reconstruction_error > threshold
    
    print(f"\nImage: {os.path.basename(image_path)}")
    print(f"Reconstruction Error (MSE): {reconstruction_error:.6f}")
    print(f"Classification: {'DEFECTIVE (Anomaly)' if is_defective else 'NORMAL'}")
    
    return reconstruction_error, is_defective, reconstructed_tensor.squeeze(0)


print("\nModel saved to 'autoencoder_defect_model.pth'")
torch.save(model.state_dict(), 'autoencoder_defect_model.pth')


Model saved to 'autoencoder_defect_model.pth'


In [7]:
# Predict
list_def_images = os.listdir('./test/defective')

for i in list_def_images:
    image = './test/defective/'+i
    _,isdefect,_ = identify_defect_score(model,image)
    if isdefect:
        print(image)


Image: 4983.png
Reconstruction Error (MSE): 0.011640
Classification: DEFECTIVE (Anomaly)
./test/defective/4983.png

Image: 4996.png
Reconstruction Error (MSE): 0.000001
Classification: NORMAL

Image: 4957.png
Reconstruction Error (MSE): 0.006133
Classification: DEFECTIVE (Anomaly)
./test/defective/4957.png

Image: 4943.png
Reconstruction Error (MSE): 0.001521
Classification: DEFECTIVE (Anomaly)
./test/defective/4943.png

Image: 4994.png
Reconstruction Error (MSE): 0.000001
Classification: NORMAL

Image: 4980.png
Reconstruction Error (MSE): 0.007685
Classification: DEFECTIVE (Anomaly)
./test/defective/4980.png

Image: 4995.png
Reconstruction Error (MSE): 0.000001
Classification: NORMAL

Image: 4946.png
Reconstruction Error (MSE): 0.000001
Classification: NORMAL

Image: 4990.png
Reconstruction Error (MSE): 0.002462
Classification: DEFECTIVE (Anomaly)
./test/defective/4990.png

Image: 4947.png
Reconstruction Error (MSE): 0.005178
Classification: DEFECTIVE (Anomaly)
./test/defective/4947.