In [31]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

In [32]:
#transform = transforms.ToTensor()

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5))
])

mnist_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

data_loader = torch.utils.data.DataLoader(dataset=mnist_data,
                                          batch_size=64,
                                          shuffle=True)



In [33]:
dataiter = iter(data_loader)
images, labels = next(dataiter)
print(torch.min(images), torch.max(images))

tensor(-1.) tensor(1.)


In [34]:
class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()        # N, 784 (28 * 28 pixels)
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 128), # N,784 -> N,128
            nn.ReLU(),
            nn.Linear(128,64),
            nn.ReLU(),
            nn.Linear(64,12),
            nn.ReLU(),
            nn.Linear(12,3), # N,3
        )
        self.decoder = nn.Sequential(
            nn.Linear(3, 12), 
            nn.ReLU(),
            nn.Linear(12,64),
            nn.ReLU(),
            nn.Linear(64,128),
            nn.ReLU(),
            nn.Linear(128,28*28), # N,3 -> n,784
            nn.Tanh()
        )
    def forward(self, x):
        encoded = self.encoder(x)
        deocded = self.decoder(encoded)
        return deocded
    
    

In [35]:
model = Autoencoder()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

In [36]:
num_epochs = 1
outputs = []

for epoch in range(num_epochs):
    for (img,_) in data_loader:
        img = img.reshape(-1,28*28)
        recon = model(img)
        loss = criterion(recon, img)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print(f'epoch:{epoch+1}, loss:{loss.item():.4f}')
    outputs.append((epoch, img, recon))

epoch:1, loss:0.1886
