In [1]:
import torch
from torch.utils.data import DataLoader
from torch import nn,optim
from torchvision import transforms,datasets
import visdom

In [2]:
class AE(nn.Module):
    
    def __init__(self):
        super(AE,self).__init__()
        
        # [b,784] => [b,20]
        self.encoder=nn.Sequential(
            nn.Linear(784,256),
            nn.ReLU(),
            nn.Linear(256,64),
            nn.ReLU(),
            nn.Linear(64,20),
            nn.ReLU()
        )
        # [b,20] => [b,784]
        self.decoder=nn.Sequential(
            nn.Linear(20,64),
            nn.ReLU(),
            nn.Linear(64,256),
            nn.ReLU(),
            nn.Linear(256,784),
            nn.Sigmoid()
        )
        
    def forward(self,x):
        """
        param x: [b,1,28,28]
        return x 
        """
        batchsz=x.size(0)
        # flatten
        x=x.view(batchsz,-1)
        # encoder
        x=self.encoder(x)
        # decoder
        x=self.decoder(x)
        # reshape
        x=x.view(batchsz,1,28,28)
        
        return x

In [3]:
def main():
    batch_size=32 
    epochs=100
    learning_rate=1e-3 
    
    train_db=datasets.MNIST('./data/mnist_data',train=True,download=True,
                  transform=transforms.Compose([
                      transforms.ToTensor(),
                      transforms.Normalize((0.1307,),(0.3081,))
                  ]))
    train_loader=torch.utils.data.DataLoader(train_db,batch_size=batch_size,shuffle=True)

    test_db=datasets.MNIST('./data/mnist_data',train=False,
                  transform=transforms.Compose([
                      transforms.ToTensor(),
                      transforms.Normalize((0.1307,),(0.3081,))
                  ]))
    test_loader=torch.utils.data.DataLoader(test_db,batch_size=batch_size,shuffle=True)
    
    
    x,_ =iter(train_loader).next()
    print('x:',x.shape)
    
    device=torch.device('cuda')
    model=AE().to(device)
    criteon=nn.MSELoss()
    optimizer=optim.Adam(model.parameters(),lr=learning_rate)
    print(model)
    
    
    viz=visdom.Visdom()
    for epoch in range(epochs):
        for batchidx,(x,_) in enumerate(train_loader):
            # [b,1,28,28]
            x=x.to(device)
            
            x_hat=model(x)   # logits
            loss=criteon(x_hat,x)
            
            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        print('epoch={} loss={}'.format(epoch,loss.item()))
        
        x,_=iter(test_loader).next()
        x=x.to(device)
        with torch.no_grad():
            x_hat=model(x)
            
        viz.images(x,nrow=8,win='x',opts=dict(title='x'))
        viz.images(x_hat,nrow=8,win='x_hat',opts=dict(title='x_hat'))

In [4]:
if __name__=='__main__':
    main()

x: torch.Size([32, 1, 28, 28])


Setting up a new session...


AE(
  (encoder): Sequential(
    (0): Linear(in_features=784, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=20, bias=True)
    (5): ReLU()
  )
  (decoder): Sequential(
    (0): Linear(in_features=20, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=784, bias=True)
    (5): Sigmoid()
  )
)
epoch=0 loss=0.5714783072471619
epoch=1 loss=0.5603113770484924
epoch=2 loss=0.5409227609634399
epoch=3 loss=0.47055497765541077
epoch=4 loss=0.527624249458313
epoch=5 loss=0.5254924297332764
epoch=6 loss=0.5143334865570068
epoch=7 loss=0.5338783860206604
epoch=8 loss=0.5495006442070007
epoch=9 loss=0.5165415406227112
epoch=10 loss=0.4805402159690857
epoch=11 loss=0.47055667638778687
epoch=12 loss=0.5304577946662903
epoch=13 loss=0.5037247538566589
epoch=14 loss=0.49

KeyboardInterrupt: 