<a href="https://colab.research.google.com/github/TakuroTerui/object_gan_autoencoder_image/blob/master/DCGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms

dataset = datasets.MNIST(
    root='/content/drive/MyDrive/Colab Notebooks/GAN/DCGAN_PyTorch/mnist',
    download=True,
    train=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
)

batch_size = 50
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device:', device)

device: cpu


In [11]:
import torch.nn as nn

class Discriminator(nn.Module):
    '''識別器のクラス
    Attributes:
        layers: Sequentialオブジェクトのリスト
    '''
    def __init__(self):
        '''識別器のネットワークを構築する
        '''
        super(Discriminator, self).__init__()

        in_ch = 1
        start_ch = 128

        self.layers = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(
                    in_ch,
                    start_ch,
                    4,
                    2,
                    1
                ),
                nn.LeakyReLU(negative_slope=0.2)
            ),
            nn.Sequential(
                nn.Conv2d(
                    start_ch,
                    start_ch * 2,
                    4,
                    2,
                    1
                ),
                nn.BatchNorm2d(start_ch * 2),
                nn.LeakyReLU(negative_slope=0.2)
            ),
            nn.Sequential(
                nn.Conv2d(
                    start_ch * 2,
                    start_ch * 4,
                    3,
                    2,
                    0
                ),
                nn.BatchNorm2d(start_ch * 4),
                nn.LeakyReLU(negative_slope=0.2)
            ),
            nn.Sequential(
                nn.Conv2d(
                    start_ch * 4,
                    1,
                    3,
                    1,
                    0
                ),
                nn.Sigmoid()
            ),
        ])
    
    def forward(self, x):
        '''順伝播処理
        Parameter:
            x: 画像データまたは生成画像
        '''
        for layer in self.layers:
            x = layer(x)
        
        return x.squeeze()

In [20]:
class Generator(nn.Module):
    '''生成器のクラス
    Attributes:
        layers: Sequentialオブジェクトのリスト
    '''
    def __init__(self):
        '''生成器のネットワークを構築する
        '''
        super(Generator, self).__init__()

        input_dim = 100
        out_ch = 128
        img_ch = 1

        self.layers = nn.ModuleList([
            nn.Sequential(
                nn.ConvTranspose2d(
                    input_dim,
                    out_ch * 4,
                    3,
                    1,
                    0
                ),
                nn.BatchNorm2d(out_ch * 4),
                nn.ReLU()
            ),
            nn.Sequential(
                nn.ConvTranspose2d(
                    out_ch * 4,
                    out_ch * 2,
                    3,
                    2,
                    0
                ),
                nn.BatchNorm2d(out_ch * 2),
                nn.ReLU()
            ),
            nn.Sequential(
                nn.ConvTranspose2d(
                    out_ch * 2,
                    out_ch ,
                    4,
                    2,
                    1
                ),
                nn.BatchNorm2d(out_ch),
                nn.ReLU()
            ),
            nn.Sequential(
                nn.ConvTranspose2d(
                    out_ch,
                    img_ch,
                    4,
                    2,
                    1
                ),
                nn.Tanh()
            ),
        ])
    
    def forward(self, z):
        '''順伝播処理
        Parameter:
            z: 識別器の出力
        '''
        for layer in self.layers:
            z = layer(z)
        
        return z

In [21]:
def weights_init(m):
    '''
    DCGANの論文では重みを正規分布からサンプリングした値で初期化している
    Parameters:
        m: ネットワークのインスタンス
    '''
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
        m.bias.data.fill_(0)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [22]:
import torchsummary

generator = Generator().to(device)
generator.apply(weights_init)
torchsummary.summary(generator, (100, 1, 1))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
   ConvTranspose2d-1            [-1, 512, 3, 3]         461,312
       BatchNorm2d-2            [-1, 512, 3, 3]           1,024
              ReLU-3            [-1, 512, 3, 3]               0
   ConvTranspose2d-4            [-1, 256, 7, 7]       1,179,904
       BatchNorm2d-5            [-1, 256, 7, 7]             512
              ReLU-6            [-1, 256, 7, 7]               0
   ConvTranspose2d-7          [-1, 128, 14, 14]         524,416
       BatchNorm2d-8          [-1, 128, 14, 14]             256
              ReLU-9          [-1, 128, 14, 14]               0
  ConvTranspose2d-10            [-1, 1, 28, 28]           2,049
             Tanh-11            [-1, 1, 28, 28]               0
Total params: 2,169,473
Trainable params: 2,169,473
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forw

In [23]:
discriminator = Discriminator().to(device)
discriminator.apply(weights_init)
torchsummary.summary(discriminator, (1, 28, 28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 128, 14, 14]           2,176
         LeakyReLU-2          [-1, 128, 14, 14]               0
            Conv2d-3            [-1, 256, 7, 7]         524,544
       BatchNorm2d-4            [-1, 256, 7, 7]             512
         LeakyReLU-5            [-1, 256, 7, 7]               0
            Conv2d-6            [-1, 512, 3, 3]       1,180,160
       BatchNorm2d-7            [-1, 512, 3, 3]           1,024
         LeakyReLU-8            [-1, 512, 3, 3]               0
            Conv2d-9              [-1, 1, 1, 1]           4,609
          Sigmoid-10              [-1, 1, 1, 1]               0
Total params: 1,713,025
Trainable params: 1,713,025
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.78
Params size (MB): 6.53
Estimat

In [24]:
import torch.optim as optim

criterion = nn.BCELoss()
optimizer_ds = optim.Adam(
    discriminator.parameters(),
    lr=0.0002,
    betas=(0.5, 0.999)
)

optimizer_gn = optim.Adam(
    generator.parameters(),
    lr=0.0002,
    betas=(0.5, 0.999)
)

In [None]:
%%time
import torchvision.utils as vutils

n_epoch = 10
gn_input_dim = 100

outf = '/content/drive/MyDrive/Colab Notebooks/GAN/DCGAN_PyTorch/result'

fixed_noise = torch.randn(
    batch_size,
    gn_input_dim,
    1,
    1,
    device=device
)

for epoch in range(n_epoch):
    print('Epoch {}/{}'.format(epoch+1, n_epoch))

    for itr, data in enumerate(dataloader):
        real_image = data[0].to(device)
        sample_size = real_image.size(0)

        noise = torch.randn(sample_size, gn_input_dim, 1, 1, device=device)
        real_target = torch.full((sample_size,), 1., device=device)
        fake_target = torch.full((sample_size,), 0., device=device)

        discriminator.zero_grad()

        output = discriminator(real_image)
        ds_real_err = criterion(output, real_target)
        true_dsout_mean = output.mean().item()

        fake_image = generator(noise)
        output = discriminator(fake_image.detach())
        ds_fake_err = criterion(output, fake_target)
        fake_dsout_mean1 = output.mean().item()
        ds_err = ds_real_err + ds_fake_err

        ds_err.backward()
        optimizer_ds.step()

        generator.zero_grad()
        output = discriminator(fake_image)
        gn_err = criterion(output, real_target)
        gn_err.backward()
        fake_dsout_mean2 = output.mean().item()
        optimizer_gn.step()

        if itr % 100 == 0:
            print('({}/{}) ds_loss: {:.3f} - gn_loss: {:.3f} - true_out: {:.3f} - fake_out: {:.3f}>>{:.3f}'.format(
                itr+1,
                len(dataloader),
                ds_err.item(),
                gn_err.item(),
                true_dsout_mean,
                fake_dsout_mean1,
                fake_dsout_mean2
            ))

        if epoch == 0 and itr == 0:
            vutils.save_image(
                real_image,
                '{}/real_samples.png'.format(outf),
                normalize=True,
                nrow=10)
        
        fake_image = generator(fixed_noise)
        vutils.save_image(
            fake_image.detach(),
            '{}/generated_epoch_{:03d}.png'.format(outf, epoch+1),
            normalize=True,
            nrow=10
        )

Epoch 1/10
(1/1200) ds_loss: 0.652 - gn_loss: 5.440 - true_out: 0.796 - fake_out: 0.284>>0.006
(101/1200) ds_loss: 0.101 - gn_loss: 4.310 - true_out: 0.960 - fake_out: 0.055>>0.017
(201/1200) ds_loss: 0.298 - gn_loss: 2.989 - true_out: 0.830 - fake_out: 0.077>>0.072
(301/1200) ds_loss: 0.761 - gn_loss: 2.440 - true_out: 0.538 - fake_out: 0.020>>0.125
(401/1200) ds_loss: 0.370 - gn_loss: 2.317 - true_out: 0.889 - fake_out: 0.190>>0.135
(501/1200) ds_loss: 0.247 - gn_loss: 2.817 - true_out: 0.883 - fake_out: 0.099>>0.082
(601/1200) ds_loss: 0.383 - gn_loss: 1.432 - true_out: 0.758 - fake_out: 0.066>>0.286
(701/1200) ds_loss: 0.832 - gn_loss: 1.543 - true_out: 0.557 - fake_out: 0.033>>0.295
(801/1200) ds_loss: 0.605 - gn_loss: 2.855 - true_out: 0.862 - fake_out: 0.308>>0.079
(901/1200) ds_loss: 0.470 - gn_loss: 3.714 - true_out: 0.881 - fake_out: 0.258>>0.034
(1001/1200) ds_loss: 0.685 - gn_loss: 4.047 - true_out: 0.951 - fake_out: 0.429>>0.023
(1101/1200) ds_loss: 0.410 - gn_loss: 3.058 