# **03-VAE**

In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image

### **Device**

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### **Hyper-parameter**

In [3]:
image_size = 784
h_dim = 400 # hidden state dim
z_dim = 20 # z-latent dim
num_epochs = 15
batch_size = 128
learning_rate = 1e-3

In [4]:
sample_dir = 'samples'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

### **Dataset & DataLoader**

In [5]:
# MNIST dataset
dataset = torchvision.datasets.MNIST(root="./data", 
                                     train=True, 
                                     transform=transforms.ToTensor(),
                                     download=False)

# Data Loader
data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                          batch_size=batch_size,
                                          shuffle=True)

### **VAE model**

- 知乎【VAE模型】：[link](https://zhuanlan.zhihu.com/p/34998569?utm_campaign=shareopn&utm_medium=social&utm_oi=994876384665292800&utm_psn=1638135436173361152&utm_source=wechat_session)

In [6]:
class VAE(nn.Module):
    def __init__(self, image_size=784, h_dim=400, z_dim=20):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(image_size, h_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(h_dim, z_dim)
        self.fc4 = nn.Linear(z_dim, h_dim)
        self.fc5 = nn.Linear(h_dim, image_size)
    
    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc2(h), self.fc3(h)
    
    def decode(self, z):
        h = F.relu(self.fc4(z))
        return F.sigmoid(self.fc5(h))
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var/2)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        x_reconst = self.decode(z)
        return x_reconst, mu, log_var

In [7]:
model = VAE().to(device)

### **Training Config**

In [8]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

### **Train**

In [24]:
for epoch in range(num_epochs):
    for i, (x, _) in enumerate(data_loader):
        # Forward pass
        x = x.to(device).view(-1, image_size)
        x_reconst, mu, log_var = model(x)
        
        # Compute reconstruction loss and kl divergence
        # For KL divergence, see Appendix B in VAE paper or http://yunjey47.tistory.com/43
        reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)
        kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        
        # Backprop and optimize
        loss = reconst_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 10 == 0:
            print ("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}" 
                   .format(epoch+1, num_epochs, i+1, len(data_loader), reconst_loss.item(), kl_div.item()))



Epoch[1/15], Step [10/469], Reconst Loss: 36330.8164, KL Div: 2750.3308
Epoch[1/15], Step [20/469], Reconst Loss: 29143.3828, KL Div: 1110.9963
Epoch[1/15], Step [30/469], Reconst Loss: 27420.0352, KL Div: 1354.1023
Epoch[1/15], Step [40/469], Reconst Loss: 26775.8516, KL Div: 672.9094
Epoch[1/15], Step [50/469], Reconst Loss: 26968.9824, KL Div: 768.7137
Epoch[1/15], Step [60/469], Reconst Loss: 26433.8750, KL Div: 876.1646
Epoch[1/15], Step [70/469], Reconst Loss: 24678.7422, KL Div: 827.0153
Epoch[1/15], Step [80/469], Reconst Loss: 24115.4375, KL Div: 1039.8115
Epoch[1/15], Step [90/469], Reconst Loss: 24112.4688, KL Div: 1069.9922
Epoch[1/15], Step [100/469], Reconst Loss: 24058.1660, KL Div: 1140.9454
Epoch[1/15], Step [110/469], Reconst Loss: 22115.7031, KL Div: 1292.7358
Epoch[1/15], Step [120/469], Reconst Loss: 21786.5488, KL Div: 1503.6377
Epoch[1/15], Step [130/469], Reconst Loss: 19728.6113, KL Div: 1670.1440
Epoch[1/15], Step [140/469], Reconst Loss: 20538.4141, KL Div: 1

Epoch[3/15], Step [220/469], Reconst Loss: 11914.9844, KL Div: 3046.1194
Epoch[3/15], Step [230/469], Reconst Loss: 12075.8916, KL Div: 3044.9729
Epoch[3/15], Step [240/469], Reconst Loss: 11203.0723, KL Div: 3040.5618
Epoch[3/15], Step [250/469], Reconst Loss: 11593.9111, KL Div: 3105.9246
Epoch[3/15], Step [260/469], Reconst Loss: 11412.0947, KL Div: 3141.5129
Epoch[3/15], Step [270/469], Reconst Loss: 11440.8037, KL Div: 3095.9314
Epoch[3/15], Step [280/469], Reconst Loss: 11580.3516, KL Div: 3061.3296
Epoch[3/15], Step [290/469], Reconst Loss: 11701.6367, KL Div: 2971.4810
Epoch[3/15], Step [300/469], Reconst Loss: 12139.8730, KL Div: 3059.2178
Epoch[3/15], Step [310/469], Reconst Loss: 11401.8047, KL Div: 3092.9895
Epoch[3/15], Step [320/469], Reconst Loss: 11215.0664, KL Div: 3035.7310
Epoch[3/15], Step [330/469], Reconst Loss: 11274.5723, KL Div: 3127.3665
Epoch[3/15], Step [340/469], Reconst Loss: 11621.3125, KL Div: 2971.7576
Epoch[3/15], Step [350/469], Reconst Loss: 11453.94

Epoch[5/15], Step [430/469], Reconst Loss: 10667.5488, KL Div: 3100.7109
Epoch[5/15], Step [440/469], Reconst Loss: 10691.4229, KL Div: 3211.2402
Epoch[5/15], Step [450/469], Reconst Loss: 10635.2461, KL Div: 3117.2146
Epoch[5/15], Step [460/469], Reconst Loss: 10773.6582, KL Div: 3097.5054
Epoch[6/15], Step [10/469], Reconst Loss: 11012.7012, KL Div: 3128.3452
Epoch[6/15], Step [20/469], Reconst Loss: 10502.8926, KL Div: 3166.0437
Epoch[6/15], Step [30/469], Reconst Loss: 10568.4951, KL Div: 3238.9385
Epoch[6/15], Step [40/469], Reconst Loss: 10822.2900, KL Div: 3084.0029
Epoch[6/15], Step [50/469], Reconst Loss: 10526.2930, KL Div: 3111.9221
Epoch[6/15], Step [60/469], Reconst Loss: 11105.4922, KL Div: 3204.7832
Epoch[6/15], Step [70/469], Reconst Loss: 10580.2891, KL Div: 3187.8926
Epoch[6/15], Step [80/469], Reconst Loss: 11055.5000, KL Div: 3191.9565
Epoch[6/15], Step [90/469], Reconst Loss: 10904.3506, KL Div: 3237.1150
Epoch[6/15], Step [100/469], Reconst Loss: 10720.0938, KL Di

Epoch[8/15], Step [180/469], Reconst Loss: 10879.2207, KL Div: 3198.4856
Epoch[8/15], Step [190/469], Reconst Loss: 10264.0898, KL Div: 3150.5872
Epoch[8/15], Step [200/469], Reconst Loss: 10799.5977, KL Div: 3178.5706
Epoch[8/15], Step [210/469], Reconst Loss: 10458.1738, KL Div: 3127.1321
Epoch[8/15], Step [220/469], Reconst Loss: 10429.3984, KL Div: 3108.5457
Epoch[8/15], Step [230/469], Reconst Loss: 10880.7744, KL Div: 3289.7722
Epoch[8/15], Step [240/469], Reconst Loss: 10876.1807, KL Div: 3218.0251
Epoch[8/15], Step [250/469], Reconst Loss: 10865.7324, KL Div: 3253.1882
Epoch[8/15], Step [260/469], Reconst Loss: 10274.0078, KL Div: 3261.6040
Epoch[8/15], Step [270/469], Reconst Loss: 10486.6055, KL Div: 3273.8674
Epoch[8/15], Step [280/469], Reconst Loss: 10392.4541, KL Div: 3104.3740
Epoch[8/15], Step [290/469], Reconst Loss: 10749.8076, KL Div: 3307.3662
Epoch[8/15], Step [300/469], Reconst Loss: 10439.2275, KL Div: 3234.4814
Epoch[8/15], Step [310/469], Reconst Loss: 10052.16

Epoch[10/15], Step [390/469], Reconst Loss: 10126.7715, KL Div: 3251.8684
Epoch[10/15], Step [400/469], Reconst Loss: 10399.1582, KL Div: 3190.6189
Epoch[10/15], Step [410/469], Reconst Loss: 10064.6875, KL Div: 3190.8428
Epoch[10/15], Step [420/469], Reconst Loss: 10534.4043, KL Div: 3206.6299
Epoch[10/15], Step [430/469], Reconst Loss: 10430.2832, KL Div: 3194.7488
Epoch[10/15], Step [440/469], Reconst Loss: 10230.6396, KL Div: 3314.7737
Epoch[10/15], Step [450/469], Reconst Loss: 9947.9775, KL Div: 3065.6089
Epoch[10/15], Step [460/469], Reconst Loss: 10150.1270, KL Div: 3219.0342
Epoch[11/15], Step [10/469], Reconst Loss: 10523.9287, KL Div: 3174.7942
Epoch[11/15], Step [20/469], Reconst Loss: 10616.7812, KL Div: 3210.3657
Epoch[11/15], Step [30/469], Reconst Loss: 10077.4961, KL Div: 3242.1147
Epoch[11/15], Step [40/469], Reconst Loss: 10241.6836, KL Div: 3181.1416
Epoch[11/15], Step [50/469], Reconst Loss: 10065.3203, KL Div: 3225.1790
Epoch[11/15], Step [60/469], Reconst Loss: 1

Epoch[13/15], Step [130/469], Reconst Loss: 10298.8555, KL Div: 3154.7683
Epoch[13/15], Step [140/469], Reconst Loss: 9693.7998, KL Div: 3183.0391
Epoch[13/15], Step [150/469], Reconst Loss: 10347.5938, KL Div: 3193.1753
Epoch[13/15], Step [160/469], Reconst Loss: 10031.3730, KL Div: 3176.6130
Epoch[13/15], Step [170/469], Reconst Loss: 9889.5332, KL Div: 3178.3655
Epoch[13/15], Step [180/469], Reconst Loss: 10461.6113, KL Div: 3252.5852
Epoch[13/15], Step [190/469], Reconst Loss: 10122.4971, KL Div: 3217.4688
Epoch[13/15], Step [200/469], Reconst Loss: 10442.4941, KL Div: 3286.5334
Epoch[13/15], Step [210/469], Reconst Loss: 10430.1914, KL Div: 3330.8237
Epoch[13/15], Step [220/469], Reconst Loss: 10290.0137, KL Div: 3203.7034
Epoch[13/15], Step [230/469], Reconst Loss: 10280.9600, KL Div: 3280.4812
Epoch[13/15], Step [240/469], Reconst Loss: 10418.8477, KL Div: 3245.9163
Epoch[13/15], Step [250/469], Reconst Loss: 10149.3262, KL Div: 3172.5564
Epoch[13/15], Step [260/469], Reconst Lo

Epoch[15/15], Step [330/469], Reconst Loss: 10240.4756, KL Div: 3228.3867
Epoch[15/15], Step [340/469], Reconst Loss: 10180.2539, KL Div: 3248.8364
Epoch[15/15], Step [350/469], Reconst Loss: 10284.3047, KL Div: 3363.6687
Epoch[15/15], Step [360/469], Reconst Loss: 9983.1523, KL Div: 3155.3672
Epoch[15/15], Step [370/469], Reconst Loss: 10298.2949, KL Div: 3303.8367
Epoch[15/15], Step [380/469], Reconst Loss: 10109.8828, KL Div: 3176.0752
Epoch[15/15], Step [390/469], Reconst Loss: 9964.9180, KL Div: 3380.7402
Epoch[15/15], Step [400/469], Reconst Loss: 9970.1660, KL Div: 3110.7456
Epoch[15/15], Step [410/469], Reconst Loss: 9976.4492, KL Div: 3225.6387
Epoch[15/15], Step [420/469], Reconst Loss: 10055.7607, KL Div: 3193.4941
Epoch[15/15], Step [430/469], Reconst Loss: 9845.7998, KL Div: 3266.4980
Epoch[15/15], Step [440/469], Reconst Loss: 10076.4316, KL Div: 3183.8064
Epoch[15/15], Step [450/469], Reconst Loss: 10464.2637, KL Div: 3294.3083
Epoch[15/15], Step [460/469], Reconst Loss:

### **Inference**

In [28]:
with torch.no_grad():
    # Save the sampled images
    z = torch.randn(batch_size, z_dim).to(device)
    out = model.decode(z).view(-1, 1, 28, 28)
    save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch+1)))
    
    # Save the reconstructed images
    out, _, _ = model(x)
    x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)
    save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch+1)))