# **VAE**

In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image

### **Device**

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### **Hyper-parameter**

In [3]:
image_size = 784
h_dim = 400 # hidden state dim
z_dim = 20 # z-latent dim
num_epochs = 15
batch_size = 128
learning_rate = 1e-3

In [4]:
sample_dir = 'samples'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

### **Dataset & DataLoader**

In [5]:
# MNIST dataset
dataset = torchvision.datasets.MNIST(root="./data", 
                                     train=True, 
                                     transform=transforms.ToTensor(),
                                     download=False)

# Data Loader
data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                          batch_size=batch_size,
                                          shuffle=True)

### **VAE model**

- 知乎【VAE模型】：[link](https://zhuanlan.zhihu.com/p/34998569?utm_campaign=shareopn&utm_medium=social&utm_oi=994876384665292800&utm_psn=1638135436173361152&utm_source=wechat_session)

In [6]:
class VAE(nn.Module):
    def __init__(self, image_size=784, h_dim=400, z_dim=20):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(image_size, h_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(h_dim, z_dim)
        self.fc4 = nn.Linear(z_dim, h_dim)
        self.fc5 = nn.Linear(h_dim, image_size)
    
    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc2(h), self.fc3(h)
    
    def decode(self, z):
        h = F.relu(self.fc4(z))
        return F.sigmoid(self.fc5(h))
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var/2)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        x_reconst = self.decode(z)
        return x_reconst, mu, log_var

In [7]:
model = VAE().to(device)

### **Training Config**

In [8]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

## **Train**