<a href="https://colab.research.google.com/github/ahsanGoheer/Autoencoder-Collection/blob/main/MNIST_Autoencoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Install Necessary Modules.

!pip install torch
!pip install torchvision

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
# Import Modules

import torch
from torchvision import datasets
from torchvision import transforms
import matplotlib.pyplot as plt

In [3]:
# Load the Mnist Dataset
 

# Prepare a tensor transformation for the images.
image_tensor_trfm = transforms.ToTensor()

# Download the MNIST Dataset.
mnist_dataset = datasets.MNIST(root = "./data",
                               train = True,
                               transform = image_tensor_trfm,
                               download = True)
# Create a Data Loader.
data_loader = torch.utils.data.DataLoader(dataset = mnist_dataset,
                                          batch_size = 32,
                                          shuffle = True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 282225119.43it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 120893906.01it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 152921866.15it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz



100%|██████████| 4542/4542 [00:00<00:00, 23204054.53it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [15]:
class AutoEncoder(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.encoder = torch.nn.Sequential(torch.nn.Linear(28 * 28, 128),
                                       torch.nn.ReLU(),
                                       torch.nn.Linear(128, 64),
                                       torch.nn.ReLU(),
                                       torch.nn.Linear(64, 32),
                                       torch.nn.ReLU(),
                                       torch.nn.Linear(32, 16),
                                       torch.nn.ReLU(),
                                       torch.nn.Linear(16, 8)
                                       )
    self.decoder = torch.nn.Sequential(torch.nn.Linear(8, 16),
                                       torch.nn.ReLU(),
                                       torch.nn.Linear(16, 32),
                                       torch.nn.ReLU(),
                                       torch.nn.Linear(32, 64),
                                       torch.nn.ReLU(),
                                       torch.nn.Linear(64, 128),
                                       torch.nn.ReLU(),
                                       torch.nn.Linear(128, 28 * 28)
                                       )
    
  def forward(self, x):
    encoded = self.encoder(x)
    decoded = self.decoder(encoded)
    return decoded




In [16]:
auto_encoder_model = AutoEncoder()
auto_encoder_model.cuda()

AutoEncoder(
  (encoder): Sequential(
    (0): Linear(in_features=784, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=32, bias=True)
    (5): ReLU()
    (6): Linear(in_features=32, out_features=16, bias=True)
    (7): ReLU()
    (8): Linear(in_features=16, out_features=8, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=8, out_features=16, bias=True)
    (1): ReLU()
    (2): Linear(in_features=16, out_features=32, bias=True)
    (3): ReLU()
    (4): Linear(in_features=32, out_features=64, bias=True)
    (5): ReLU()
    (6): Linear(in_features=64, out_features=128, bias=True)
    (7): ReLU()
    (8): Linear(in_features=128, out_features=784, bias=True)
  )
)

In [17]:
loss_func = torch.nn.MSELoss()
optimizer = torch.optim.Adam(auto_encoder_model.parameters(),
                             lr = 1e-1,
                             weight_decay = 1e-8
                             )

In [18]:
import numpy as np

In [None]:
epochs = 20
outputs = []
losses = []
for epoch in range(epochs):
    print(f"Epoch:{epoch}")
    counter = 1
    for (image, _) in data_loader:
      
      # Reshaping the image to (-1, 784)
      image = image.reshape(-1, 28*28).cuda()
       
      # Output of Autoencoder
      reconstructed = auto_encoder_model(image)
       
      # Calculating the loss function
      loss = loss_func(reconstructed, image)
       
      # The gradients are set to zero,
      # the gradient is computed and stored.
      # .step() performs parameter update
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      print(f"Image:{counter} Loss:{loss}")
      counter+=1
      # Storing the losses in a list for plotting
      losses.append(loss)
    outputs.append((epochs, image, reconstructed))
 
# Defining the Plot Style
plt.style.use('fivethirtyeight')
plt.xlabel('Iterations')
plt.ylabel('Loss')
 
# Plotting the last 100 values
plt.plot(losses[-100:])

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Image:877 Loss:0.06578904390335083
Image:878 Loss:0.06413836032152176
Image:879 Loss:0.06319347769021988
Image:880 Loss:0.07653617858886719
Image:881 Loss:0.06498543173074722
Image:882 Loss:0.06603313982486725
Image:883 Loss:0.07097313553094864
Image:884 Loss:0.07378967106342316
Image:885 Loss:0.06381241977214813
Image:886 Loss:0.0748901441693306
Image:887 Loss:0.07242116332054138
Image:888 Loss:0.07409358024597168
Image:889 Loss:0.0645185261964798
Image:890 Loss:0.06739546358585358
Image:891 Loss:0.06643690168857574
Image:892 Loss:0.06411612778902054
Image:893 Loss:0.07070748507976532
Image:894 Loss:0.07048924267292023
Image:895 Loss:0.06838658452033997
Image:896 Loss:0.06844557821750641
Image:897 Loss:0.06827075034379959
Image:898 Loss:0.06596952676773071
Image:899 Loss:0.06634218990802765
Image:900 Loss:0.06733051687479019
Image:901 Loss:0.0686182752251625
Image:902 Loss:0.06875759363174438
Image:903 Loss:0.06694062054

In [None]:
for i, item in enumerate(image):
   
  # Reshape the array for plotting
  item = item.reshape(-1, 28, 28)
  plt.imshow(item[0])
 
for i, item in enumerate(reconstructed):
  item = item.reshape(-1, 28, 28)
  plt.imshow(item[0])