# MNIST Denoising Autoencoder

This notebook implements a Denoising Autoencoder using PyTorch to remove noise from MNIST digits.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Data Preparation
We will download the MNIST dataset and create a custom transform to add Gaussian noise.

In [None]:
# Hyperparameters
BATCH_SIZE = 64
NOISE_FACTOR = 0.5
LEARNING_RATE = 1e-3
EPOCHS = 10

# Transforms
transform = transforms.Compose([
    transforms.ToTensor(),
])

# Download Datasets
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

print("Data loaded successfully.")

## Visualization Helper
Let's visualize what the noisy images look like.

In [None]:
def add_noise(img, noise_factor=0.5):
    noise = torch.randn_like(img) * noise_factor
    noisy_img = img + noise
    return torch.clamp(noisy_img, 0., 1.)

def show_images(original, noisy, denoised=None, num=10):
    plt.figure(figsize=(20, 4))
    for i in range(num):
        # Original
        ax = plt.subplot(3 if denoised is not None else 2, num, i + 1)
        plt.imshow(original[i].reshape(28, 28), cmap='gray')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        if i == 0: ax.set_title("Original")

        # Noisy
        ax = plt.subplot(3 if denoised is not None else 2, num, i + 1 + num)
        plt.imshow(noisy[i].reshape(28, 28), cmap='gray')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        if i == 0: ax.set_title("Noisy")

        # Denoised
        if denoised is not None:
            ax = plt.subplot(3, num, i + 1 + 2*num)
            plt.imshow(denoised[i].reshape(28, 28), cmap='gray')
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)
            if i == 0: ax.set_title("Denoised")
    plt.show()

# Preview noise
dataiter = iter(test_loader)
images, _ = next(dataiter)
noisy_images = add_noise(images, NOISE_FACTOR)
show_images(images, noisy_images, num=10)

## Model Architecture
We'll use a Convolutional Autoencoder. Encoder reduces dimensions, Decoder restores them.

In [None]:
class Denoiser(nn.Module):
    def __init__(self):
        super(Denoiser, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=2, padding=1),  # -> 16x14x14
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1), # -> 32x7x7
            nn.ReLU()
        )
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1), # -> 16x14x14
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1),  # -> 1x28x28
            nn.Sigmoid() # Output pixels between 0 and 1
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

model = Denoiser().to(device)
print(model)

## Training Loop

In [None]:
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

print("Starting training...")
for epoch in range(EPOCHS):
    train_loss = 0.0
    for data in train_loader:
        img, _ = data
        img = img.to(device)
        
        # Add noise
        noisy_img = add_noise(img, NOISE_FACTOR)
        
        optimizer.zero_grad()
        outputs = model(noisy_img)
        loss = criterion(outputs, img)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {train_loss/len(train_loader):.4f}")

print("Training finished.")

## Results Visualization

In [None]:
# Get a batch of test images
dataiter = iter(test_loader)
images, _ = next(dataiter)
images = images.to(device)

# Add noise and inference
noisy_images = add_noise(images, NOISE_FACTOR)
with torch.no_grad():
    outputs = model(noisy_images)

# Move back to CPU for plotting
images = images.cpu()
noisy_images = noisy_images.cpu()
outputs = outputs.cpu()

show_images(images, noisy_images, outputs)