<a href="https://colab.research.google.com/github/JordanFoss/STAT3007_Project/blob/main/Denoising_Autoencoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST

In [2]:
class Encoder(nn.Module):
  def __init__(self, filters):
    super(Encoder, self).__init__()
    (W, K, S) = (28, 5, 1)
    P = (W*S-W-S+K)//2
    self.conv = nn.Conv2d(1, filters, K, S, P)
  def forward(self, x):
    x = self.conv(x)
    x = nn.functional.relu(x)
    return x


class Decoder(nn.Module):
  def __init__(self, filters):
    super(Decoder,self).__init__()
    (W, K, S) = (28, 1, 1)
    P = (W*S-W-S+K)//2
    self.conv = nn.Conv2d(filters, 1, K, S, P)
  def forward(self, x):
    x = self.conv(x)
    x = torch.sigmoid(x)
    return x


In [4]:
class Autoencoder(nn.Module):
  def __init__(self, filters):
    super(Autoencoder,self).__init__()
    self.filters = filters
    self.encoder = Encoder(filters)
    self.decoder = Decoder(filters)
  def forward(self, x):
    x = self.encoder(x)
    x = self.decoder(x)
    return x


In [5]:
def noisify(img):
  noise = 0.4*torch.randn(img.shape).to(device)
  img = img + noise
  img = torch.clamp(img, 0, 1)
  return img

In [6]:
def train_AE(ae, dloader, nepochs, lr):
  criterion = nn.MSELoss()
  optimizer = torch.optim.Adam(ae.parameters(), lr=lr)
  for epoch in range(nepochs):
    for i, (img, _) in enumerate(dloader):
      img = img.to(device)
      optimizer.zero_grad()
      out = ae(noisify(img))
      loss = criterion(out, img)
      loss.backward()
      optimizer.step()
      print('[%2d,%3d] loss: %.6f' % (epoch+1,i+1,loss.item()),end='\r')
    print('\n', end ='\r')


In [None]:
# set device to GPU if available
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# load MNIST training and testing data
mnist_tn = MNIST('./', transform=ToTensor(), download=False, train=True)
mnist_ts = MNIST('./', transform=ToTensor(), download=False, train=False)
dataloader_tn = DataLoader(mnist_tn, batch_size=250, shuffle=True,num_workers = 1)
dataloader_ts = DataLoader(mnist_ts, batch_size=10000)
# train an autoencoder with 4 filters
ae = Autoencoder(4).to(device)
train_AE(ae, dataloader_tn, 5, 0.01)
torch.save(ae,'netAE')
