# Denoising Autoencoder

In [None]:
import torch, torchvision
import torch.nn as nn
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
from torch.autograd import Variable

### Hyperparameters

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

batch_size = 100
image_size = 28*28 ## MNIST Dataset
hidden_size = 28

learning_rate = 0.0002
num_epochs = 10

### MNIST Data

In [None]:
MNIST_train = torchvision.datasets.MNIST(root='./../data/MNIST/', train=True, transform=transforms.ToTensor(), download=True)
MNIST_test = torchvision.datasets.MNIST(root='./../data/MNIST/', train=False, transform=transforms.ToTensor(), download=True)

train_loader = DataLoader(dataset=MNIST_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=MNIST_test, batch_size=batch_size, shuffle=True)

### Model

In [None]:
class AutoEncoder(nn.Module):
    
    def __init__(self):
        super(AutoEncoder, self).__init__()
        
        self.Encoder = nn.Linear(image_size, hidden_size)
        self.Decoder = nn.Linear(hidden_size, image_size)
        
    def forward(self, input):
        input = input.view(batch_size, -1)
        encoded = self.Encoder(input)
        output = self.Decoder(encoded).view(batch_size, 1, 28, 28)
        
        return output

### Loss Function & Optimizer

In [None]:
model = AutoEncoder().to(device)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

### Train

In [None]:
all_losses = []

total_step = len(train_loader)
for epoch in range(num_epochs):
    for i , (images, _) in enumerate(train_loader):
        
        noise = torch.nn.init.normal(torch.FloatTensor(batch_size, 1, 28, 28), 0, 0.1)
        noise = Variable(noise.to(device))
        
        input = Variable(images).to(device)
        noise_input = input + noise
        
        output = model.forward(noise_input)
        loss = criterion(output, input)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], loss [{:.4f}]'.format(epoch+1, num_epochs, i+1, total_step, loss.item()))
            all_losses.append(loss.item())     

### Result

In [None]:
out_img = torch.squeeze(output.cpu().data)

for i in range(out_img.size()[0]):    
    fig = plt.figure()
    origin = fig.add_subplot(1, 2, 1)
    generated = fig.add_subplot(1, 2, 2)
    
    origin.imshow(torch.squeeze(images[i]), cmap='gray')
    generated.imshow(out_img[i], cmap='gray')
    
    fig.show()