In [55]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd

import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid

from torch.utils.tensorboard import SummaryWriter

In [56]:
n_epochs = 200
batch_size = 64
lr = 0.0002
b1 = 0.5
b2 = 0.999
latent_dim = 100
img_size = 28
channels = 1
n_critic = 5
clip_value = 0.01
sample_interval = 400
lambda_gp = 10

In [57]:
img_shape = (channels, img_size, img_size)
img_shape

(1, 28, 28)

In [58]:
device = torch.device('cuda:1')
device

device(type='cuda', index=1)

In [59]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
    
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        
        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.shape[0], *img_shape)
        return img

In [60]:
# TEST
generator = Generator().to(device)
z = torch.randn((batch_size, latent_dim, )).to(device)
img = generator(z)
img.shape

torch.Size([64, 1, 28, 28])

In [61]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
    
        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1)
        )

    def forward(self, img):
        img_flat = img.view(img.shape[0], -1)
        validity = self.model(img_flat)
        return validity

In [62]:
# TEST
discriminator = Discriminator().to(device)
validity = discriminator(img)
print(validity.shape)

torch.Size([64, 1])


In [63]:
from torchvision import datasets

dataloader = torch.utils.data.DataLoader(
    datasets.MNIST('./', train=True, download=True,
                   transform=transforms.Compose([transforms.Resize(img_size),
                                                 transforms.ToTensor(),
                                                 transforms.Normalize([0.5], [0.5])]),),
    batch_size=batch_size,
    shuffle=True)

In [64]:
batch = iter(dataloader).next()
imgs, labels = batch
print(imgs.shape, labels.shape)

torch.Size([64, 1, 28, 28]) torch.Size([64])


In [65]:
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

In [66]:
def compute_gradient_penalty(D, real_samples, fake_samples):
    # realとfakeを補間するランダムな重み
    alpha = np.random.random((real_samples.size(0), 1, 1, 1))
    alpha = torch.from_numpy(alpha).to(device)

    # realとfakeのランダムな補間を作成
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates.float())
    
    fake = torch.Tensor(real_samples.shape[0], 1).fill_(1.0).to(device)
    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

In [None]:
writer = SummaryWriter()
print(writer.log_dir)

batches_done = 0

for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(dataloader):
        real_imgs = imgs.to(device)
        
        # train discriminator
        optimizer_D.zero_grad()
        
        z = np.random.normal(0, 1, (imgs.shape[0], latent_dim))
        z = torch.from_numpy(z).float().to(device)

        fake_imgs = generator(z)

        # adversarial loss
        # fake_imgs入れたときは最小化してreal_imgsを入れたときは最大化するとloss_Dが小さくなる
        real_validity = discriminator(real_imgs)
        fake_validity = discriminator(fake_imgs)
        gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data)
        d_loss = torch.mean(fake_validity) - torch.mean(real_validity) + lambda_gp * gradient_penalty
        d_loss.backward()
        optimizer_D.step()

        # train the generator every n_critic iterations
        if i % n_critic == 0:
            optimizer_G.zero_grad()
            
            fake_imgs = generator(z)
            
            # fake_imgsを入れたときのDの出力を最大化する
            fake_validity = discriminator(fake_imgs)
            g_loss = - torch.mean(fake_validity)
            
            g_loss.backward()
            optimizer_G.step()
            
            writer.add_scalar('loss_G', g_loss.item(), batches_done)
            writer.add_scalar('loss_D', d_loss.item(), batches_done)

            print('[Epoch {}/{}] [Batch {}/{}] [D_loss: {:.3f}] [G_loss: {:.3f}]'.format(
                epoch, n_epochs,
                batches_done % len(dataloader), len(dataloader),
                d_loss.item(), g_loss.item()))

            imgs = make_grid(fake_imgs.data[:25], nrow=5, normalize=True)
            writer.add_image('Generated Images', imgs, batches_done)
#             save_image(fake_imgs.data[:25], f'images/epoch_{epoch}.png', nrow=5, normalize=True)

        batches_done += 1