In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid

In [23]:
n_epochs = 200
batch_size = 100
lr = 0.0002
b1 = 0.5
b2 = 0.999
latent_dim = 100
n_classes = 10
img_size = 32
channels = 1
sample_interval = 400

In [24]:
img_shape = (channels, img_size, img_size)
img_shape

(1, 32, 32)

In [25]:
device = torch.device('cuda:0')
device

device(type='cuda', index=0)

In [26]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
    
        self.label_emb = nn.Embedding(n_classes, n_classes)

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        
        self.model = nn.Sequential(
            *block(latent_dim + n_classes, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        gen_input = torch.cat((self.label_emb(labels), noise), -1)
        img = self.model(gen_input)
        img = img.view(img.shape[0], *img_shape)
        return img

In [27]:
# TEST
generator = Generator().to(device)
noise = torch.randn((batch_size, latent_dim, )).to(device)
labels = torch.randint(low=0, high=n_classes, size=(batch_size, )).to(device)
print(noise.shape, labels.shape)

img = generator.forward(noise, labels)
print(img.shape)

torch.Size([100, 100]) torch.Size([100])
torch.Size([100, 1, 32, 32])


In [28]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.label_embedding = nn.Embedding(n_classes, n_classes)

        self.model = nn.Sequential(
            nn.Linear(n_classes + int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1)
        )

    def forward(self, img, labels):
        d_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1)
        validity = self.model(d_in)
        return validity

In [29]:
# TEST
discriminator = Discriminator().to(device)
validity = discriminator(img, labels)
print(validity.shape)

torch.Size([100, 1])


In [30]:
from torchvision import datasets

dataloader = torch.utils.data.DataLoader(
    datasets.MNIST('./', train=True, download=True,
                   transform=transforms.Compose([transforms.Resize(img_size),
                                                 transforms.ToTensor(),
                                                 transforms.Normalize([0.5], [0.5])]),),
    batch_size=batch_size,
    shuffle=True)

In [31]:
# TEST
batch = iter(dataloader).next()
imgs, labels = batch
print(imgs.shape, labels.shape)

torch.Size([100, 1, 32, 32]) torch.Size([100])


In [32]:
import os
os.makedirs('images', exist_ok=True)

def sample_image(n_row, batches_done):
    """Saves a grid of generated digits ranging from 0 to n_classes"""
    z = torch.from_numpy(np.random.normal(0, 1, (n_row ** 2, latent_dim))).float()
    labels = torch.from_numpy(np.array([num for _ in range(n_row) for num in range(n_row)])).long()
    z, labels = z.to(device), labels.to(device)
    gen_imgs = generator(z, labels)
    save_image(gen_imgs.data, 'images/%d.png' % batches_done, nrow=n_row, normalize=True)
    return gen_imgs.data

In [33]:
# TEST
imgs = sample_image(8, 5)
imgs.shape

torch.Size([64, 1, 32, 32])

In [34]:
adversarial_loss = torch.nn.MSELoss()
generator = Generator().to(device)
discriminator = Discriminator().to(device)

In [35]:
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

In [None]:
from torch.utils.tensorboard import SummaryWriter


writer = SummaryWriter()
print(writer.log_dir)

batches_done = 0

for epoch in range(n_epochs):
    for i, (real_imgs, labels) in enumerate(dataloader):
        batch_size = real_imgs.shape[0]

        real_imgs = real_imgs.to(device)
        labels = labels.to(device)

        # Adversarial ground truths
        real = torch.FloatTensor(batch_size, 1).fill_(1.0).to(device)
        fake = torch.FloatTensor(batch_size, 1).fill_(0.0).to(device)
        real.requires_grad = False
        fake.requires_grad = False

        # train generator
        optimizer_G.zero_grad()
        
        z = torch.from_numpy(np.random.normal(0, 1, (batch_size, latent_dim))).float().to(device)
        gen_labels = torch.from_numpy(np.random.randint(0, n_classes, batch_size)).long().to(device)
        
        gen_imgs = generator(z, gen_labels)
        
        # generatorはDの出力がrealであるほどよい
        validity = discriminator(gen_imgs, gen_labels)
        g_loss = adversarial_loss(validity, real)

        g_loss.backward()
        optimizer_G.step()
        
        # train discriminator
        optimizer_D.zero_grad()
        
        # discriminatorはrealはreal、fakeはfakeに近づいてほしい
        validity_real = discriminator(real_imgs, labels)
        d_real_loss = adversarial_loss(validity_real, real)
        
        validity_fake = discriminator(gen_imgs.detach(), gen_labels)
        d_fake_loss = adversarial_loss(validity_fake, fake)
        
        d_loss = (d_real_loss + d_fake_loss) / 2
        
        d_loss.backward()
        optimizer_D.step()
            
        writer.add_scalar('loss_G', g_loss.item(), batches_done)
        writer.add_scalar('loss_D', d_loss.item(), batches_done)

#         print('[Epoch {}/{}] [Batch {}/{}] [D_loss: {:.3f}] [G_loss: {:.3f}]'.format(
#             epoch, n_epochs,
#             i, len(dataloader),
#             d_loss.item(), g_loss.item()))

        batches_done = epoch * len(dataloader) + i
        if batches_done % sample_interval == 0:
            gen_imgs = sample_image(10, batches_done)
            imgs = make_grid(gen_imgs, nrow=10, normalize=True)
            writer.add_image('Generated Images', imgs, batches_done)

runs/Nov18_21-25-49_dlgdev0001
