In [1]:
import os
import torch
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoader

from torchvision.datasets import MNIST
from torchvision import transforms as tfs
from torchvision.utils import save_image

In [3]:
im_tfs=tfs.Compose([
    tfs.ToTensor(),
    tfs.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
])

train_set=MNIST('./mnist',transform=im_tfs,download=True)
train_data=DataLoader(train_set,batch_size=128,shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!


In [4]:
class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder,self).__init__()
        
        self.encoder=nn.Sequential(
            nn.Linear(28*28,128),
            nn.ReLU(True),
            nn.Linear(128,64),
            nn.ReLU(True),
            nn.Linear(64,12),
            nn.ReLU(True),
            nn.Linear(12,3),
        )
        
        self.decoder=nn.Sequential(
            nn.Linear(3,12),
            nn.ReLU(True),
            nn.Linear(12,64),
            nn.ReLU(True),
            nn.Linear(64,128),
            nn.ReLU(True),
            nn.Linear(128,28*28),
            nn.Tanh()
        )
    def forward(self,x):
        encode=self.encoder(x)
        decode=self.decoder(encode)
        return encode,decode

In [5]:
net=autoencoder()
x=Variable(torch.randn(1,28*28))
code,_=net(x)
print(code.shape)

torch.Size([1, 3])


In [6]:
criterion=nn.MSELoss(size_average=False)
optimizer=torch.optim.Adam(net.parameters(),lr=1e-3)

def to_img(x):
    x=0.5*(x+1.)
    x=x.clamp(0,1)
    x=x.view(x.shape[0],1,28,28)
    return x



In [None]:
for e in range(100):
    for im,_ in train_data:
        im=im.view(im.shape[0],-1)
        im=Variable(im)
        _,output=net(im)
        loss=criterion(output,im)/im.shape[0]
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    if(e+1)%20==0:
        print(e+1,loss.data)
        pic=to_img(output.cpu().data)
        if not os.path.exists('./simple_autoencoder'):
            os.mkdir('./simple_autoencoder')
        save_image(pic,'./simple_autoencoder/image_{}.png'.format(e + 1))

20 tensor(94.4271)
40 tensor(96.6156)
60 tensor(98.2562)
80 tensor(104.6222)
