<a href="https://colab.research.google.com/github/Pinokcio/ML_Study/blob/main/MNIST_VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip3 install torch
!pip3 install torchvision



In [None]:
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

device = "cuda" if torch.cuda.is_available() else "cpu"

train_ds = datasets.FashionMNIST(
    root = "data",
    train = True,
    download = True,
    transform=ToTensor(),
)
test_ds = datasets.FashionMNIST(
    root = "data",
    train = False,
    download = False,
    transform=ToTensor(),
)
batch_size = 100
trainDL = DataLoader(train_ds, batch_size = batch_size, shuffle = True)
testDL = DataLoader(test_ds, batch_size = batch_size, shuffle = False)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/5148 [00:00<?, ?it/s]

Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw



In [None]:
class VAE(nn.Module):
  def __init__(self, x_dim, h_dim1, h_dim2, z_dim):
    super(VAE, self).__init__()

    self.fc1 = nn.Linear(x_dim, h_dim1)
    self.fc2 = nn.Linear(h_dim1, h_dim2)
    self.fc31 = nn.Linear(h_dim2, z_dim)
    self.fc32 = nn.Linear(h_dim2, z_dim)

    self.fc4 = nn.Linear(z_dim, h_dim2)
    self.fc5 = nn.Linear(h_dim2, h_dim1)
    self.fc6 = nn.Linear(h_dim1, x_dim)
    
  def encoder(self, x):
    h = F.relu(self.fc1(x))
    h = F.relu(self.fc2(h))
    return self.fc31(h), self.fc32(h) # mu, log_var

  def sampling(self, mu, log_var):
    std = torch.exp(0.5 * log_var)
    eps = torch.randn_like(std) # std와 size가 같은 random 값 반환
    return eps.mul(std).add_(mu) # N(0,1) 정규분포에서 추출한 eps(수학기호로 z)에 표준편차를 곱하고 평균을 더해 z(수학기호로 x)를 구함 

  def decoder(self, z):
    h = F.relu(self.fc4(z))
    h = F.relu(self.fc5(h))
    return F.sigmoid(self.fc6(h))

  def forward(self, x):
    mu, log_var = self.encoder(x.view(-1, 784)) #(100, 1, 28, 28) => (100, 784)
    z = self.sampling(mu, log_var)
    return self.decoder(z), mu, log_var #loss를 계산할 때 필요한 인자들 리턴
  
vae = VAE(x_dim=784, h_dim1=512, h_dim2=256, z_dim=2)
vae = vae.to(device)

In [None]:
optimizer = optim.Adam(vae.parameters(), lr = 1e-3)

def loss_function(recon_x, x, mu, log_var): 
  BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
  KLD = 0.5*torch.sum(-1 - log_var + mu.pow(2) + log_var.exp())
  return BCE + KLD

In [None]:
def train(epoch):
  vae.train()
  train_loss = 0
  for batch_idx, (data, _ ) in enumerate(trainDL):
    data = data.to(device)
    optimizer.zero_grad()
    recon_batch, mu, log_var = vae(data)
    loss = loss_function(recon_batch, data, mu, log_var)
    loss.backward()
    optimizer.step()
    train_loss += loss.item()

    if batch_idx % 100 == 0:
      print(f"Train Epoch : {epoch} [{batch_idx * len(data)}/{len(trainDL.dataset)}] ({100 * batch_idx / len(trainDL):.0f}%)\tLoss : {loss.item()/len(data):.6f} ")
  print(f"====> Epoch: {epoch} Average loss: {train_loss/len(trainDL):.4f}")

In [None]:
def test():
  vae.eval()
  test_loss = 0
  with torch.no_grad():
    for data, _ in testDL:
      data = data.to(device)
      recon, mu, log_var = vae(data)
      test_loss += loss_function(recon, data, mu, log_var).item()

  test_loss /= len(testDL.dataset)
  print(f'====> Test set loss : {test_loss:.4f}')

In [None]:
for epoch in range(1, 51):
  train(epoch)
  test()

In [None]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive/

In [None]:
import matplotlib.pyplot as plt
from torchvision.utils import save_image
with torch.no_grad():
  z = torch.randn(100, 2).to(device)
  sample = vae.decoder(z).to(device).view(100,1,28,28)
  save_image(sample, './save_image' + '.png')
  """for i in sample:
    plt.imshow(i.cpu().numpy().squeeze(), aspect='auto')"""