**Train cDCGAN with Rock-Paper-Scissor Dataset**

---

You can get dataset from [Rock-Paper-Scissor, Kaggle](https://www.kaggle.com/datasets/sanikamal/rock-paper-scissors-dataset).

In [1]:
import numpy as np
import pandas as pd
import plotly.express as px

import torch
import torch.nn as nn
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader, Subset

from torchvision.utils import make_grid
from torchvision.datasets import ImageFolder
from torchvision import transforms as T

from tqdm import trange, tqdm

from IPython.display import clear_output

from multiprocessing import cpu_count

In [2]:
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)

img_size = 128
img_dim = img_size ** 2
img_channels = 3
n_class = 3

z_dim = 100

n_epochs = 50
batch_size = 32
lr = 4e-4
k= 25 # G_lr/D_lr

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"Cpu core num: {cpu_count()}")

Using device: cuda
Cpu core num: 2


In [3]:
def download_dataset_to_dir(dataset_id: str, move: bool = True, target_dir: str = "./data/"):
    """
    使用 kagglehub 下载 Kaggle 数据集，并移动/复制到指定目录
    :param dataset_id: Kaggle 数据集 ID，比如 "splcher/animefacedataset"
    :param move: True=剪切数据至target_dir，False=数据位置不变
    :param target_dir: 目标目录，比如 "./data"
    :return: 目标目录的绝对路径
    """

    import kagglehub
    import shutil
    import os

    # 下载/获取缓存路径
    cache_path = kagglehub.dataset_download(dataset_id)
    print(f"Download path: {cache_path}")

    if not move:
        return cache_path
    else:
        # 确保目标目录存在
        os.makedirs(target_dir, exist_ok=True)

        # 移动内容
        for item in os.listdir(cache_path):
            s = os.path.join(cache_path, item)
            d = os.path.join(target_dir, item)

            # 剪切（移动）
            if os.path.exists(d):
                if os.path.isdir(d):
                    shutil.rmtree(d)
                else:
                    os.remove(d)
            shutil.move(s, d)

        return os.path.abspath(target_dir)

data_path = download_dataset_to_dir("sanikamal/rock-paper-scissors-dataset", move=False, target_dir="./data/rockpaperscissors/")

Using Colab cache for faster access to the 'rock-paper-scissors-dataset' dataset.
Download path: /kaggle/input/rock-paper-scissors-dataset


In [4]:
train_dataset = ImageFolder(root=data_path + '/Rock-Paper-Scissors/train', transform=T.Compose(
        [
            T.Resize(img_size),
            T.ToTensor(),
            T.Normalize([0.5] * 3, [0.5] * 3),
        ]
    ))

subset_flag = False
subset_size = 2000

data_loader_paras = {
    "batch_size": batch_size,
    "shuffle": True,
    "drop_last": True,

    "num_workers": 2,
    "pin_memory": True,
}

if subset_flag:
    indices = np.random.choice(range(len(train_dataset)), subset_size)
    subset_dataset = Subset(train_dataset, indices)
    print(f"Subset dataset size: {len(subset_dataset)}")

    train_loader = DataLoader(subset_dataset, **data_loader_paras)
else:
    train_loader = DataLoader(train_dataset, **data_loader_paras)

n_samples = len(train_loader.dataset)
n_batches = len(train_loader)
print(f"Number of training samples: {n_samples}")
print(f"Number of batches: {n_batches}")

Number of training samples: 2520
Number of batches: 78


In [5]:
imgs, labels = next(iter(train_loader))
print(f"Image batch shape: {imgs.shape}")
print(f"Label batch shape: {labels.shape}")

Image batch shape: torch.Size([32, 3, 128, 128])
Label batch shape: torch.Size([32])


In [6]:
from IPython.display import display, HTML

def show_centered(fig):
    html = f"<div style='display:flex; justify-content:center;'>{fig.to_html(include_plotlyjs='cdn')}</div>"
    display(HTML(html))


In [7]:
def denormalize(imgs):
    return imgs * 0.5 + 0.5

def show_images(imgs, grid_size=5):
    grid_imgs = make_grid(denormalize(imgs)[:24], nrow=6, padding=2).permute(1, 2, 0)

    fig = px.imshow(grid_imgs, aspect='auto')
    fig.update_xaxes(showticklabels=False).update_yaxes(showticklabels=False)
    fig.update_layout(
        width=800,
        height=400,
        title=dict(
            text="Rock-Paper-Scissor",
            x=0.5,
            xanchor='center',
            yanchor='top',
        ),
        coloraxis_showscale=False,
    )

    show_centered(fig)

show_images(imgs)

In [8]:
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

def basic_G(in_channels):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=4, stride=2, padding=1, bias=False), # H -> 2*H
        nn.BatchNorm2d(in_channels // 2),
        nn.ReLU(True),
    )

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            basic_G(512), # input shape: batch_size * 512 * 4 * 4
            basic_G(256),
            basic_G(128),
            basic_G(64),
            nn.ConvTranspose2d(32, img_channels, 4, 2, 1),
            nn.Tanh(),
        )

        self.label_encoder = nn.Embedding(n_class, 1 * 4 * 4)

        # 每张虚假图片的初始值自z_dim采样，通过latent线性层映射，最终在forward与label的嵌入合并作为input
        self.latent = nn.Linear(z_dim, 511 * 4 * 4)

    def forward(self, z, labels):
        x = self.latent(z)
        c = self.label_encoder(labels)
        input = torch.cat([x, c], dim=1).view(-1, 512, 4, 4)
        return self.net(input)

def basic_D(in_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, 2 * in_channels, kernel_size=4, stride=2, padding=1, bias=False), # H -> H/2
        nn.BatchNorm2d(2 * in_channels),
        nn.LeakyReLU(0.2, inplace=True),
    )

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(img_channels + 1, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            basic_D(64),
            basic_D(128),
            basic_D(256),
            basic_D(512),
            nn.Conv2d(1024, 1, kernel_size=4, stride=1, padding=0),
            nn.Flatten(),
            nn.Sigmoid(),
        )

        self.label_encoder = nn.Embedding(n_class, 1 * img_dim)

    def forward(self, imgs, labels):
        x = imgs.view(-1, img_channels * img_dim)
        c = self.label_encoder(labels)
        input = torch.cat([x, c], dim=1).view(-1, img_channels + 1, img_size, img_size)

        return self.net(input)

In [9]:
def train_D(D, D_optimizer, criterion, real_imgs, real_labels, fake_imgs, fake_labels):
    one_targets = torch.ones(batch_size, 1).to(device)
    zero_targets = torch.zeros(batch_size, 1).to(device)

    real_preds = D(real_imgs, real_labels)
    real_loss = criterion(real_preds, one_targets)
    fake_preds = D(fake_imgs.detach(), fake_labels)
    fake_loss = criterion(fake_preds, zero_targets)

    D_loss = real_loss + fake_loss
    D_optimizer.zero_grad()
    D_loss.backward()
    D_optimizer.step()

    real_score = torch.mean(real_preds).item()
    fake_score = torch.mean(fake_preds).item()

    return D_loss.item(), real_score, fake_score

def train_G(D, G_optimizer, criterion, fake_imgs, fake_labels):
    one_targets = torch.ones(fake_imgs.size(0), 1).to(device)

    preds = D(fake_imgs, fake_labels)
    G_loss = criterion(preds, one_targets)

    G_optimizer.zero_grad()
    G_loss.backward()
    G_optimizer.step()

    return G_loss.item()

def fit(D, G, D_optimizer: torch.optim.Adam, G_optimizer: torch.optim.Adam, scheduler, criterion, n_epochs, log_interval):

    train_history = {
        "D_loss": [],
        "G_loss": [],
        "D_real": [],
        "D_fake": [],
        "lr": [],
    }

    seed_seq = np.random.randint(0, 10000, size=n_epochs)


    for epoch in trange(n_epochs, desc="Epoch", leave=False):
        D_loss, G_loss, real_score, fake_score = 0.0, 0.0, 0.0, 0.0

        for real_imgs, real_labels in train_loader:
            real_imgs = real_imgs.to(device)
            real_labels = real_labels.to(device)

            # z = torch.randn(batch_size, z_dim).to(device)
            # 由于在最上方固定随机数种子，因此如果保留上面的代码，模型在每个epoch都会生成相同的fake_labels
            # 这里为了保证可控的随机性，遍历在外部提前生成的随机数序列作为epoch内部需要的随机数种子
            g = torch.Generator().manual_seed(seed)
            z = torch.randn(batch_size, z_dim, generator=g).to(device)
            fake_labels = torch.randint(0, 3, (batch_size, )).to(device)
            fake_imgs = G(z, fake_labels)

            _D_loss, _real_score, _fake_score = train_D(D, D_optimizer, criterion, real_imgs, real_labels, fake_imgs, fake_labels)
            D_loss += _D_loss
            real_score += _real_score
            fake_score += _fake_score

            _G_loss = train_G(D, G_optimizer, criterion, fake_imgs, fake_labels)
            G_loss += _G_loss

        train_history["D_loss"].append(D_loss / n_samples)
        train_history["G_loss"].append(G_loss / n_samples)
        train_history["D_real"].append(real_score / n_batches)
        train_history["D_fake"].append(fake_score / n_batches)

        lr = scheduler.get_last_lr()[0]
        train_history["lr"].append(lr)

        G_optimizer.param_groups[0]['lr'] = lr * k

        if epoch == 0 or (epoch+1) % log_interval == 0:
            clear_output(wait=True)

            tqdm.write(
                f"Epoch [{epoch+1}/{n_epochs}] lr: {train_history['lr'][-1]:.6f} "
                f"D_loss: {train_history['D_loss'][-1]:.4f} G_loss: {train_history['G_loss'][-1]:.4f} "
                f"D_real: {train_history['D_real'][-1]:.4f} D_fake: {train_history['D_fake'][-1]:.4f} "
            )

            z = torch.randn(batch_size, z_dim).to(device)
            # Generate fake labels for showing images
            fake_labels = torch.randint(0, 3, (batch_size, )).to(device)
            fake_imgs = G(z, fake_labels).detach().cpu()
            show_images(fake_imgs)

        scheduler.step()


    return train_history

In [10]:
D = Discriminator().to(device)
G = Generator().to(device)
D.apply(weights_init)
G.apply(weights_init)

optim_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
optim_G = torch.optim.Adam(G.parameters(), lr=lr*k, betas=(0.5, 0.999))

v= lambda epoch: 0.9 ** epoch
scheduler = lr_scheduler.LambdaLR(optim_D, lr_lambda=v)
criterion = nn.BCELoss(reduction='sum')

log_interval = 10

train_history = fit(D, G, optim_D, optim_G, scheduler, criterion, n_epochs, log_interval)

Epoch:  98%|█████████▊| 49/50 [12:13<00:14, 14.59s/it]

Epoch [50/50] lr: 0.000002 D_loss: 0.6083 G_loss: 1.4773 D_real: 0.7546 D_fake: 0.2475 




In [11]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

fig = make_subplots(rows=1, cols=3, subplot_titles=("Loss", "D(x) & D(G(z))", "Learning Rate"))

fig.add_traces([
    go.Scatter(y=train_history["D_loss"], mode='lines', name='D_loss'),
    go.Scatter(y=train_history["G_loss"], mode='lines', name='G_loss'),
], rows=1, cols=1)

fig.add_traces([
    go.Scatter(y=train_history["D_real"], mode='lines', name='D(x)'),
    go.Scatter(y=train_history["D_fake"], mode='lines', name='D(G(z))'),
], rows=1, cols=2)

fig.add_hline(
    y=0.5,
    line=dict(color="black", dash="dash"),
    row=1,
    col=2,
)

fig.add_trace(
    go.Scatter(y=train_history["lr"], mode='lines', name='Learning Rate'),
    row=1,
    col=3,
)

fig.update_layout(height=400, width=1200, title_text="GAN Training History")
show_centered(fig)