In [1]:
import torch
import torchvision
import os
import random
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as vutils

In [2]:
workers = 2
batch_size=50
nz = 100
nch_g = 128
nch_d = 128
n_epoch = 15
lr = 0.0002
beta1 = 0.5
outf = './result_cgan'
display_interval = 600

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('device:', device)

device: cuda:0


### MNIST

In [4]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, ), (0.5, ))])
dataset = datasets.MNIST(root='./data', 
                                        train=True,
                                        download=True,
                                        transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = int(workers))

In [5]:
class Generator(nn.Module):
    def __init__(self, nz = 100,  nch_g = 64, nch = 1):
        super().__init__()
        self.layers = nn.Sequential(
                nn.ConvTranspose2d(nz, nch_g * 4, 3, 1, 0),    
                nn.BatchNorm2d(nch_g * 4),                      
                nn.ReLU(),
                nn.ConvTranspose2d(nch_g * 4, nch_g * 2, 3, 2, 0),
                nn.BatchNorm2d(nch_g * 2),
                nn.ReLU(),
                nn.ConvTranspose2d(nch_g * 2, nch_g, 4, 2, 1),
                nn.BatchNorm2d(nch_g),
                nn.ReLU(),
                nn.ConvTranspose2d(nch_g, nch, 4, 2, 1),
                nn.Tanh()
                )

    def forward(self, z):
        return self.layers(z)

In [6]:
class Discriminator(nn.Module):
    def __init__(self, nch=1, nch_d=64):
        super().__init__()
        self.layers = nn.Sequential(
                nn.Conv2d(nch, nch_d, 4, 2, 1),     # 畳み込み
                nn.LeakyReLU(negative_slope=0.2),
                nn.Conv2d(nch_d, nch_d * 2, 4, 2, 1),
                nn.BatchNorm2d(nch_d * 2),
                nn.LeakyReLU(negative_slope=0.2),
                nn.Conv2d(nch_d * 2, nch_d * 4, 3, 2, 0),
                nn.BatchNorm2d(nch_d * 4),
                nn.LeakyReLU(negative_slope=0.2),
                nn.Conv2d(nch_d * 4, 1, 3, 1, 0),
                nn.Sigmoid()
                ) 

    def forward(self, x):
        return self.layers(x).squeeze()

In [7]:
Gnet = Generator(nz = nz+10, nch_g = nch_g).to(device)
Dnet = Discriminator(nch = 1+10, nch_d = nch_d).to(device)
print(Gnet)
print(Dnet)

Generator(
  (layers): Sequential(
    (0): ConvTranspose2d(110, 512, kernel_size=(3, 3), stride=(1, 1))
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): ConvTranspose2d(512, 256, kernel_size=(3, 3), stride=(2, 2))
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): ConvTranspose2d(128, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (10): Tanh()
  )
)
Discriminator(
  (layers): Sequential(
    (0): Conv2d(11, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats

In [8]:
criterion = nn.BCELoss()
optimizerD = optim.Adam(Dnet.parameters(), lr=lr, betas=(beta1, 0.999), weight_decay=1e-5)
optimizerG = optim.Adam(Gnet.parameters(), lr=lr, betas=(beta1, 0.999), weight_decay=1e-5)

In [9]:
def onehot_encode(label, device, n_class = 10):
    eye = torch.eye(n_class, device = device)
    return eye[label].view(-1, n_class, 1, 1) 

def concat_image_label(image, label, device, n_class=10):
    B, C, H, W = image.shape
    
    oh_label = onehot_encode(label, device)
    oh_label = oh_label.expand(B, n_class, H, W)
    return torch.cat((image, oh_label), dim = 1)

def concat_noise_label(noise, label, device):
    oh_label = onehot_encode(label, device) 
    return torch.cat((noise, oh_label), dim = 1) 

In [10]:
fixed_noise = torch.randn(batch_size, nz, 1, 1, device=device)
fixed_label = [i for i in range(10)] * (batch_size // 10)
fixed_label = torch.tensor(fixed_label, dtype=torch.long, device=device)
fixed_noise_label = concat_noise_label(fixed_noise, fixed_label, device) 
print(fixed_noise.shape)
print(fixed_label)
print(fixed_noise_label.shape)

torch.Size([50, 100, 1, 1])
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3,
        4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7,
        8, 9], device='cuda:0')
torch.Size([50, 110, 1, 1])


In [11]:
if not os.path.exists(outf):
    os.mkdir(outf)

In [12]:
# 学習のループ
for epoch in range(n_epoch):
    for itr, data in enumerate(dataloader):
        real_image = data[0].to(device)
        real_label = data[1].to(device)
        real_image_label = concat_image_label(real_image, real_label, device) 
        sample_size = real_image.size(0)
        noise = torch.randn(sample_size, nz, 1, 1, device = device)
        fake_label = torch.randint(10, (sample_size,), dtype = torch.long, device = device)
        fake_noise_label = concat_noise_label(noise, fake_label, device)        
        real_target = torch.full((sample_size,), 1., device = device)
        fake_target = torch.full((sample_size,), 0., device = device)
        
        ############################
        # 識別器Dの更新
        ###########################
        Dnet.zero_grad()    
        
        output = Dnet(real_image_label)
        errD_real = criterion(output, real_target)

        D_x = output.mean().item()

        fake_image = Gnet(fake_noise_label)
        fake_image_label = concat_image_label(fake_image, fake_label, device)   
        
        output = Dnet(fake_image_label.detach()) 
        errD_fake = criterion(output, fake_target)  
        D_G_z1 = output.mean().item()

        errD = errD_real + errD_fake
        errD.backward() 
        optimizerD.step() 

        ############################
        # 生成器Gの更新
        ###########################
        Gnet.zero_grad()
        
        output = Dnet(fake_image_label)
        errG = criterion(output, real_target) 
        errG.backward() 
        D_G_z2 = output.mean().item()

        optimizerG.step() 

        if itr % display_interval == 0: 
            print('[{}/{}][{}/{}] Loss_D: {:.3f} Loss_G: {:.3f} D(x): {:.3f} D(G(z)): {:.3f}/{:.3f}'
                  .format(epoch + 1, n_epoch,
                          itr + 1, len(dataloader),
                          errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
            
        if epoch == 0 and itr == 0: 
            vutils.save_image(real_image, '{}/real_samples.png'.format(outf),
                              normalize=True, nrow=10)

    ############################
    # 確認用画像の生成
    ############################
    fake_image = Gnet(fixed_noise_label)  
    vutils.save_image(fake_image.detach(), '{}/fake_samples_epoch_{:03d}.png'.format(outf, epoch + 1),
                      normalize=True, nrow=10)

    ############################
    # モデルの保存
    ############################
    if (epoch + 1) % 10 == 0:   
        torch.save(Gnet.state_dict(), '{}/Gnet_epoch_{}.pth'.format(outf, epoch + 1))
        torch.save(Dnet.state_dict(), '{}/Dnet_epoch_{}.pth'.format(outf, epoch + 1))

[1/15][1/1200] Loss_D: 1.480 Loss_G: 1.252 D(x): 0.491 D(G(z)): 0.522/0.296
[1/15][601/1200] Loss_D: 0.641 Loss_G: 2.058 D(x): 0.688 D(G(z)): 0.182/0.166
[2/15][1/1200] Loss_D: 1.181 Loss_G: 1.375 D(x): 0.489 D(G(z)): 0.306/0.281
[2/15][601/1200] Loss_D: 1.651 Loss_G: 1.158 D(x): 0.694 D(G(z)): 0.686/0.341
[3/15][1/1200] Loss_D: 0.920 Loss_G: 1.424 D(x): 0.552 D(G(z)): 0.237/0.267
[3/15][601/1200] Loss_D: 0.858 Loss_G: 1.675 D(x): 0.573 D(G(z)): 0.218/0.212
[4/15][1/1200] Loss_D: 1.210 Loss_G: 1.502 D(x): 0.799 D(G(z)): 0.600/0.237
[4/15][601/1200] Loss_D: 0.909 Loss_G: 1.070 D(x): 0.662 D(G(z)): 0.361/0.358
[5/15][1/1200] Loss_D: 1.420 Loss_G: 0.715 D(x): 0.368 D(G(z)): 0.287/0.514
[5/15][601/1200] Loss_D: 1.936 Loss_G: 0.984 D(x): 0.367 D(G(z)): 0.555/0.400
[6/15][1/1200] Loss_D: 0.751 Loss_G: 1.018 D(x): 0.659 D(G(z)): 0.259/0.387
[6/15][601/1200] Loss_D: 0.713 Loss_G: 1.758 D(x): 0.641 D(G(z)): 0.199/0.197
[7/15][1/1200] Loss_D: 1.143 Loss_G: 1.450 D(x): 0.724 D(G(z)): 0.506/0.260


In [17]:
if not os.path.exists("./grad_img"):
    os.mkdir("./grad_img")
from torchvision.utils import save_image
import glob
from PIL import Image

def concat_noise_label(noise, label1,label2,weight, device):
    oh_label1 = onehot_encode(label1, device)
    oh_label2 = onehot_encode(label2, device)
    oh_label = weight * oh_label1 + (1 - weight) * oh_label2
    return torch.cat((noise, oh_label), dim = 1) 
    
gradation = 100
noise = torch.randn(1, nz, 1, 1, device = device)
for i in range(gradation):
    weight = i / gradation
    fake_noise_label = concat_noise_label(noise, 9, 6, weight, device)
    out = Gnet(fake_noise_label)
    save_image(out, "./grad_img/{:03d}.jpg".format(i))
    
files = sorted(glob.glob('./grad_img/*.jpg'))  
images = list(map(lambda file : Image.open(file) , files))
images[0].save('generating_process.gif' , save_all = True , append_images = images[1:] , duration = 100 , loop = 0)