<a href="https://colab.research.google.com/github/TakuroTerui/object_gan_autoencoder_image/blob/master/ConditionalGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms

dataset = datasets.MNIST(
    root='/content/drive/MyDrive/Colab Notebooks/GAN/C-GAN/C-GAN_PyTorch/mnist',
    download=True,
    train=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
)

batch_size = 50
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device:', device)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /content/drive/MyDrive/Colab Notebooks/GAN/C-GAN/C-GAN_PyTorch/mnist/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 341108851.40it/s]

Extracting /content/drive/MyDrive/Colab Notebooks/GAN/C-GAN/C-GAN_PyTorch/mnist/MNIST/raw/train-images-idx3-ubyte.gz to /content/drive/MyDrive/Colab Notebooks/GAN/C-GAN/C-GAN_PyTorch/mnist/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /content/drive/MyDrive/Colab Notebooks/GAN/C-GAN/C-GAN_PyTorch/mnist/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 88162804.82it/s]


Extracting /content/drive/MyDrive/Colab Notebooks/GAN/C-GAN/C-GAN_PyTorch/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to /content/drive/MyDrive/Colab Notebooks/GAN/C-GAN/C-GAN_PyTorch/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /content/drive/MyDrive/Colab Notebooks/GAN/C-GAN/C-GAN_PyTorch/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 83999021.01it/s]

Extracting /content/drive/MyDrive/Colab Notebooks/GAN/C-GAN/C-GAN_PyTorch/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to /content/drive/MyDrive/Colab Notebooks/GAN/C-GAN/C-GAN_PyTorch/mnist/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /content/drive/MyDrive/Colab Notebooks/GAN/C-GAN/C-GAN_PyTorch/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 17445539.16it/s]


Extracting /content/drive/MyDrive/Colab Notebooks/GAN/C-GAN/C-GAN_PyTorch/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to /content/drive/MyDrive/Colab Notebooks/GAN/C-GAN/C-GAN_PyTorch/mnist/MNIST/raw

device: cuda


In [2]:
import torch.nn as nn

class Discriminator(nn.Module):
    '''識別器のクラス
    Attributes:
        layers: Sequentialオブジェクトのリスト
    '''
    def __init__(self):
        '''識別器のネットワークを構築する
        '''
        super(Discriminator, self).__init__()

        in_ch = 1+10
        start_ch = 128

        self.layers = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(
                    in_ch,
                    start_ch,
                    4,
                    2,
                    1
                ),
                nn.LeakyReLU(negative_slope=0.2)
            ),
            nn.Sequential(
                nn.Conv2d(
                    start_ch,
                    start_ch * 2,
                    4,
                    2,
                    1
                ),
                nn.BatchNorm2d(start_ch * 2),
                nn.LeakyReLU(negative_slope=0.2)
            ),
            nn.Sequential(
                nn.Conv2d(
                    start_ch * 2,
                    start_ch * 4,
                    3,
                    2,
                    0
                ),
                nn.BatchNorm2d(start_ch * 4),
                nn.LeakyReLU(negative_slope=0.2)
            ),
            nn.Sequential(
                nn.Conv2d(
                    start_ch * 4,
                    1,
                    3,
                    1,
                    0
                ),
                nn.Sigmoid()
            ),
        ])
    
    def forward(self, x):
        '''順伝播処理
        Parameter:
            x: 画像データまたは生成画像
        '''
        for layer in self.layers:
            x = layer(x)
        
        return x.squeeze()

In [3]:
class Generator(nn.Module):
    '''生成器のクラス
    Attributes:
        layers: Sequentialオブジェクトのリスト
    '''
    def __init__(self):
        '''生成器のネットワークを構築する
        '''
        super(Generator, self).__init__()

        input_dim = 100+10
        out_ch = 128
        img_ch = 1

        self.layers = nn.ModuleList([
            nn.Sequential(
                nn.ConvTranspose2d(
                    input_dim,
                    out_ch * 4,
                    3,
                    1,
                    0
                ),
                nn.BatchNorm2d(out_ch * 4),
                nn.ReLU()
            ),
            nn.Sequential(
                nn.ConvTranspose2d(
                    out_ch * 4,
                    out_ch * 2,
                    3,
                    2,
                    0
                ),
                nn.BatchNorm2d(out_ch * 2),
                nn.ReLU()
            ),
            nn.Sequential(
                nn.ConvTranspose2d(
                    out_ch * 2,
                    out_ch ,
                    4,
                    2,
                    1
                ),
                nn.BatchNorm2d(out_ch),
                nn.ReLU()
            ),
            nn.Sequential(
                nn.ConvTranspose2d(
                    out_ch,
                    img_ch,
                    4,
                    2,
                    1
                ),
                nn.Tanh()
            ),
        ])
    
    def forward(self, z):
        '''順伝播処理
        Parameter:
            z: 識別器の出力
        '''
        for layer in self.layers:
            z = layer(z)
        
        return z

In [4]:
def weights_init(m):
    '''
    DCGANの論文では重みを正規分布からサンプリングした値で初期化している
    Parameters:
        m: ネットワークのインスタンス
    '''
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
        m.bias.data.fill_(0)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [5]:
import torchsummary

generator = Generator().to(device)
generator.apply(weights_init)
torchsummary.summary(generator, (110, 1, 1))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
   ConvTranspose2d-1            [-1, 512, 3, 3]         507,392
       BatchNorm2d-2            [-1, 512, 3, 3]           1,024
              ReLU-3            [-1, 512, 3, 3]               0
   ConvTranspose2d-4            [-1, 256, 7, 7]       1,179,904
       BatchNorm2d-5            [-1, 256, 7, 7]             512
              ReLU-6            [-1, 256, 7, 7]               0
   ConvTranspose2d-7          [-1, 128, 14, 14]         524,416
       BatchNorm2d-8          [-1, 128, 14, 14]             256
              ReLU-9          [-1, 128, 14, 14]               0
  ConvTranspose2d-10            [-1, 1, 28, 28]           2,049
             Tanh-11            [-1, 1, 28, 28]               0
Total params: 2,215,553
Trainable params: 2,215,553
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forw

In [6]:
discriminator = Discriminator().to(device)
discriminator.apply(weights_init)
torchsummary.summary(discriminator, (11, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 128, 14, 14]          22,656
         LeakyReLU-2          [-1, 128, 14, 14]               0
            Conv2d-3            [-1, 256, 7, 7]         524,544
       BatchNorm2d-4            [-1, 256, 7, 7]             512
         LeakyReLU-5            [-1, 256, 7, 7]               0
            Conv2d-6            [-1, 512, 3, 3]       1,180,160
       BatchNorm2d-7            [-1, 512, 3, 3]           1,024
         LeakyReLU-8            [-1, 512, 3, 3]               0
            Conv2d-9              [-1, 1, 1, 1]           4,609
          Sigmoid-10              [-1, 1, 1, 1]               0
Total params: 1,733,505
Trainable params: 1,733,505
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.03
Forward/backward pass size (MB): 0.78
Params size (MB): 6.61
Estimat

In [7]:
import torch.optim as optim

criterion = nn.BCELoss()
optimizer_ds = optim.Adam(
    discriminator.parameters(),
    lr=0.0002
)

optimizer_gn = optim.Adam(
    generator.parameters(),
    lr=0.0003
)

In [8]:
def encoder(label, device, n_class=10):
    '''正解ラベルをone-hot表現に変換する
    Parameters:
        label: 変換対象の正解ラベル
        device: 使用するデバイス
        n_class: 分類先のクラス数
    '''
    one_hot = torch.eye(n_class, device=device)
    return one_hot[label].view(-1, n_class, 1, 1)

In [16]:
def concat_img_label(image, label, device, n_class=10):
    '''画像のテンソルとラベルのテンソルを連結して
    　　　識別器に入力するテンソルを作成する
    Parameters:
        image: 画像データを格納したテンソル(bs, 1, 28, 28)
        label: 正解ラベル
        device: 使用可能なデバイス
        n_class: 分類先のデバイス
    Return:
        画像とOne-hot化ラベルを結合したテンソル
        (bs, 11, 28, 28)
    '''
    bs, ch, h, w = image.shape
    oh_label = encoder(label, device)
    oh_label = oh_label.expand(bs, n_class, h, w)
    return torch.cat((image, oh_label), dim=1)

In [10]:
def concat_noise_label(noise, label, device):
    '''ノイズのテンソルとラベルのテンソルを連結して
      生成器に入力するテンソルを作成する
    Parameters:
        noise(Tensor): ノイズのテンソル（bs, 100, 1, 1）
        label(int): 正解ラベル
        device: 使用するデバイス
    Return:
        ノイズとOne-hot化ラベルを連結したテンソル
        （bs, 110, 1, 1）
    '''
    oh_label = encoder(label, device)
    return torch.cat((noise, oh_label), dim=1)

In [11]:
noise_num = 100

fixed_noise = torch.randn(
    batch_size,
    noise_num,
    1,
    1,
    device=device
)
fixed_label = [i for i in range(10)] * (batch_size // 10)
fixed_label = torch.tensor(fixed_label, dtype=torch.long, device=device)
fixed_noise_label = concat_noise_label(fixed_noise, fixed_label, device)

In [28]:
%%time
import torchvision.utils as vutils

n_epoch = 10

outf = '/content/drive/MyDrive/Colab Notebooks/GAN/C-GAN/C-GAN_PyTorch/result'

for epoch in range(n_epoch):
    print('Epoch {}/{}'.format(epoch+1, n_epoch))

    for itr, data in enumerate(dataloader):
        real_image = data[0].to(device)
        real_label = data[1].to(device)
        real_image_label = concat_img_label(real_image, real_label, device)
        noise = torch.randn(batch_size, noise_num, 1, 1, device=device)
        fake_label = torch.randint(10, (batch_size,), dtype=torch.long, device=device)
        fake_noise_label = concat_noise_label(noise, fake_label, device)
        real_target = torch.full((batch_size,), 1., device=device)
        fake_target = torch.full((batch_size,), 0., device=device)
        
        discriminator.zero_grad()
        output = discriminator(real_image_label)
        ds_real_err = criterion(output, real_target)
        true_dsout_mean = output.mean().item()

        fake_image = generator(fake_noise_label)
        fake_image_label = concat_img_label(fake_image, fake_label, device)

        output = discriminator(fake_image_label.detach())
        ds_fake_err = criterion(output, fake_target)
        fake_dsout_mean1 = output.mean().item()
        ds_err = ds_real_err + ds_fake_err

        ds_err.backward()

        optimizer_ds.step()

        generator.zero_grad()
        output = discriminator(fake_image_label)
        gn_err = criterion(output, real_target)
        gn_err.backward()
        fake_dsout_mean2 = output.mean().item()
        optimizer_gn.step()

        if itr % 100 == 0:
            print('({}/{}) ds_loss: {:.3f} - gn_loss: {:.3f} - true_out: {:.3f} - fake_out: {:.3f}>>{:.3f}'.format(
                itr+1,
                len(dataloader),
                ds_err.item(),
                gn_err.item(),
                true_dsout_mean,
                fake_dsout_mean1,
                fake_dsout_mean2
            ))

        if epoch == 0 and itr == 0:
            vutils.save_image(
                real_image,
                '{}/real_samples.png'.format(outf),
                normalize=True,
                nrow=10)
        
        fake_image = generator(fixed_noise_label)
        vutils.save_image(
            fake_image.detach(),
            '{}/generated_epoch_{:03d}.png'.format(outf, epoch+1),
            normalize=True,
            nrow=10
        )

Epoch 1/10
(1/1200) ds_loss: 0.084 - gn_loss: 5.758 - true_out: 0.958 - fake_out: 0.008>>0.004
(101/1200) ds_loss: 0.002 - gn_loss: 8.947 - true_out: 0.999 - fake_out: 0.001>>0.001
(201/1200) ds_loss: 0.013 - gn_loss: 6.504 - true_out: 0.993 - fake_out: 0.005>>0.003
(301/1200) ds_loss: 0.030 - gn_loss: 5.218 - true_out: 0.983 - fake_out: 0.011>>0.007
(401/1200) ds_loss: 0.031 - gn_loss: 6.392 - true_out: 0.996 - fake_out: 0.025>>0.006
(501/1200) ds_loss: 0.024 - gn_loss: 6.846 - true_out: 0.997 - fake_out: 0.020>>0.005
(601/1200) ds_loss: 0.196 - gn_loss: 6.175 - true_out: 0.938 - fake_out: 0.075>>0.016
(701/1200) ds_loss: 0.155 - gn_loss: 6.647 - true_out: 0.909 - fake_out: 0.010>>0.007
(801/1200) ds_loss: 0.171 - gn_loss: 7.387 - true_out: 0.933 - fake_out: 0.013>>0.004
(901/1200) ds_loss: 0.083 - gn_loss: 6.774 - true_out: 0.950 - fake_out: 0.014>>0.017
(1001/1200) ds_loss: 0.189 - gn_loss: 5.956 - true_out: 0.892 - fake_out: 0.008>>0.022
(1101/1200) ds_loss: 0.073 - gn_loss: 6.459 