In [30]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [37]:
transform = transforms.Compose([
    transforms.ToTensor()
])

train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_dataset  = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size = 128, shuffle = True)
test_loader = DataLoader(test_dataset, batch_size = 128, shuffle = True)

In [49]:
class RAE(nn.Module):
    def __init__(self, input_dim=784, latent_dim=32):
        super().__init__()

        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, latent_dim)
        )

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, input_dim),
            nn.Sigmoid()
        )

    def forward(self, x):
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat, z


In [53]:
rae = RAE().to(device)
optimizer = optim.Adam(rae.parameters(), lr=1e-3)

lambda_reg = 1e-3  # regularization strength

for epoch in range(10):
    rae.train()
    total_loss = 0

    for x, _ in train_loader:
        x = x.view(x.size(0), -1).to(device)

        x_hat, z = rae(x)

        recon_loss = nn.functional.mse_loss(x_hat, x)
        reg_loss = torch.mean(z ** 2)
        loss = recon_loss + lambda_reg * reg_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"RAE Epoch [{epoch+1}/10], Loss: {total_loss/len(train_loader):.4f}")


RAE Epoch [1/10], Loss: 0.0391
RAE Epoch [2/10], Loss: 0.0155
RAE Epoch [3/10], Loss: 0.0117
RAE Epoch [4/10], Loss: 0.0101
RAE Epoch [5/10], Loss: 0.0092
RAE Epoch [6/10], Loss: 0.0086
RAE Epoch [7/10], Loss: 0.0081
RAE Epoch [8/10], Loss: 0.0077
RAE Epoch [9/10], Loss: 0.0074
RAE Epoch [10/10], Loss: 0.0071
