In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

transform = transforms.Compose([
    transforms.ToTensor()
])

train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

train_subset = Subset(train_dataset, range(200))
test_subset = Subset(test_dataset, range(50))

train_loader = DataLoader(train_subset, batch_size=10, shuffle=True)
test_loader = DataLoader(test_subset, batch_size=10, shuffle=False)

class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 64),
        )
        self.decoder = nn.Sequential(
            nn.Linear(64, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 28*28),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

model = AutoEncoder()
optimizer = optim.Adam(model.parameters())
criterion = nn.MSELoss()

def train_model(num_epochs):
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        for data in train_loader:
            img, _ = data
            img = img.view(img.size(0), -1)
            output = model(img)
            loss = criterion(output, img)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        avg_train_loss = train_loss / len(train_loader)

        model.eval()
        test_loss = 0.0
        with torch.no_grad():
            for data in test_loader:
                img, _ = data
                img = img.view(img.size(0), -1)
                output = model(img)
                loss = criterion(output, img)
                test_loss += loss.item()
        avg_test_loss = test_loss / len(test_loader)
        
        print(f'Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}, Test Loss: {avg_test_loss:.4f}')

train_model(10)


Epoch 1, Train Loss: 0.1349, Test Loss: 0.0747
Epoch 2, Train Loss: 0.0737, Test Loss: 0.0729
Epoch 3, Train Loss: 0.0703, Test Loss: 0.0675
Epoch 4, Train Loss: 0.0635, Test Loss: 0.0623
Epoch 5, Train Loss: 0.0571, Test Loss: 0.0553
Epoch 6, Train Loss: 0.0499, Test Loss: 0.0498
Epoch 7, Train Loss: 0.0425, Test Loss: 0.0453
Epoch 8, Train Loss: 0.0370, Test Loss: 0.0431
Epoch 9, Train Loss: 0.0338, Test Loss: 0.0422
Epoch 10, Train Loss: 0.0314, Test Loss: 0.0402
