## [VAE的原理](https://www.cnblogs.com/huangshiyu13/p/6209016.html)

In [1]:
import torch
from torch.utils.data import DataLoader
from torch import nn,optim
from torchvision import transforms,datasets
import visdom

In [2]:
class VAE(nn.Module):
    
    def __init__(self):
        super(VAE,self).__init__()
        
        # VAE的编码器会产生两个向量:一个是均值向量，一个是标准差向量。
        # [b,784] => [b,20]
        # u: [b,10]
        # sigma: [b,10]
        self.encoder=nn.Sequential(
            nn.Linear(784,256),
            nn.ReLU(),
            nn.Linear(256,64),
            nn.ReLU(),
            nn.Linear(64,20),
            nn.ReLU()
        )
        # [b,10] => [b,784]
        self.decoder=nn.Sequential(
            nn.Linear(10,64),
            nn.ReLU(),
            nn.Linear(64,256),
            nn.ReLU(),
            nn.Linear(256,784),
            nn.Sigmoid()
        )
        
    def forward(self,x):
        """
        param x: [b,1,28,28]
        return x 
        """
        batchsz=x.size(0)
        # flatten
        x=x.view(batchsz,-1)
        # encoder
        # [b,20] include mean and sigma
        h_=self.encoder(x)
        
        # [b,20] => u=[b,10], sigma=[b,10]
        mu,sigma=h_.chunk(2,dim=1)
        # reparametrize trick, epison~N(0,1)
        epison=torch.randn_like(sigma)
        h=mu+sigma*epison
        
        # decoder
        x_hat=self.decoder(h)
        # reshape
        x_hat=x.view(batchsz,1,28,28)
        
        # KL divegence
        kld=0.5*torch.sum(
            torch.pow(mu,2)+torch.pow(sigma,2)-torch.log(1e-8+torch.pow(sigma,2))-1 
        )  #/ (batchsz*28*28)
        
        return x_hat,kld

In [3]:
def main():
    batch_size=32 
    epochs=100
    learning_rate=1e-3 
    
    train_db=datasets.MNIST('./data/mnist_data',train=True,download=True,
                  transform=transforms.Compose([
                      transforms.ToTensor(),
                      transforms.Normalize((0.1307,),(0.3081,))
                  ]))
    train_loader=torch.utils.data.DataLoader(train_db,batch_size=batch_size,shuffle=True)

    test_db=datasets.MNIST('./data/mnist_data',train=False,
                  transform=transforms.Compose([
                      transforms.ToTensor(),
                      transforms.Normalize((0.1307,),(0.3081,))
                  ]))
    test_loader=torch.utils.data.DataLoader(test_db,batch_size=batch_size,shuffle=True)
    
    
    x,_ =iter(train_loader).next()
    print('x:',x.shape)
    
    device=torch.device('cuda')
    model=VAE().to(device)
    criteon=nn.MSELoss()
    optimizer=optim.Adam(model.parameters(),lr=learning_rate)
    print(model)
    
    
    viz=visdom.Visdom()
    for epoch in range(epochs):
        for batchidx,(x,_) in enumerate(train_loader):
            # [b,1,28,28]
            x=x.to(device)
            
            x_hat,kld=model(x)   # logits
            reconstruct_loss=criteon(x_hat,x)
            
            # 损失函数可以把这两方面进行加和。
            # 一方面，是图片的重构误差，可以用平均平方误差来度量
            # 另一方面，用KL散度来度量潜在变量的分布和单位高斯分布的差异
            loss=reconstruct_loss+1.0*kld
            
            # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        print('epoch={} loss={} kld={}'.format(epoch,loss.item(),kld))
        
        x,_=iter(test_loader).next()
        x=x.to(device)
        with torch.no_grad():
            x_hat,kld=model(x)
            
        viz.images(x,nrow=8,win='x',opts=dict(title='x'))
        viz.images(x_hat,nrow=8,win='x_hat',opts=dict(title='x_hat'))

In [4]:
if __name__=='__main__':
    main()

x: torch.Size([32, 1, 28, 28])


Setting up a new session...


VAE(
  (encoder): Sequential(
    (0): Linear(in_features=784, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=20, bias=True)
    (5): ReLU()
  )
  (decoder): Sequential(
    (0): Linear(in_features=10, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=784, bias=True)
    (5): Sigmoid()
  )
)
epoch=0 loss=2.069458105324884e-06 kld=2.069458105324884e-06
epoch=1 loss=1.8934547370008659e-06 kld=1.8934547370008659e-06
epoch=2 loss=8.11341124062892e-07 kld=8.11341124062892e-07
epoch=3 loss=8.413691716668836e-07 kld=8.413691716668836e-07
epoch=4 loss=5.489101226885396e-07 kld=5.489101226885396e-07
epoch=5 loss=2.53642838288215e-07 kld=2.53642838288215e-07
epoch=6 loss=6.020738965162309e-07 kld=6.020738965162309e-07
epoch=7 loss=6.033544650563272e-07 kld=6.0335