In [None]:
import os
import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image
from torch.optim import Optimizer
from tqdm import tqdm 

class AdamOptimizer(Optimizer):
    """
    implements ADAM Algorithm, as a preceding step.
    """
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.99), eps=1e-8, weight_decay=0):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super(AdamOptimizer, self).__init__(params, defaults)
        
    def step(self):
        import math
        """
        Performs a single optimization step.
        """
        loss = None
        for group in self.param_groups:

            for p in group['params']:
                grad = p.grad.data
                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Momentum (Exponential MA of gradients)
                    state['exp_avg'] = torch.zeros_like(p.data)
                    #print(p.data.size())
                    # RMS Prop componenet. (Exponential MA of squared gradients). Denominator.
                    state['exp_avg_sq'] = torch.zeros_like(p.data)
                    
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']

                b1, b2 = group['betas']
                state['step'] += 1
                
                # L2 penalty. Gotta add to Gradient as well.
                if group['weight_decay'] != 0:
                    grad = grad.add(group['weight_decay'], p.data)

                # Momentum
                exp_avg = torch.mul(exp_avg, b1) + (1 - b1)*grad
                # RMS
                exp_avg_sq = torch.mul(exp_avg_sq, b2) + (1-b2)*(grad*grad)
                
                denom = exp_avg_sq.sqrt() + group['eps']

                bias_correction1 = 1 / (1 - b1 ** state['step'])
                bias_correction2 = 1 / (1 - b2 ** state['step'])
                
                adapted_learning_rate = group['lr'] * bias_correction1 / math.sqrt(bias_correction2)

                p.data = p.data - adapted_learning_rate * exp_avg / denom 
                
        return loss

class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 128),
            nn.ReLU(True),
            nn.Linear(128, 64),
            nn.ReLU(True), 
            nn.Linear(64, 12), 
            nn.ReLU(True), 
            nn.Linear(12, 3))
        self.decoder = nn.Sequential(
            nn.Linear(3, 12),
            nn.ReLU(True),
            nn.Linear(12, 64),
            nn.ReLU(True),
            nn.Linear(64, 128),
            nn.ReLU(True), 
            nn.Linear(128, 28 * 28), 
            nn.Tanh())

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [None]:
if not os.path.exists('./output'):
    os.mkdir('./output')

def to_img(x):
    x = 0.5 * (x + 1)
    x = x.clamp(0, 1)
    x = x.view(x.size(0), 1, 28, 28)
    return x


num_epochs = 50
batch_size = 128
learning_rate = 1e-3


# Image transformer
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

# Download dataset
train_dataset = MNIST('./data', transform=img_transform, download=True, train=True)
val_dataset = MNIST('./data', transform=img_transform, download=True, train=False)

# Dataset length
num_train = len(train_dataset)
num_val = len(val_dataset)

print(f"Num. training samples:   {num_train}")
print(f"Num. validated samples:  {num_val}")

# Build dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoEncoder().cuda()
criterion = nn.MSELoss()

optimizer = AdamOptimizer(model.parameters(), lr=learning_rate, weight_decay=1e-5)

In [None]:
history = {'train': [], 'val': []}

for epoch in range(num_epochs):
    recon_loss = 0.0
    train_iterator = tqdm(train_loader, leave=True)
    model.train()
    for data in train_iterator:
        img, _ = data
        img = img.view(img.size(0), -1)
        img = Variable(img).to(device)
        # ===================forward=====================
        output = model(img)
        loss = criterion(output, img)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        recon_loss += loss.item()

        # Progress bar
        train_iterator.set_description('(Train) Epoch [{}/{}]'.format(epoch+1, args.num_epochs))
        train_iterator.set_postfix(train_recon_loss=loss.item())
        
    recon_loss /= len(train_loader)
    history['train'].append(recon_loss)
    
    # if epoch % 10 == 0:
    #     pic = to_img(output.cpu().data)
    #     save_image(pic, './mlp_img/image_{}.png'.format(epoch))

    # =================== Validation ==================
    # clear_output(wait=True)
    model.eval()
    val_iterator = tqdm(val_loader, leave=True)
    val_recon_loss = 0.0
    for eval_batch in val_iterator:
        val_img, _ = eval_batch
        val_img = val_img.view(val_img.size(0), -1)
        val_img = Variable(val_img).to(device)
        with torch.no_grad():
            val_output = model(val_img)
            val_loss = criterion(val_output, val_img)
            val_recon_loss += val_loss.item()
            
            # Progress bar
            val_iterator.set_description('(Val) Epoch [{}/{}]'.format(epoch+1, num_epochs))
            val_iterator.set_postfix(val_recon_loss=val_loss.item())
            # save result
            # valpic = to_img(val_output.cpu().data)
            # show_img(valpic, epoch + 1) 
            # save_image(valpic, '{}.png'.format(epoch+1))
    val_recon_loss /= len(val_loader)
    history['val'].append(val_recon_loss)
    
    print(f"Epoch [{epoch+1}/{num_epochs}]: Train reconstruct loss = {recon_loss:.5f} | Val reconstruct loss = {val_recon_loss:.5f}")

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
img = mpimg.imread('./output/image_40.png')
imgplot = plt.imshow(img)
plt.show()