# Generative adversarial network (using the GAN class)

In [1]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tensorboardX import SummaryWriter

from tqdm import tqdm

batch_size = 64
epochs = 10
seed = 1
torch.manual_seed(seed)

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

In [2]:
root = '../data'
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Lambda(lambd=lambda x: x.view(-1))])
kwargs = {'batch_size': batch_size, 'num_workers': 1, 'pin_memory': True}

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=True, transform=transform, download=True),
    shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root=root, train=False, transform=transform),
    shuffle=False, **kwargs)

In [3]:
from pixyz.distributions import Deterministic, DataDistribution
from pixyz.distributions import Normal
from pixyz.models import GAN

In [4]:
x_dim = 784
z_dim = 100

# generator model p(x|z)    
class Generator(Deterministic):
    def __init__(self):
        super(Generator, self).__init__(cond_var=["z"], var=["x"], name="p")

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(z_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, x_dim),
            nn.Sigmoid()
        )

    def forward(self, z):
        x = self.model(z)
        return {"x": x}
    
    
# prior model p(z)
loc = torch.tensor(0.).to(device)
scale = torch.tensor(1.).to(device)
prior = Normal(loc=loc, scale=scale, var=["z"], dim=z_dim, name="p_prior")

In [5]:
# generative model
p_g = Generator()
p = (p_g*prior).marginalize_var("z")
p.to(device)

# data distribution
p_data = DataDistribution(["x"])
p_data.to(device)

print(p)
print(p_data)

Distribution:
  p(x) = ∫p(x|z)p_prior(z)dz
Network architecture:
  p_prior(z) (Normal): Normal()
  p(x|z) (Deterministic): Generator(
    (model): Sequential(
      (0): Linear(in_features=100, out_features=128, bias=True)
      (1): LeakyReLU(negative_slope=0.2, inplace)
      (2): Linear(in_features=128, out_features=256, bias=True)
      (3): BatchNorm1d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
      (4): LeakyReLU(negative_slope=0.2, inplace)
      (5): Linear(in_features=256, out_features=512, bias=True)
      (6): BatchNorm1d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
      (7): LeakyReLU(negative_slope=0.2, inplace)
      (8): Linear(in_features=512, out_features=1024, bias=True)
      (9): BatchNorm1d(1024, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
      (10): LeakyReLU(negative_slope=0.2, inplace)
      (11): Linear(in_features=1024, out_features=784, bias=True)
      (12): Sigmoid()
    )
  )
Distribution:


In [6]:
# discriminator model p(t|x)
class Discriminator(Deterministic):
    def __init__(self):
        super(Discriminator, self).__init__(cond_var=["x"], var=["t"], name="d")

        self.model = nn.Sequential(
            nn.Linear(x_dim, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        t = self.model(x)
        return {"t": t}
    
d = Discriminator()
d.to(device)

print(d)

Distribution:
  d(t|x) (Deterministic)
Network architecture:
  Discriminator(
    (model): Sequential(
      (0): Linear(in_features=784, out_features=512, bias=True)
      (1): LeakyReLU(negative_slope=0.2, inplace)
      (2): Linear(in_features=512, out_features=256, bias=True)
      (3): LeakyReLU(negative_slope=0.2, inplace)
      (4): Linear(in_features=256, out_features=1, bias=True)
      (5): Sigmoid()
    )
  )


In [7]:
model = GAN(p_data, p, d,
            optimizer=optim.Adam, optimizer_params={"lr":0.0002},
            d_optimizer=optim.Adam, d_optimizer_params={"lr":0.0002})
print(model)

Distributions (for training): 
  p(x) 
Loss function: 
  mean(mean(AdversarialJS[p_data(x)||p(x)])) 
Optimizer: 
  Adam (
  Parameter Group 0
      amsgrad: False
      betas: (0.9, 0.999)
      eps: 1e-08
      lr: 0.0002
      weight_decay: 0
  )


In [8]:
def train(epoch):
    train_loss = 0
    train_d_loss = 0
    for x, _ in tqdm(train_loader):
        x = x.to(device)
        loss, d_loss = model.train({"x": x})
        train_loss += loss
        train_d_loss += d_loss
 
    train_loss = train_loss * train_loader.batch_size / len(train_loader.dataset)
    train_d_loss = train_d_loss * train_loader.batch_size / len(train_loader.dataset)
    print('Epoch: {} Train loss: {:.4f}, {:.4f}'.format(epoch, train_loss.item(), train_d_loss.item()))
    return train_loss

In [9]:
def test(epoch):
    test_loss = 0
    test_d_loss = 0
    for x, _ in test_loader:
        x = x.to(device)
        loss, d_loss = model.test({"x": x})
        test_loss += loss
        test_d_loss += d_loss

    test_loss = test_loss * test_loader.batch_size / len(test_loader.dataset)
    test_d_loss = test_d_loss * test_loader.batch_size / len(test_loader.dataset)
    
    print('Test loss: {:.4f}, {:.4f}'.format(test_loss, test_d_loss.item()))
    return test_loss

In [10]:
def plot_image_from_latent(z_sample):
    with torch.no_grad():
        sample = p_g.sample({"z": z_sample})["x"].view(-1, 1, 28, 28).cpu()
        return sample

In [11]:
writer = SummaryWriter()

z_sample = torch.randn(64, z_dim).to(device)
_x, _y = iter(test_loader).next()
_x = _x.to(device)
_y = _y.to(device)

for epoch in range(1, epochs + 1):
    train_loss = train(epoch)
    test_loss = test(epoch)
    
    sample = plot_image_from_latent(z_sample)

    writer.add_scalar('train_loss', train_loss.item(), epoch)
    writer.add_scalar('test_loss', test_loss.item(), epoch)      
    
    writer.add_images('Image_from_latent', sample, epoch)
    
writer.close()

100%|██████████| 938/938 [00:10<00:00, 92.17it/s] 


Epoch: 1 Train loss: 12.1933, 0.1214


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 14.7451, 0.0489


100%|██████████| 938/938 [00:09<00:00, 95.00it/s] 

Epoch: 2 Train loss: 19.1830, 0.0474



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 21.1420, 0.0389


100%|██████████| 938/938 [00:10<00:00, 87.31it/s]

Epoch: 3 Train loss: 24.6523, 0.0356



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 28.3799, 0.0162


100%|██████████| 938/938 [00:08<00:00, 97.76it/s]

Epoch: 4 Train loss: 28.3241, 0.0393



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 29.3135, 0.0132


100%|██████████| 938/938 [00:10<00:00, 90.14it/s]

Epoch: 5 Train loss: 28.2658, 0.0378



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 28.8593, 0.0516


100%|██████████| 938/938 [00:09<00:00, 97.95it/s] 


Epoch: 6 Train loss: 29.4387, 0.0379


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 27.2278, 0.0347


100%|██████████| 938/938 [00:10<00:00, 88.19it/s]

Epoch: 7 Train loss: 30.3915, 0.0406



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 31.6369, 0.0100


100%|██████████| 938/938 [00:10<00:00, 87.15it/s]


Epoch: 8 Train loss: 31.4287, 0.0414


  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 30.3456, 0.0501


100%|██████████| 938/938 [00:11<00:00, 83.19it/s]

Epoch: 9 Train loss: 29.3756, 0.0569



  0%|          | 0/938 [00:00<?, ?it/s]

Test loss: 28.7867, 0.0378


100%|██████████| 938/938 [00:10<00:00, 85.68it/s]

Epoch: 10 Train loss: 28.6039, 0.0601





Test loss: 30.1322, 0.0383
