In [2]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid

In [3]:
n_epochs = 200
batch_size = 64
lr = 0.00005
latent_dim = 100
img_size = 28
channels = 1
n_critic = 5
clip_value = 0.01
sample_interval = 400

In [4]:
img_shape = (channels, img_size, img_size)
img_shape

(1, 28, 28)

In [5]:
device = torch.device('cuda:1')
device

device(type='cuda', index=1)

In [22]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
    
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        
        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.shape[0], *img_shape)
        return img

In [23]:
# TEST
generator = Generator().to(device)
z = torch.randn((batch_size, latent_dim, )).to(device)
img = generator(z)
img.shape

torch.Size([64, 1, 28, 28])

In [24]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
    
        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1)
        )

    def forward(self, img):
        img_flat = img.view(img.shape[0], -1)
        validity = self.model(img_flat)
        return validity

In [25]:
# TEST
discriminator = Discriminator().to(device)
validity = discriminator(img)
print(validity.shape)

torch.Size([64, 1])


In [26]:
from torchvision import datasets

dataloader = torch.utils.data.DataLoader(
    datasets.MNIST('./', train=True, download=True,
                   transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),),
    batch_size=batch_size,
    shuffle=True)

In [27]:
batch = iter(dataloader).next()
imgs, labels = batch
print(imgs.shape, labels.shape)

torch.Size([64, 1, 28, 28]) torch.Size([64])


In [28]:
optimizer_G = torch.optim.RMSprop(generator.parameters(), lr)
optimizer_D = torch.optim.RMSprop(discriminator.parameters(), lr)

In [29]:
from torch.utils.tensorboard import SummaryWriter


writer = SummaryWriter()
print(writer.log_dir)

batches_done = 0

for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(dataloader):
        real_imgs = imgs.to(device)
        
        # train discriminator
        optimizer_D.zero_grad()
        
        z = np.random.normal(0, 1, (imgs.shape[0], latent_dim))
        z = torch.from_numpy(z).float().to(device)

        fake_imgs = generator(z).detach()

        # adversarial loss
        # fake_imgs入れたときは最小化してreal_imgsを入れたときは最大化するとloss_Dが小さくなる
        loss_D = torch.mean(discriminator(fake_imgs)) - torch.mean(discriminator(real_imgs))
        loss_D.backward()
        optimizer_D.step()

        # clip weights of discriminator
        for p in discriminator.parameters():
            p.data.clamp_(-clip_value, clip_value)

        # train the generator every n_critic iterations
        if i % n_critic == 0:
            optimizer_G.zero_grad()
            
            gen_imgs = generator(z)
            
            # fake_imgsを入れたときのDの出力を最大化する
            loss_G = - torch.mean(discriminator(gen_imgs))
            
            loss_G.backward()
            optimizer_G.step()
            
            writer.add_scalar('loss_G', loss_G.item(), batches_done)
            writer.add_scalar('loss_D', loss_D.item(), batches_done)

            print('[Epoch {}/{}] [Batch {}/{}] [D_loss: {:.3f}] [G_loss: {:.3f}]'.format(
                epoch, n_epochs,
                batches_done % len(dataloader), len(dataloader),
                loss_D.item(), loss_G.item()))

            imgs = make_grid(gen_imgs.data[:25], nrow=5, normalize=True)
            writer.add_image('Generated Images', imgs, batches_done)
#             save_image(gen_imgs.data[:25], f'images/epoch_{epoch}.png', nrow=5, normalize=True)

        batches_done += 1

[Epoch 48/200] [Batch 240/938] [D_loss: -0.274] [G_loss: -0.148]
[Epoch 48/200] [Batch 245/938] [D_loss: -0.279] [G_loss: -0.169]
[Epoch 48/200] [Batch 250/938] [D_loss: -0.258] [G_loss: -0.134]
[Epoch 48/200] [Batch 255/938] [D_loss: -0.287] [G_loss: -0.210]
[Epoch 48/200] [Batch 260/938] [D_loss: -0.280] [G_loss: -0.161]
[Epoch 48/200] [Batch 265/938] [D_loss: -0.248] [G_loss: -0.149]
[Epoch 48/200] [Batch 270/938] [D_loss: -0.243] [G_loss: -0.165]
[Epoch 48/200] [Batch 275/938] [D_loss: -0.265] [G_loss: -0.250]
[Epoch 48/200] [Batch 280/938] [D_loss: -0.276] [G_loss: -0.172]
[Epoch 48/200] [Batch 285/938] [D_loss: -0.284] [G_loss: -0.125]
[Epoch 48/200] [Batch 290/938] [D_loss: -0.228] [G_loss: -0.282]
[Epoch 48/200] [Batch 295/938] [D_loss: -0.224] [G_loss: -0.205]
[Epoch 48/200] [Batch 300/938] [D_loss: -0.251] [G_loss: -0.218]
[Epoch 48/200] [Batch 305/938] [D_loss: -0.264] [G_loss: -0.235]
[Epoch 48/200] [Batch 310/938] [D_loss: -0.266] [G_loss: -0.244]
[Epoch 48/200] [Batch 315