In [None]:
import torch
import torchvision
import matplotlib.pyplot as plt

In [None]:
trans = torchvision.transforms.ToTensor()

dataset = torchvision.datasets.MNIST(root="./basic_autoencoder/data",
train = True,
download=True,
transform=trans)

loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=32,
shuffle = True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class AutoEncoder(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(28*28, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 16),
            torch.nn.ReLU(),
            torch.nn.Linear(16, 8)
        )

        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(8, 16),
            torch.nn.ReLU(),
            torch.nn.Linear(16, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 28*28),
            torch.nn.Sigmoid()
        )
    
    def forward(self, x):
        # x.to(device)
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded


In [None]:
model = AutoEncoder().to(device)

loss_function = torch.nn.MSELoss()

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay= 1e-8)

In [None]:
epochs = 20
outputs = []
losses = []


for epoch in range(epochs):
    for (image, _) in loader:
        # Reshaping the image to (-1, 784)
        image = image.reshape(-1, 28*28).to(device)
        
        # Output of Autoencoder
        reconstructed = model.forward(image)
        
        # Calculating the loss function
        loss = loss_function(reconstructed, image)
        
        # The gradients are set to zero,
        # the gradient is computed and stored.
        # .step() performs parameter update
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Storing the losses in a list for plotting
        losses.append(loss)
    outputs.append((epochs, image, reconstructed))
    print("Epoch:{}, Loss:{:.4f}".format(epoch+1, float(loss)))
 
# Defining the Plot Style
plt.style.use('fivethirtyeight')
plt.xlabel('Iterations')
plt.ylabel('Loss')
 
# Plotting the last 100 values
plt.plot(losses[-100:])

In [None]:
for i, item in enumerate(image):
   
  # Reshape the array for plotting
  item = item.reshape(-1, 28, 28)
  plt.imshow(item[0])
 
for i, item in enumerate(reconstructed):
  item = item.reshape(-1, 28, 28)
  plt.imshow(item[0])