In [None]:
import torch
import torchvision
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torchvision.datasets import ImageFolder
from torchvision import transforms
import tqdm
import numpy as np
import cv2
import glob
import matplotlib.pyplot as plt
from statistics import mean

In [None]:
# データセットの準備
wget http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz
tar xf 102flowers.tgz
mkdir oxford-102
mkdir oxford-102/jpg
mv jpg/*.jpg oxford-102/jpg

In [None]:
# DataSetとDataLoader
img_data = ImageFolder('./oxford-102/',
                       transform = transforms.Compose([
                              transforms.Resize(80),
                              transforms.CenterCrop(64),
                              transforms.ToTensor()
                       ]))
batch_size = 64
img_loader = DataLoader(img_data, batch_size = batch_size, shuffle = True)

In [None]:
nz = 100
ngf = 32

class GNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.main = nn.Sequential(
        nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias = False),
        nn.BatchNorm2d(ngf * 8),
        nn.ReLU(inplace = True),  # inplace -> 必要に応じて動作インプレースを行える(デフォルトはFalse)

        nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias = False),
        nn.BatchNorm2d(ngf * 4),
        nn.ReLU(inplace = True),

        nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias = False),
        nn.BatchNorm2d(ngf * 2),
        nn.ReLU(inplace = True),

        nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias = False),
        nn.BatchNorm2d(ngf),
        nn.ReLU(inplace = True),

        nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias = False),
        nn.Tanh()
    )

  def forward(self, x):
    x = self.main(x)
    return x


ndf = 32

class DNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.main = nn.Sequential(
        nn.Conv2d(3, ndf, 4, 2, 1, bias = False),
        nn.LeakyReLU(0.2, inplace = True),
        nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias = False),
        nn.BatchNorm2d(ndf * 2),
        nn.LeakyReLU(0.2, inplace = True),
        nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias = False),
        nn.BatchNorm2d(ndf * 4),
        nn.LeakyReLU(0.2, inplace = True),
        nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias = False),
        nn.BatchNorm2d(ndf * 8),
        nn.LeakyReLU(0.2, inplace = True),
        nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias = False)
    )
  
  def forward(self, x):
    x = self.main(x)
    return x.squeeze()

In [None]:
d = DNet().to('cuda:0')
g = GNet().to('cuda:0')

# Adamのパラメータは元論文の提案値
opt_d = optim.Adam(d.parameters(), lr = 0.0002, betas = (0.5, 0.999))
opt_g = optim.Adam(g.parameters(), lr = 0.0002, betas = (0.5, 0.999))

# クロスエントロピーを計算するための補助変数など
ones = torch.ones(batch_size).to('cuda:0')
zeros = torch.zeros(batch_size).to('cuda:0')
loss_f = nn.BCEWithLogitsLoss()

In [None]:
# 訓練関数
def train_dcgan(g, d, opt_g, opt_d, loader):
  log_loss_g = []
  log_loss_d = []

  for real_img, _ in tqdm.tqdm(loader):
    batch_len = len(real_img)
    real_img = real_img.to('cuda:0')
    z = torch.randn(batch_len, nz, 1, 1).to('cuda:0')
    fake_img = g(z)
    fake_img_tensor = fake_img.detach()

    out = d(fake_img)
    loss_g = loss_f(out, ones[: batch_len])
    log_loss_g.append(loss_g.item())

    d.zero_grad(), g.zero_grad()
    loss_g.backward()
    opt_g.step()

    real_out = d(real_img)
    loss_d_real = loss_f(real_out, ones[: batch_len])

    fake_img = fake_img_tensor
    
    fake_out = d(fake_img)
    loss_d_fake = loss_f(fake_out, zeros[: batch_len])

    loss_d = loss_d_real + loss_d_fake
    log_loss_d.append(loss_d.item())

    d.zero_grad(), g.zero_grad()
    loss_d.backward()
    opt_d.step()
  
  return mean(log_loss_g), mean(log_loss_d)

In [None]:
# DCGANの訓練
for epoch in tqdm.tqdm(range(100)):
  train_dcgan(g, d, opt_g, opt_d, img_loader)

In [None]:
torch.save(g.state_dict(), './generate.prm', pickle_protocol=4)
torch.save(d.state_dict(), './descriminator.prm', pickle_protocol=4)

In [None]:
# 適当に乱数生成して学習したモデルに流して可視化してみる
fixed_generate = torch.randn(1, nz, 1, 1).to('cuda:0')
generated_img = g(fixed_generate)

import torchvision

torchvision.utils.save_image(generated_img, 'hoge.png')