# P2: Autoencoder on MNIST

**Objective:** Train a simple autoencoder to compress and reconstruct MNIST digits.

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data()
x_train = x_train.astype('float32')/255.0; x_test = x_test.astype('float32')/255.0
x_train = x_train.reshape(-1, 28*28); x_test = x_test.reshape(-1, 28*28)
input_dim = 28*28
encoding_dim = 64
input_layer = layers.Input(shape=(input_dim,))
encoded = layers.Dense(encoding_dim, activation='relu')(input_layer)
decoded = layers.Dense(input_dim, activation='sigmoid')(encoded)
autoencoder = models.Model(input_layer, decoded)
autoencoder.compile(optimizer='adam', loss='binary_crossentropy')
autoencoder.fit(x_train, x_train, epochs=5, batch_size=256, validation_data=(x_test, x_test))

In [None]:
# Practical 2: Denoising Autoencoder (PyTorch + MNIST)
import torch, torch.nn as nn, torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

tfm = transforms.Compose([transforms.ToTensor()])
train_ds = datasets.MNIST(root='./data', train=True, download=True, transform=tfm)
test_ds = datasets.MNIST(root='./data', train=False, download=True, transform=tfm)

def add_noise(x, p=0.3):
    m = torch.rand_like(x) < p
    noisy = x.clone()
    noisy[m] = 1 - noisy[m]
    return noisy

class AE(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc = nn.Sequential(nn.Flatten(), nn.Linear(28*28, 128), nn.ReLU())
        self.dec = nn.Sequential(nn.Linear(128, 28*28), nn.Sigmoid())
    def forward(self, x):
        z = self.enc(x)
        xh = self.dec(z).view(-1,1,28,28)
        return xh

ae = AE(); opt = optim.Adam(ae.parameters(), lr=1e-3); crit = nn.MSELoss()
loader = DataLoader(train_ds, batch_size=256, shuffle=True)

for epoch in range(2):
    for xb, yb in loader:
        nb = add_noise(xb)
        xr = ae(nb)
        loss = crit(xr, xb)
        opt.zero_grad(); loss.backward(); opt.step()
    print('epoch', epoch, 'loss', loss.item())

xb,_ = next(iter(DataLoader(test_ds, batch_size=8)))
nb = add_noise(xb)
with torch.no_grad():
    xr = ae(nb)

fig, axs = plt.subplots(3,8, figsize=(12,4))
for i in range(8):
    axs[0,i].imshow(xb[i,0], cmap='gray'); axs[0,i].axis('off')
    axs[1,i].imshow(nb[i,0], cmap='gray'); axs[1,i].axis('off')
    axs[2,i].imshow(xr[i,0], cmap='gray'); axs[2,i].axis('off')
plt.suptitle('Original / Noisy / Reconstructed'); plt.show()
