In [None]:
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [None]:
device = 'cuda:0'
batch_size = 128

In [None]:
# train dataloader
train_loader = DataLoader(datasets.MNIST(root='data', train=True, download=True, transform=transforms.ToTensor()), batch_size=batch_size, shuffle=True)
train_iter = iter(train_loader)
# get first batch - images and labels
x_sample, _ = next(train_iter)
# float images 28 x 28, 0.0-1.0
print(x_sample[0].shape,torch.max(x_sample[0]).item(),torch.min(x_sample[0]).item(),x_sample[0].dtype)
# save first samples
for i in range(10):
    cv2.imwrite('mnist/inp'+str(i).zfill(5)+".png",np.asarray(x_sample[i].squeeze(0).detach().numpy()*255,np.uint8))

In [None]:
# test dataloader
test_loader = DataLoader(datasets.MNIST(root='data', train=False, download=True, transform=transforms.ToTensor()),batch_size=batch_size,shuffle=False)

In [None]:
# custom module
class Reshape(nn.Module):
    def __init__(self, *args):
        super(Reshape, self).__init__()
        self.shape = tuple(map(int,args))
    def forward(self, x):
        return x.view((x.shape[0],)+self.shape)

In [None]:
# architecture
class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        # convolution layers and max pooling of encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1,16,(3,3),padding=1),  # [n, 1, 28, 28] -> [n, 16, 28, 28] # 1 x 16 x (9 + 1)
            nn.ReLU(),                        # -> # 0
            nn.MaxPool2d(2),                  # [n, 16, 28, 28] -> [n, 16, 14, 14] # 0
            nn.Conv2d(16,8,(3,3),padding=1),  # [n, 16, 14, 14] -> [n, 8, 14, 14] # 8 x (16 x 9 + 1)
            nn.ReLU(),                        # -> # 0
            nn.MaxPool2d(2),                  # [n, 8, 14, 14] -> [n, 8, 7, 7] # 0
            nn.Conv2d(8,8,(3,3),padding=1),   # [n, 8, 7, 7] -> [n, 8, 7, 7]  # 8 x (8 * 9 + 1)
            nn.Sigmoid(),                     # -> # 0
            nn.MaxPool2d(2,padding=1),        # [n, 8, 7, 7] -> [n, 8, 4, 4] # 0
            torch.nn.Flatten()                # [n, 8, 4, 4] -> [n, 128] # 0
        )
        
        # convolution layers and upsampling of decoder
        self.decoder = nn.Sequential(
            Reshape(8,4,4),                   # -> [n, 8, 4, 4] # 0
            nn.Conv2d(8,8,(3,3),padding=1),   # [n, 8, 4, 4] -> [n, 8, 4, 4] # 8 x (8 * 9 + 1)
            nn.ReLU(),
            nn.Upsample(scale_factor=(2,2)),  # [n, 8, 4, 4] -> [n, 8, 8, 8] # 0
            nn.Conv2d(8,8,(3,3),padding=1),   # [n, 8, 8, 8] -> [n, 8, 8, 8] # 8 x (8 * 9 + 1)
            nn.ReLU(),
            nn.Upsample(scale_factor=(2,2)),  # [n, 8, 8, 8] -> [n, 8, 16, 16]
            nn.Conv2d(8,16,(3,3)),            # [n, 8, 16, 16] -> [n, 16, 14, 14] # 16 x (8 * 9 + 1)
            nn.ReLU(),
            nn.Upsample(scale_factor=(2,2)),  # [n, 16, 14, 14] -> [n, 16, 28, 28] # 0
            nn.Conv2d(16,1,(3,3),padding=1),  # [n, 16, 28, 28] -> [n, 1, 28, 28] # 1 x (16 * 9 + 1)
            nn.Sigmoid()
        )
    def forward(self, x):
        # apply encoder
        features = self.encoder(x)
        # apply decoder
        return self.decoder(features)
    def __str__(self):
        return str(self.encoder)+str(self.decoder)

In [None]:
autoencoder = Autoencoder().to(device)
print(autoencoder)

In [None]:
# Define optimizer
optimizer = optim.Adadelta(autoencoder.parameters())

In [None]:
# Training
epochs_count = 100
for epoch in range(epochs_count):

    # change model in training mode
    autoencoder.train()

    # to record loss and accuracy
    batch_loss = np.array([])
    batch_acc = np.array([])
        
    for batch, (x_train, _) in enumerate(train_loader):
        
        # send data to device 
        input = x_train.to(device)

        # reset parameters gradient to zero
        optimizer.zero_grad()
        
        # forward pass to the model
        output = autoencoder(input)
        
        # cross entropy loss
        loss = F.binary_cross_entropy(output, input)
        
        # find gradients 
        loss.backward()
        # update parameters using gradients
        optimizer.step()
        
        # recording loss
        batch_loss = np.append(batch_loss, [loss.item()])
        
        # recording accuracy
        total_train = input.numel()
        correct_train = (torch.abs(input-output) < 0.1).sum().item()
        acc = (100.0 * correct_train) / total_train
        batch_acc = np.append(batch_acc, [acc])

        if batch % 100 == 0 and batch > 0:              
            print('Train Epoch: {} [{}/{}] Loss: {:.6f} Acc: {:.4f}'.format(epoch, batch * len(input), len(train_loader.dataset), loss.item(), acc))
            
    epoch_loss = batch_loss.mean()
    epoch_acc = batch_acc.mean()

    print('Epoch: {} Loss: {:.6f} Acc: {:.4f}'.format(epoch, epoch_loss, epoch_acc))

In [None]:
# validation (evaluation)
total_test = 0
correct_test = 0

In [None]:
for batch, (x_test, _) in enumerate(test_loader):

    # send data to device 
    input = x_test.to(device)
    input.to(device)    

    # forward pass to the model
    output = autoencoder(input)
    
    total_test += input.numel()
    correct_test += (torch.abs(input-output) < 0.1).sum().item()

In [None]:
test_acc = (100.0 * correct_test) / total_test
print('Test accuracy: {:.4f}'.format(test_acc))

In [None]:
# save model weights
model_name = 'pytorch_mnist_autoencoder_model.pth'
torch.save(autoencoder.state_dict(), model_name) # weights only
print('Saved trained model at %s ' % model_name)

In [None]:
# use the model on few samples
input_images = x_sample[0:10].to(device)
output_images = autoencoder(input_images).to('cpu')
for i in range(10):
    cv2.imwrite('mnist/out'+str(i).zfill(5)+".png",np.asarray(output_images[i].squeeze(0).detach().numpy()*255,np.uint8))