
# Unsupervised Learning: Autoencoder on MNIST

This notebook is a demo that trains a simple fully connected autoencoder on MNIST:
- Input: **28×28** grayscale (flattened to 784)
- Encoder: 784 → 256 → 9
- Decoder: 9 → 256 → 784
- Loss: Binary Cross Entropy (BCE) on normalized pixels
- Optimizer: Adam


In [None]:

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
from torchinfo import summary

torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else "cpu"))
use_cuda = (device.type == "cuda")
pin_memory = True if use_cuda else False
non_blocking = True if use_cuda else False
device



## Data
We load MNIST, normalize to [0,1], and create train/val loaders.


In [None]:

transform = transforms.Compose([
    transforms.ToTensor(),  # [0,1]
])

data_root = "./data"
train_dataset_full = datasets.MNIST(root=data_root, train=True, download=True, transform=transform)

val_size = 5000
train_size = len(train_dataset_full) - val_size
train_dataset, val_dataset = random_split(train_dataset_full, [train_size, val_size])

batch_size = 256
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=pin_memory)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=pin_memory)

print( "Number of samples in train dataset:", len(train_dataset) )
print( "Number of samples in validation dataset:", len(val_dataset) )



## Model
A symmetric fully connected autoencoder with a **9-dim latent** (displayed as 3×3).


In [None]:

class AE(nn.Module):
    def __init__(self, latent_dim=9):
        super().__init__()
        self.latent_dim = latent_dim
        # Encoder (removed 64-unit hidden layer)
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, latent_dim),
        )
        # Decoder (symmetric, no 64-unit hidden layer)
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 28*28),
            nn.Sigmoid(),
        )

    def forward(self, x):
        b = x.size(0)
        x = x.view(b, -1)
        z = self.encoder(x)
        xhat = self.decoder(z)
        xhat = xhat.view(b, 1, 28, 28)
        return xhat, z

model = AE(latent_dim=9).to(device)
model



## Architecture Diagram (schematic)


In [None]:

layers = [
    "784 (28x28)",
    "256",
    "9",
    "256",
    "784 (28x28)"]

# Horizontal layout: draw layer boxes left -> right and arrows between them
plt.figure(figsize=(12, 2))
ax = plt.gca()
n = len(layers)
ax.set_xlim(0, n + 1)
ax.set_ylim(0, 1)
for i, name in enumerate(layers):
    x = i + 1
    # rectangle centered vertically
    rect_y = 0.2
    rect_h = 0.6
    ax.add_patch(plt.Rectangle((x - 0.4, rect_y), 0.8, rect_h, fill=False))
    ax.text(x, rect_y + rect_h / 2, name, ha='center', va='center', wrap=True)
    if i < n - 1:
        # arrow pointing right — start at right edge (x+0.4) and end at next box left edge (x+0.6) so dx=0.2
        ax.arrow(x + 0.4, 0.5, 0.2, 0, length_includes_head=True, head_width=0.03, head_length=0.03)

ax.axis('off')
plt.title("Autoencoder (Fully Connected)")
plt.show()



## Training
We use BCE loss on pixel values in [0,1].


In [None]:

epochs = 10
# ensure model parameters are on the chosen device before training
model.to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.BCELoss()

def run_epoch(loader, train=True):
    model.train(train)
    total = 0.0
    count = 0
    for x, _ in loader:
        # move inputs to same device as model; non_blocking only used for CUDA
        x = x.to(device, non_blocking=non_blocking)
        xhat, _ = model(x)
        loss = criterion(xhat, x)
        if train:
            opt.zero_grad(set_to_none=True)
            loss.backward()
            opt.step()
        total += loss.item() * x.size(0)
        count += x.size(0)
    return total / count

train_losses, val_losses = [], []
for ep in range(1, epochs+1):
    tr = run_epoch(train_loader, train=True)
    vl = run_epoch(val_loader, train=False)
    train_losses.append(tr); val_losses.append(vl)
    print(f"Epoch {ep:02d}/{epochs} | train {tr:.4f} | val {vl:.4f}")

plt.figure(figsize=(5,3))
plt.plot(range(1, epochs+1), train_losses, marker="o", label="train")
plt.plot(range(1, epochs+1), val_losses, marker="o", label="val")
plt.xlabel("Epoch")
plt.ylabel("BCE loss")
plt.title("Training & Validation Loss")
plt.legend()
plt.grid(True)
plt.show()



## Reconstructions and Original → Compressed → Decompressed examples
Below: for several validation samples show Original image (left), compressed latent (middle, 3×3 heatmap), and decompressed reconstruction (right).


In [None]:

model.eval()
x, _ = next(iter(val_loader))
x = x[:16].to(device)
with torch.no_grad():
    xhat, z = model(x)

import matplotlib
from torchvision.utils import make_grid

def show_triplets(orig_batch, z_batch, recon_batch, n=8):
    n = min(n, orig_batch.size(0))
    orig = orig_batch[:n].cpu()
    recon = recon_batch[:n].cpu()
    # reshape latent to (n, 3, 3)
    z = z_batch[:n].view(-1, 3, 3).cpu().numpy()

    fig, axes = plt.subplots(nrows=n, ncols=3, figsize=(6, 2*n))
    if n == 1:
        axes = axes.reshape(1,3)
    for i in range(n):
        ax0 = axes[i,0]
        ax1 = axes[i,1]
        ax2 = axes[i,2]
        # Original
        ax0.imshow(orig[i].squeeze(), cmap='gray')
        ax0.axis('off')
        if i == 0:
            ax0.set_title('Original (28x28)')
        # Compressed: normalize per-sample and show as 3x3 black & white image (no interpolation blur)
        zi = z[i]
        zn = (zi - zi.min()) / (zi.max() - zi.min() + 1e-8)
        ax1.imshow(zn, cmap='gray', interpolation='nearest', aspect='equal')
        ax1.set_xticks([]); ax1.set_yticks([])
        if i == 0:
            ax1.set_title('Compressed (3×3)')
        # Reconstruction
        ax2.imshow(recon[i].squeeze(), cmap='gray')
        ax2.axis('off')
        if i == 0:
            ax2.set_title('Reconstruction (28x28)')
    plt.tight_layout()
    plt.show()

# show first 8 triplets
show_triplets(x, z, xhat, n=8)
