In [1]:
import torch.nn as nn
import torch.nn.functional as F
import os
from torch.autograd import Variable
from torchvision.utils import save_image
import torch
import torchvision
import torchvision.transforms as transforms

In [10]:
class DenoiseAutoEncoder(nn.Module):
    def __init__(self):
        super(DenoiseAutoEncoder,self).__init__()
        self.drop = nn.Dropout(0.4)
        self.encoder = nn.Sequential(
            nn.Linear(28*28,1000),
            nn.ReLU(),
            nn.BatchNorm1d(1000),
            nn.Linear(1000,500),
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.Linear(500,1000),
            nn.ReLU(),
            nn.BatchNorm1d(1000),
            nn.Linear(1000,28*28),
            nn.ReLU()
        )
    def forward(self,x):
        x = x.view(-1,784)
        #print(x.shape)
        noisex = self.drop(x)
        en =  self.encoder(noisex)
        de = self.decoder(en)
        de = de.view(-1,1,28,28)
        return x.view(-1,1,28,28),noisex.view(-1,1,28,28),de

dAE = DenoiseAutoEncoder()
dAE = dAE.cuda()

In [11]:
loss_criterion=nn.MSELoss()
optimizer = torch.optim.SGD( dAE.parameters(), lr=0.01)
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])
realimages = torchvision.datasets.MNIST(root='./data', train=True,download=True, transform=transform)
train_loader= torch.utils.data.DataLoader(realimages, batch_size=50,shuffle=True, num_workers=2)
it = iter(train_loader)
fdata,_ = it.next()

In [12]:
for epoch in range(200):
    for data,_ in train_loader:    
        optimizer.zero_grad()
        data=Variable(data.cuda())
        actual,corrupted,output=dAE(data)
        loss=loss_criterion(output,data)
        loss.backward()     
        optimizer.step()
    print('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, 30, loss.data[0]))
    actual,corrupted,output = dAE(fdata.cuda())
    if epoch % 10 == 0:
        save_image(actual.data, './GANs/samples/DAE/actual2.png'.format(epoch),nrow=10)
        save_image(corrupted.data, './GANs/samples/DAE/corrupted2.png'.format(epoch),nrow=10)
        save_image(output.data, './GANs/samples/DAE/output2.png'.format(epoch),nrow=10)

epoch [1/30], loss:1.1058


  if __name__ == '__main__':


epoch [2/30], loss:0.9917
epoch [3/30], loss:0.9464
epoch [4/30], loss:0.9156
epoch [5/30], loss:0.9153
epoch [6/30], loss:0.9127
epoch [7/30], loss:0.9166
epoch [8/30], loss:0.9200
epoch [9/30], loss:0.9048
epoch [10/30], loss:0.9028
epoch [11/30], loss:0.8990
epoch [12/30], loss:0.8999
epoch [13/30], loss:0.8961
epoch [14/30], loss:0.8935
epoch [15/30], loss:0.9061
epoch [16/30], loss:0.8938
epoch [17/30], loss:0.8897
epoch [18/30], loss:0.8987
epoch [19/30], loss:0.8998
epoch [20/30], loss:0.9007
epoch [21/30], loss:0.9014
epoch [22/30], loss:0.9033
epoch [23/30], loss:0.8925
epoch [24/30], loss:0.8968
epoch [25/30], loss:0.8948
epoch [26/30], loss:0.8938
epoch [27/30], loss:0.8856
epoch [28/30], loss:0.8851
epoch [29/30], loss:0.8881
epoch [30/30], loss:0.8843
epoch [31/30], loss:0.8837
epoch [32/30], loss:0.8899
epoch [33/30], loss:0.8882
epoch [34/30], loss:0.8740
epoch [35/30], loss:0.8819
epoch [36/30], loss:0.8860
epoch [37/30], loss:0.8909
epoch [38/30], loss:0.8836
epoch [39