<a href="https://colab.research.google.com/github/Lzino/TIL_Today-I-Learned/blob/master/pytorch_AAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Pytorch_ver AAE
* https://github.com/eriklindernoren/PyTorch-GAN

In [1]:
import argparse
import os
import numpy as np
import math
import itertools

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

import matplotlib.pyplot as plt
import torchvision.transforms as transforms

In [2]:
n_epoch = 10
batch_size = 64
lr = 0.0001
b1=0.5 # ??
b2=0.999 # ??
n_cput=8 # ?????
latent_dim=10
img_size=32
channels=1
sample_interval=50 ## ???
img_shape = (channels, img_size, img_size)

cuda = True if torch.cuda.is_available() else False

### AutoEncoder

In [6]:
def reparameterization(mu, logvar):
  std = torch.exp(logvar /2)
  sample_z = Variable(Tensor(np.random.normal(0, 1 (mu.size(0), latent_dim))))
  z = sample_z * std + mu
  return z

class Encoder(nn.Module):
  def __init__(self):
    super(Encoder,self).__init__()

    # ref. of inplace : https://discuss.pytorch.org/t/whats-the-difference-between-nn-relu-and-nn-relu-inplace-true/948
    self.model = nn.Sequential(
        nn.Linear(int(np.prod(img_shape)), 512),
        nn.LeakyReLU(0.2, inplace= True),
        nn.Linear(512, 512),
        nn.BatchNorm1d(512),
        nn.LeakyReLU(0.2, inplace= True)
    )
    self.mu = nn.Linear(512, latent_dim)
    self.logvar = nn.Linear(512, latent_dim)

  def forward(self, img):
    img_flat = img.view(img.shape[0], -1)
    x = self.model(img_flat)
    mu = self.mu(x)
    logvar = self.logvar(x)
    z = reparameterization(mu, logvar)
    return z 

class Decoder(nn.Module):
  def __init__(self):
    super(Decoder, self).__init__()

    self.model = nn.Sequential(
        nn.Linear(latent_dim, 512),
        nn.LeakyReLU(512, 512),
        nn.BatchNorm1d(512),
        nn.LeakyReLU(0.2, inplace= True),
        nn.Linear(512, int(np.prod(img_shape))),
        nn.Tanh()
    )
  def forward(self, z):
    img_flat = self.model(z)
    img = img_flat.view(img_flat.shape[0], *img_shape)
    return img

### Discriminator

In [7]:
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()

    self.model = nn.Sequential(
        nn.Linear(latent_dim, 512),
        nn.LeakyReLU(0.2, inplace = True),
        nn.Linear(512, 256),
        nn.LeakyReLU(0.2, inplace= True),
        nn.Linear(256, 1),
        nn.Sigmoid(),
    )
    
    def forward(self, z):
      validity = self.model(z)
      return validity

### Loss & Optimizer& Etc. Setting

In [8]:
adversarial_loss = torch.nn.BCELoss()
pixelwise_loss = torch.nn.L1Loss() # can be MSELoss

In [9]:
encoder = Encoder()
decoder = Decoder()
discriminator = Discriminator()

In [10]:
if cuda :
  encoder.cuda()
  decoder.cuda()
  discriminator.cuda()
  adversarial_loss.cuda()
  pixelwise_loss.cuda()

In [18]:
optimizer_G = torch.optim.Adam(
    itertools.chain(encoder.parameters(), decoder.parameters()), lr=lr, betas=(b1, b2)
)
Optimizer_D = torch.optim.Adam(discriminator.parameters(), lr = lr, betas = (b1 , b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

In [14]:
# Data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size= batch_size,
    shuffle=True,
)

In [15]:
def sample_image(n_row, batches_done):
    """Saves a grid of generated digits"""
    # Sample noise
    z = Variable(Tensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))
    gen_imgs = decoder(z)
    save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)

## Training

In [None]:
for epoch in range(n_epoch):
  for i, (imgs, _) in enumerate(dataloader):

    valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), require_grad = False)
    fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), require_grad = False)

    real_imgs = Variable(imgs.type(Tensor))

    ## Train Generator

    Optimizer_G.zero_grad()

    encoded_imgs = encoder(real_imgs) # output of Encoder
    decoded_imgs = decoder(encoded_imgs) # output of Decoder

    g_loss = 0.001 * adversarial_loss(discriminator(encoded_imgs), valid) + 0.999 * pixelwise_loss(decoded_imgs, real_imgs)

    g_loss.backward()
    optimizer_G.step()

    ## Train Discriminator

    optimizer_D.zero_grad()

    # Sample noise
    z = Variable(Tensor(np.random.normal(0,1, (img.shape[0]), latent_dim)))

    real_loss = adversarial_loss(discriminator(z), valid)
    
    # ref. of fake.detach() : https://redstarhong.tistory.com/64 
    fake_loss = adversarial_loss(discriminator(encoded_imgs.detach()), fake)
    d_loss = 0.5 * (real_loss + fake_loss)

    d_loss.backward()
    optimizer_D.step()

    print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
        )
    
    batch_done = epoch * len(dataloader) + i
    if batch_done % sample_interval == 0:
      sample_image(n_row = 10, batches_done= batches_done)