# import libraries

In [52]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
import torch.optim as optim
from torch.utils.data import DataLoader


# Define Transformers

In [53]:
transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5),(0.5))
])

# Download Datasets

In [54]:
train_data = datasets.MNIST(root='./data_mnist', download=True, train = True, transform=transforms)

# Download Test Datesets

In [55]:
test_data = datasets.MNIST(root='./data_mnist', download=True, transform=transforms)

# Step - 4 Data Loaders

In [56]:
train_loader = DataLoader(train_data, batch_size = 128, shuffle = True)
test_loader = DataLoader(test_data, batch_size = 128 , shuffle = False)

# Build the Architecture

In [57]:
class denoise_AE(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc = nn.Sequential(
            nn.Conv2d(1,16,3,stride=2,padding=1),
            nn.ReLU(True),
            nn.Conv2d(16,16,3,stride=2,padding=1),
            nn.ReLU(True),
            nn.Conv2d(16,8,3,stride=2,padding=1),
            nn.ReLU(True)
        )
        self.dec = nn.Sequential(
            nn.ConvTranspose2d(8,16,3,stride=2,padding=1,output_padding=0),   # 4 -> 7
            nn.ReLU(True),
            nn.ConvTranspose2d(16,16,3,stride=2,padding=1,output_padding=1),  # 7 -> 14
            nn.ReLU(True),
            nn.ConvTranspose2d(16,1,3,stride=2,padding=1,output_padding=1),   # 14 -> 28
            nn.Tanh()
        )

    def forward(self, x):
        x = self.enc(x)
        x = self.dec(x)
        return x

# Object Creation

In [58]:
model = denoise_AE()
criterion = nn.MSELoss()
optimizer = optim. Adam(model.parameters(),lr = 0.001)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
print(model)

denoise_AE(
  (enc): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): Conv2d(16, 8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): ReLU(inplace=True)
  )
  (dec): Sequential(
    (0): ConvTranspose2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): ConvTranspose2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): ConvTranspose2d(16, 1, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (5): Tanh()
  )
)


In [59]:
def add_noise(img):
    # use same device/shape as img
    noise = torch.randn_like(img) * 0.2
    noisy_img = img + noise
    noisy_img = torch.clamp(noisy_img, 0., 1.)
    return noisy_img

# Train

In [60]:
for epoch in range(20):
    for img,label in train_loader:
        img = img.to(device)
        noisy_img = add_noise(img)
        noisy_img = noisy_img.to(device)
        output = model(noisy_img)
        loss = criterion(output,img)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print('epoch [{}/{}], loss:{:.4f}'.format(epoch+1,20,loss.item()))

epoch [1/20], loss:0.0515
epoch [2/20], loss:0.0358
epoch [3/20], loss:0.0328
epoch [4/20], loss:0.0278
epoch [5/20], loss:0.0250
epoch [6/20], loss:0.0241
epoch [7/20], loss:0.0220
epoch [8/20], loss:0.0223
epoch [9/20], loss:0.0222
epoch [10/20], loss:0.0223
epoch [11/20], loss:0.0211
epoch [12/20], loss:0.0224
epoch [13/20], loss:0.0221
epoch [14/20], loss:0.0219
epoch [15/20], loss:0.0218
epoch [16/20], loss:0.0193
epoch [17/20], loss:0.0195
epoch [18/20], loss:0.0187
epoch [19/20], loss:0.0177
epoch [20/20], loss:0.0200


In [61]:
for data in test_loader:
    img,label = data
    img = img.to(device)
    noisy_img = add_noise(img)
    noisy_img = noisy_img.to(device)
    output = model(noisy_img)
    break

# Reconstruct

In [None]:
show_img = torch.cat([img,noisy_img,output],0)
show_img = show_img.cpu().detach()
print(show_img.shape)
import matplotlib.pyplot as plt
import numpy as np

torch.Size([384, 1, 28, 28])
