In [None]:
"""
@Description: AutoEncoder
@Author: Dezhan Tu
@LastEditTime: 08/23/2020
"""

import numpy as np
import os

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader

from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image

import matplotlib.pyplot as plt


# Define hyperparameters
NUM_EPOCHS = 20
LEARNING_RATE = 1e-3
BATCH_SIZE = 128

In [None]:
# Data utility

# Convert data to torch.FloatTensor
img_transform  = transforms.Compose([transforms.ToTensor()])

# Download MNIST Dataset
train_dataset = MNIST(
    root="./dataset", 
    train=True, 
    transform=img_transform, 
    download=True
)

test_dataset = MNIST(
    root="./dataset", 
    train=False, 
    transform=img_transform, 
    download=True
)

# Create training and testing datasets
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False
)


def save_epoch_img(img, epoch):
    """
    save img during training process
    """
    img = img.view(img.size(0), 1, 28, 28)
    save_image(img, './AutoEncoder_Result/AutoEncoder{}.png'.format(epoch))


def make_dir():
    """
    make directory to store result
    """
    if not os.path.exists('AutoEncoder_Result'):
        os.makedirs('AutoEncoder_Result')
        
make_dir()    #Create result directory

In [None]:
class AutoEncoder(nn.Module):
    """
    AutoEncoder Class
    """
    def __init__(self):
        """
        define network structure
        """
        super(AutoEncoder, self).__init__()
        
        #Encoder
        self.encoder1 = nn.Linear(784, 256)
        self.encoder2 = nn.Linear(256, 128)
        self.encoder3 = nn.Linear(128, 64)
        self.encoder4 = nn.Linear(64, 32)
        
        #Decoder
        self.decoder1 = nn.Linear(32, 64)
        self.decoder2 = nn.Linear(64, 128)
        self.decoder3 = nn.Linear(128, 256)
        self.decoder4 = nn.Linear(256, 784)
        
        
    def forward(self, x):
        """
        forward computing
        """
        x = torch.relu(self.encoder1(x))
        x = torch.relu(self.encoder2(x))
        x = torch.relu(self.encoder3(x))
        x = torch.relu(self.encoder4(x))
        
        x = torch.relu(self.decoder1(x))
        x = torch.relu(self.decoder2(x))
        x = torch.relu(self.decoder3(x))
        x = torch.relu(self.decoder4(x))
        
        return x

# Training Process

In [None]:
# Use gpu if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Instantiate an AutoEncoder
model = AutoEncoder().to(device)

# Mean-Squared Error Loss
criterion = nn.MSELoss()

# Adam optimizer with learning rate 1e-3
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Training Process
loss_output = []

for epoch in range(NUM_EPOCHS):
    loss = 0.0  
    for img, _ in train_loader:
        img = img.view(-1, 784).to(device)    # Reshape mini-batch data
        optimizer.zero_grad()                 # Reset the gradients back to zero
        outputs = model(img)                  # Reconstruction
        train_loss = criterion(outputs, img)  # Reconstruction loss
        train_loss.backward()                 # Compute accumulated gradients
        optimizer.step()                      # Parameters update
        loss += train_loss.item()             # Add the mini-batch training loss to epoch loss
    
    
    loss = loss / len(train_loader)           # Compute the epoch training loss
    loss_output.append(loss)
    
    # Display the epoch training loss
    print("Epoch : {}/{}, Training loss = {:.3f}".
          format(epoch + 1, NUM_EPOCHS, loss))
    
    # Save result
    if epoch % 10 == 0:
        save_epoch_img(outputs.cpu().data, epoch)

In [None]:
# Visulize the loss 
plt.figure()
plt.plot(loss_output)
plt.title('Training Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.savefig('./AutoEncoder_Result/Autoencoder_train_loss.png')

# Testing Process

In [None]:
#Testing Process
i = 0
for img, _ in test_loader:
    img = img.view(-1, 784).to(device)    # Reshape mini-batch data
    outputs = model(img)                  # Reconstruction
    outputs = outputs.view(outputs.size(0), 1, 28, 28).cpu().data
    
    i = i + 1 
    save_image(outputs, './AutoEncoder_Result/Test_result_{}.png'.format(i))