# VAE (Variational AutoEncoder)

### 1. 데이터 로드 및 설정

In [1]:
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
import torchvision.datasets as datasets
from torchvision import transforms
from torch.utils.data import DataLoader

dataset = datasets.MNIST(root='dataset/', train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True)

100%|██████████| 9.91M/9.91M [00:00<00:00, 16.3MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 482kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.48MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 7.73MB/s]


### 2. 모델 생성

In [3]:
import torch.nn as nn

class VAE(nn.Module):
  def __init__(self, input_dim, hidden_dim=200, z_dim=20):
    super(VAE, self).__init__()

    self.img2hid = nn.Linear(input_dim, hidden_dim)
    self.hid2mu = nn.Linear(hidden_dim, z_dim)
    self.hid2sigma = nn.Linear(hidden_dim, z_dim)

    self.z2hid = nn.Linear(z_dim, hidden_dim)
    self.hid2img = nn.Linear(hidden_dim, input_dim)

    self.relu = nn.ReLU()

  def encoder(self, x):
    x = self.img2hid(x)
    x = self.relu(x)
    mu = self.hid2mu(x)
    sigma = self.hid2sigma(x)
    return mu, sigma

  def decoder(self, z):
    z = self.z2hid(z)
    z = self.relu(z)
    x = self.hid2img(z)
    x = torch.sigmoid(x)
    return x

  def forward(self, x):
    mu, sigma = self.encoder(x)
    epsilon = torch.randn_like(sigma)
    z_reparam = mu + sigma * epsilon
    x_reconst = self.decoder(z_reparam)
    return x_reconst, mu, sigma

In [4]:
model = VAE(784, 200, 20).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = nn.BCELoss(reduction="sum")

### 3. 모델 학습

In [5]:
from tqdm import tqdm

for epoch in range(10):
  for i, (x, _) in tqdm(enumerate(train_loader)):
    x = x.to(device).view(x.shape[0], 784)

    x_reconst, mu, sigma = model(x)

    reconst_loss = criterion(x_reconst, x)
    kl_div = -torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))

    loss = reconst_loss + kl_div
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

1875it [00:11, 158.03it/s]
1875it [00:10, 172.13it/s]
1875it [00:10, 171.78it/s]
1875it [00:10, 172.66it/s]
1875it [00:10, 177.99it/s]
1875it [00:10, 172.16it/s]
1875it [00:10, 172.15it/s]
1875it [00:10, 171.41it/s]
1875it [00:11, 170.43it/s]
1875it [00:10, 171.85it/s]


### 4. 추론 (이미지 생성)

In [6]:
from torchvision.utils import save_image

model = model.to("cpu")

def inference(digit, num_samples=3):
  images = []
  idx = 0

  for x, y in dataset:
    if y == digit:
      images.append(x)
      idx += 1
      if idx >= num_samples:
        break

  encoding_digit = []
  for img in images:
    with torch.no_grad():
      mu, sigma = model.encoder(img.view(1, 784))
    encoding_digit.append((mu, sigma))

  mu, sigma = encoding_digit[0]

  for example in range(num_samples):
    epsilon = torch.randn_like(sigma)
    z = mu + sigma * epsilon
    out = model.decoder(z)
    out = out.view(-1, 1, 28, 28)
    save_image(out, f"digit{digit}_smaple_{example}.png")

In [7]:
inference(7)