In [1]:
# Regularized Autoencoders

import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
import matplotlib.pyplot as plt

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [3]:
# We'll use the MNIST dataset

transform = transforms.ToTensor()

train_dataset = datasets.MNIST(root="./data",
                         train=True,
                         download=True,
                         transform=transform)

test_dataset = datasets.MNIST(root="./data",
                              train=False,
                              download=True,
                              transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=32,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=32,
                                          shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 105371273.72it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 37158188.29it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz



100%|██████████| 1648877/1648877 [00:00<00:00, 28378941.96it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 6451245.77it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [4]:
class RegularizedAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, 8)
        )

        self.decoder = nn.Sequential(
            nn.Linear(8, 16),
            nn.ReLU(),
            nn.Linear(16, 32),
            nn.ReLU(),
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [5]:
# For our regularized autoencoder, we will add a L2 penalty term or weight decay. 

ae_model = RegularizedAE().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(ae_model.parameters(), lr=1e-3, weight_decay=1e-5)

In [7]:
n_epochs = 10

for epoch in range(n_epochs):
    for data, _ in train_loader:
        data = data.to(device).view(-1, 128)
        optimizer.zero_grad()
        outputs = ae_model(data)
        loss = criterion(outputs, data)

        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch} of {n_epochs}, Loss: {loss.item()}")

Epoch 0 of 10, Loss: 0.020155735313892365
Epoch 1 of 10, Loss: 0.023483606055378914
Epoch 2 of 10, Loss: 0.019208386540412903
Epoch 3 of 10, Loss: 0.01898408867418766
Epoch 4 of 10, Loss: 0.01972731202840805
Epoch 5 of 10, Loss: 0.0189221128821373
Epoch 6 of 10, Loss: 0.017561331391334534
Epoch 7 of 10, Loss: 0.018997181206941605
Epoch 8 of 10, Loss: 0.01691749505698681
Epoch 9 of 10, Loss: 0.01783318817615509
