In [None]:
import torch.utils.data
from torchvision import datasets, transforms
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch import nn, optim

device = 'cuda' if torch.cuda.is_available() else 'cpu'


In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize(0, 1)])
# Load the MNIST dataset
mnist_dataset_train = datasets.MNIST(
    root='./sample_data/', train=True, download=True, transform=transform)
# Load the test MNIST dataset
mnist_dataset_test = datasets.MNIST(
    root='./data', train=False, download=True, transform=transform)

batch_size = 128
train_loader = torch.utils.data.DataLoader(
    mnist_dataset_train, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    mnist_dataset_test, batch_size=5, shuffle=False)


In [3]:
class DAE(nn.Module):
    def __init__(self):
        super().__init__()

        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 128)

        self.fc4 = nn.Linear(128, 256)
        self.fc5 = nn.Linear(256, 512)
        self.fc6 = nn.Linear(512, 784)

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def encode(self, x):

        h1 = self.relu(self.fc1(x))
        h2 = self.relu(self.fc2(h1))
        return self.relu(self.fc3(h2))

    def decode(self, z):
        h4 = self.relu(self.fc4(z))
        h5 = self.relu(self.fc5(h4))
        return self.sigmoid(self.fc6(h5))

    def forward(self, x):
        q = self.encode(x.view(-1, 784))

        return self.decode(q)


In [None]:
def train(epoch, model, train_loader, optimizer,  cuda=True):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data.to(device)
        optimizer.zero_grad()

        data_noise = torch.randn(data.shape).to(device)
        data_noise = data + data_noise

        recon_batch = model(data_noise.to(device))
        loss = criterion(recon_batch, data.view(data.size(0), -1).to(device))
        loss.backward()

        train_loss += loss.item() * len(data)
        optimizer.step()

        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),
                                                                           100. * batch_idx /
                                                                           len(train_loader),
                                                                           loss.item()))

    print('====&gt; Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(train_loader.dataset)))


In [None]:
epochs = 10

model = DAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-2)
criterion = nn.MSELoss()


In [None]:
for epoch in range(1, epochs + 1):
    train(epoch, model, train_loader, optimizer, True)


In [None]:
for batch_idx, (data, labels) in enumerate(test_loader):
    data.to(device)
    optimizer.zero_grad()

    data_noise = torch.randn(data.shape).to(device)
    data_noise = data + data_noise

    recon_batch = model(data_noise.to(device))
    break


plt.figure(figsize=(20, 12))
for i in range(5):

    print("Image {i} with label {labels[i]}", end='')
    plt.subplot(3, 5, 1+i)
    plt.imshow(data_noise[i, :, :, :].view(
        28, 28).detach().numpy(), cmap='binary')
    plt.subplot(3, 5, 6+i)
    plt.imshow(recon_batch[i, :].view(28, 28).detach().numpy(), cmap='binary')
    plt.axis('off')
    plt.subplot(3, 5, 11+i)
    plt.imshow(data[i, :, :, :].view(28, 28).detach().numpy(), cmap='binary')
    plt.axis('off')
plt.show()
