In [None]:
import os 

from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision.transforms import Compose , ToTensor
import torch.nn as nn
from torch.optim import Adam
import torch.nn.functional as F
import torch
import matplotlib.pyplot as plt 
import numpy as np 

#hyperparamteres 
batch_size = 16
learning_rate = 0.00099
n_epochs = 25
device = torch.device('cuda'  if torch.cuda.is_available() else 'cpu')
transform = Compose([ToTensor()])

train_set  = MNIST('/data' , train = True, download = True,transform = transform)
test_set = MNIST('/data' , train = False, download = True,transform = transform)

train_loader = DataLoader(train_set , batch_size = 16 , shuffle =True)
test_loader = DataLoader(test_set , batch_size = 16, shuffle =False)



In [None]:
def visualize_sample_batch(images,labels):
    """
    visualize input samples of MNIST dataset.
    args:
        image(torch.Tensor): BCHW tensor of images
        label(torch.Tensor): BCHW tensor of lables.
        c(int): number of columns
        r(int): number of rows
    """

    images = images.squeeze(1)
    c , r = 4 , 4
    for i in range(len(labels)):
        image , label = images[i].numpy(), labels[i].numpy()
        #print(image.shape, '\t' , label)
        ax = plt.subplot(r,c,i+1)
        plt.tight_layout()
        ax.set_title(str(label))
        ax.axis('off')
        plt.imshow(image, cmap = 'gray')
        if i == c*r:
            break

image , label= next(iter(train_loader))
visualize_sample_batch(image,label)

In [None]:

class SimpleAutoencoders(nn.Module):
    """
    Undercomplete autoencoder class implementation.
    Args:
        input_shape(int): data frame input size.
    """
    def __init__(self,input_size):
        super(SimpleAutoencoders,self).__init__()
        #encoder layer
        self.encoder_hid = nn.Linear(input_size,128)
        self.encoder_out = nn.Linear(128,128)
        #decoder layers 
        self.decoder_hid = nn.Linear(128,128)
        self.decoder_out = nn.Linear(128,input_size)

    def forward(self,x):
       
        activation = F.relu(self.encoder_hid(x)) 
        code = F.relu(self.encoder_out(activation))
        activation = F.relu(self.decoder_hid(code)) 
        reconstructed = F.relu(self.decoder_out(activation))
        return reconstructed

               

In [None]:
model = SimpleAutoencoders(input_size = 784).to(device)
#defining the optimizers(Adam) and the loss fucntion (Mean squared erro reconstruction loss)
optimizer = Adam(model.parameters() , lr = learning_rate)
criterion = nn.MSELoss()

In [None]:

loss_over_time = []

for i in range(n_epochs):
    running_loss = 0
    for _, batch in enumerate(train_loader):
        #extract the batch  
        feature , _ = batch 
        #flatten the input frames (images)
        feature = feature.view(feature.size(0),-1).to(device)
        #forward pass to the model
        output = model(feature)
        #zero the grad 
        optimizer.zero_grad()
        #compute the loss 
        loss = criterion(output ,feature)
        #backbrop
        loss.backward()
        #optimize the parameters
        optimizer.step()
        
        running_loss += loss.item()
    #compute loss over one epoch through all the dataset
    running_loss = running_loss / len(train_loader)
    #store running loss 
    loss_over_time.append(loss)
    print("Epochs: {}/{}  , loss = {:.6f}".format(i+1 , n_epochs ,running_loss ))

In [None]:
image , label = next(iter(test_loader))
#flattening the image 
feature = image.view(image.size(0),-1).to(device)
with torch.no_grad():        
    output = model(feature)
    output = output.to('cpu')
    print(output.size())
output = torch.reshape(output, (16,1,28, 28))

print(output.shape)

visualize_sample_batch(output,label)

In [None]:
visualize_sample_batch(image,label)

In [None]:
saved_path = "./saved"

if not(os.exists(saved_path)):
    os.mkdirs(saved_path)

torch.save(model.state_dict(), os.path.join(saved_path,"simple_autoencoder.pth")) 