<a href="https://colab.research.google.com/github/MamedenQ/VolleyballDesign/blob/main/GAN%E3%81%A7%E6%AC%A1%E4%B8%96%E4%BB%A3%E3%83%90%E3%83%AC%E3%83%BC%E3%83%9C%E3%83%BC%E3%83%AB%E6%9F%84%E7%94%9F%E6%88%90.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 事前準備

インポート

In [None]:
import random
import time
import numpy as np
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.utils as vutils
import torchvision.transforms as transforms

各種設定

In [None]:
# シード値設定
torch.manual_seed(1234)
np.random.seed(1234)
random.seed(1234)

# 潜在変数の領域
z_dim = 100

# バッチ数
batch_size = 64

# 画像サイズ
img_size = 64

# エポック数
num_epochs = 500

使用デバイス確認

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("使用デバイス：", device)

リポジトリクローン、訓練データ展開

In [None]:
!git clone https://github.com/MamedenQ/VolleyballDesign
!unzip VolleyballDesign/data.zip

# Generatorの作成

定義

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.layer1 = nn.Sequential(
            nn.ConvTranspose2d(z_dim, img_size * 8,
                               kernel_size=4, stride=1, bias=False),
            nn.BatchNorm2d(img_size * 8),
            nn.ReLU(inplace=True))

        self.layer2 = nn.Sequential(
            nn.ConvTranspose2d(img_size * 8, img_size * 4,
                               kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(img_size * 4),
            nn.ReLU(inplace=True))

        self.layer3 = nn.Sequential(
            nn.ConvTranspose2d(img_size * 4, img_size * 2,
                               kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(img_size * 2),
            nn.ReLU(inplace=True))

        self.layer4 = nn.Sequential(
            nn.ConvTranspose2d(img_size * 2, img_size,
                               kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(img_size),
            nn.ReLU(inplace=True))

        self.last = nn.Sequential(
            nn.ConvTranspose2d(img_size, 3, kernel_size=4,
                               stride=2, padding=1, bias=False),
            nn.Tanh())

    def forward(self, z):
        out = self.layer1(z)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.last(out)

        return out

生成、動作確認

In [None]:
generator = Generator()

# randnは標準正規分布(平均0, 分散1の正規分布)に従う乱数を取り出す
input_z = torch.randn(1, z_dim, 1, 1)

# 偽画像を出力
fake_imgs = generator(input_z)

# 偽画像表示
plt.imshow(np.transpose(fake_imgs[0].detach().numpy(), (1, 2, 0)))
plt.show()

# Discriminatorの作成

定義

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.layer1 = nn.Sequential(
            nn.Conv2d(3, img_size, kernel_size=4,
                      stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True))

        self.layer2 = nn.Sequential(
            nn.Conv2d(img_size, img_size*2, kernel_size=4,
                      stride=2, padding=1, bias=False),
            nn.BatchNorm2d(img_size * 2),
            nn.LeakyReLU(0.2, inplace=True))

        self.layer3 = nn.Sequential(
            nn.Conv2d(img_size*2, img_size*4, kernel_size=4,
                      stride=2, padding=1, bias=False),
            nn.BatchNorm2d(img_size * 4),
            nn.LeakyReLU(0.2, inplace=True))

        self.layer4 = nn.Sequential(
            nn.Conv2d(img_size*4, img_size*8, kernel_size=4,
                      stride=2, padding=1, bias=False),
            nn.BatchNorm2d(img_size * 8),
            nn.LeakyReLU(0.2, inplace=True))

        self.last = nn.Sequential(
            nn.Conv2d(img_size*8, 1, kernel_size=4, stride=1,
                      padding=0, bias=False),
            nn.Sigmoid())

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.last(out)

        return out

生成、動作確認

In [None]:
discriminator = Discriminator()

# 偽画像を生成
input_z = torch.randn(1, z_dim, 1, 1)
fake_imgs = generator(input_z)

# 偽画像をDiscriminatorに入力
d_out = discriminator(fake_imgs)

# 判定結果発表
print(d_out)

# DataLoaderの作成

In [None]:
train_dataset = datasets.ImageFolder(root="data",
    transform=transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.RandomHorizontalFlip(p=0.3),
        transforms.RandomVerticalFlip(p=0.3),
        transforms.ColorJitter(brightness=(1, 1.3), contrast=(1, 1.2), saturation=(0.8, 1.2)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]))

dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True)

real_batch = next(iter(dataloader))

plt.figure(figsize=(8, 8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(
    np.transpose(
        vutils.make_grid(real_batch[0].to(device)[:batch_size], padding=2, normalize=True).cpu(),
        (1, 2, 0)))

# 学習

ネットワーク初期化

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

generator.apply(weights_init)
discriminator.apply(weights_init)

学習処理

In [None]:
t_start = time.time()

# 最適化手法の設定
g_lr, d_lr = 0.0002, 0.0002
beta1, beta2 = 0.5, 0.999
g_optimizer = torch.optim.Adam(generator.parameters(), g_lr, [beta1, beta2])
d_optimizer = torch.optim.Adam(discriminator.parameters(), d_lr, [beta1, beta2])

# 誤差関数を定義
criterion = nn.BCELoss()

# ネットワークをGPUへ
generator.to(device)
discriminator.to(device)

# 訓練モード設定
generator.train()
discriminator.train()

# ネットワークがある程度固定であれば、高速化させる
torch.backends.cudnn.benchmark = True

# 損失のリスト初期化
g_loss_all = []
d_loss_all = []

for epoch in tqdm(range(num_epochs)):
    t_epoch_start = time.time()

    # epoch内の損失を溜め込むリスト
    epoch_g_loss = []
    epoch_d_loss = []

    for imgs in dataloader:
        ##################
        # Discriminator学習
        ##################

        # GPUが使えるならGPUにデータを送る
        imgs = imgs[0].to(device)

        # 正解ラベルと偽ラベルを作成
        # epochの最後のイテレーションはミニバッチの数が少なくなる
        cur_batch_size = imgs.size(0)
        label_real = torch.full((cur_batch_size,), 1).to(device)
        label_fake = torch.full((cur_batch_size,), 0).to(device)

        # 真の画像を判定
        d_out_real = discriminator(imgs)

        # 偽の画像を生成して判定
        input_z = torch.randn(cur_batch_size, z_dim).to(device)
        input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)
        fake_imgs = generator(input_z)
        d_out_fake = discriminator(fake_imgs)

        # 誤差を計算
        label_real = label_real.type_as(d_out_real.view(-1))
        d_loss_real = criterion(d_out_real.view(-1), label_real)
        label_fake = label_fake.type_as(d_out_fake.view(-1))
        d_loss_fake = criterion(d_out_fake.view(-1), label_fake)
        d_loss = d_loss_real + d_loss_fake

        # バックプロパゲーション
        g_optimizer.zero_grad()
        d_optimizer.zero_grad()

        d_loss.backward()
        d_optimizer.step()

        ##################
        # Generator学習
        ##################

        # 偽の画像を生成して判定
        input_z = torch.randn(cur_batch_size, z_dim).to(device)
        input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)
        fake_imgs = generator(input_z)
        d_out_fake = discriminator(fake_imgs)

        last_d_out_fake = d_out_fake

        # 誤差を計算
        g_loss = criterion(d_out_fake.view(-1), label_real)

        # バックプロパゲーション
        g_optimizer.zero_grad()
        d_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()

        epoch_d_loss.append(d_loss.item())
        epoch_g_loss.append(g_loss.item())

    g_losses_mean = np.mean(epoch_g_loss)
    d_losses_mean = np.mean(epoch_d_loss)
    g_loss_all.append(g_losses_mean)
    d_loss_all.append(d_losses_mean)

    tqdm.write("epoch {} || d_loss:{:.4f} g_loss:{:.4f} timer: {:.4f} sec.".format(
        epoch + 1,
        d_losses_mean,
        g_losses_mean,
        time.time() - t_epoch_start))

# 損失をグラフ化
fig, ax = plt.subplots(1, 1)
ax.plot(g_loss_all, label="g loss", marker="o")
ax.plot(d_loss_all, label="d loss", marker="*")
ax.legend()

print("finish time:{:.4f} sec.".format(time.time() - t_start))

# 画像生成

モデル読み込み（学習なしで画像生成する場合）

In [None]:
# generator.load_state_dict(torch.load("VolleyballDesign/model/generate.pt"))

画像生成

In [None]:
plt.figure(figsize=(8, 8))
plt.axis("off")
plt.title("Gen Images")

fixed_z = torch.randn(64, z_dim)
fixed_z = fixed_z.view(fixed_z.size(0), fixed_z.size(1), 1, 1)

generator.eval()
fake_img = generator(fixed_z.to(device))

plt.imshow(
    np.transpose(
        vutils.make_grid(fake_img, padding=2, normalize=True).cpu(),
        (1, 2, 0)))