In [9]:
import torch
import os
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torch import optim
from torchvision.datasets import MNIST
from torchvision.utils import save_image
import warnings
warnings.filterwarnings("ignore")

if not os.path.exists('./dc_img'):
    os.mkdir('./dc_img')

number_epochs = 10
batch_size = 128
learning_rate = 1e-4

transform_image = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = MNIST('./data', transform=transform_image)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [6]:
def to_image(x):
    x = 0.5 * (x + 1)
    x = x.clamp(0, 1)
    x = x.view(x.size(0), 1, 28, 28)
    return x

class conv_autoencoder(nn.Module):
    def __init__(self):
        super(conv_autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=3, padding=1),  
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=2),  
            nn.Conv2d(16, 8, 3, stride=2, padding=1),  
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=1)  
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(8, 16, 3, stride=2),  
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, 5, stride=3, padding=1),  
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 1, 2, stride=2, padding=1),  
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


model = conv_autoencoder()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)

In [8]:
for epoch in range(number_epochs):
    for data in data_loader:
        img, i = data
        img = Variable(img)
        
        # Forward pass
        output = model(img)
        loss = criterion(output, img)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # Print results
    print('epoch [{}/{}], loss:{:.4f}'
          .format(epoch+1, number_epochs, loss.data[0]))
    if epoch % 10 == 0:
        pic = to_image(output.cpu().data)
        save_image(pic, './dc_img/image_{}.png'.format(epoch))

torch.save(model.state_dict(), './convolutional_autoencoder.pth')

epoch [1/10], loss:0.3679
epoch [2/10], loss:0.3323
epoch [3/10], loss:0.3080
epoch [4/10], loss:0.3054
epoch [5/10], loss:0.3046
epoch [6/10], loss:0.2923
epoch [7/10], loss:0.2887
epoch [8/10], loss:0.2803
epoch [9/10], loss:0.2607
epoch [10/10], loss:0.2772
