**Train DCGAN with Anime Face Dataset**

---

The github of dataset can be found [here](https://github.com/bchao1/Anime-Face-Dataset?tab=readme-ov-file).

You can also get dataset from [Kaggle, Anime Face Dataset](https://www.kaggle.com/datasets/splcher/animefacedataset/).

In [None]:
import numpy as np
import pandas as pd
import plotly.express as px

import torch
import torch.nn as nn
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader, Subset

from torchvision.utils import make_grid
from torchvision.datasets import ImageFolder
from torchvision import transforms as T

from tqdm import trange, tqdm

from IPython.display import clear_output

In [None]:
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)

img_size = 64
img_channels = 3
z_dim = 100

n_epochs = 20
batch_size = 64
lr = 1e-4
k= 18 # G_lr/D_lr

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [None]:
def download_dataset_to_dir(dataset_id: str, target_dir: str = "./data/", move: bool = True):
    """
    使用 kagglehub 下载 Kaggle 数据集，并移动/复制到指定目录
    :param dataset_id: Kaggle 数据集 ID，比如 "splcher/animefacedataset"
    :param target_dir: 目标目录，比如 "./data"
    :param move: True=剪切数据至target_dir，False=数据位置不变
    :return: 目标目录的绝对路径
    """

    import kagglehub
    import shutil
    import os

    # 下载/获取缓存路径
    cache_path = kagglehub.dataset_download(dataset_id)
    print(f"Download path: {cache_path}")

    if not move:
        return cache_path
    else:
        # 确保目标目录存在
        os.makedirs(target_dir, exist_ok=True)

        # 移动内容
        for item in os.listdir(cache_path):
            s = os.path.join(cache_path, item)
            d = os.path.join(target_dir, item)

            # 剪切（移动）
            if os.path.exists(d):
                if os.path.isdir(d):
                    shutil.rmtree(d)
                else:
                    os.remove(d)
            shutil.move(s, d)

        print(f"{'Moved' if move else 'Copied'} dataset to: {os.path.abspath(target_dir)}")
        return os.path.abspath(target_dir)

data_path = download_dataset_to_dir("splcher/animefacedataset", "./data/AnimeFace/", move=False)

Downloading from https://www.kaggle.com/api/v1/datasets/download/splcher/animefacedataset?dataset_version_number=3...


100%|██████████| 395M/395M [00:10<00:00, 39.8MB/s]

Extracting files...





Download path: /root/.cache/kagglehub/datasets/splcher/animefacedataset/versions/3


In [None]:
train_dataset = ImageFolder(root=data_path, transform=T.Compose(
        [
            T.Resize(img_size),
            T.CenterCrop(img_size),
            T.ToTensor(),
            T.Normalize([0.5] * 3, [0.5] * 3),
        ]
    ))

subset_flag = False
subset_size = 2000

if subset_flag:
    indices = np.random.choice(range(len(train_dataset)), subset_size)
    subset_dataset = Subset(train_dataset, indices) # type: ignore
    print(f"Subset dataset size: {len(subset_dataset)}")

    train_loader = DataLoader(subset_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
else:
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

# Check the number of samples and batches
n_samples = len(train_loader.dataset) # type: ignore
n_batches = len(train_loader)
print(f"Number of training samples: {n_samples}")
print(f"Number of batches: {n_batches}")

Number of training samples: 63565
Number of batches: 993


In [None]:
imgs, labels = next(iter(train_loader))
print(f"Image batch shape: {imgs.shape}")
print(f"Label batch shape: {labels.shape}")

Image batch shape: torch.Size([64, 3, 64, 64])
Label batch shape: torch.Size([64])


In [None]:
# # Monkey patch for centered Plotly display in Jupyter
# from IPython.display import display as ipy_display, HTML
# import plotly.graph_objs as go

# # 保存原 display
# _original_display = ipy_display

# def display(obj, *args, **kwargs):
#     """自动居中 Plotly Figure"""
#     if isinstance(obj, go.Figure):
#         html = f"<div style='display:flex; justify-content:center;'>{obj.to_html(include_plotlyjs='cdn')}</div>"
#         _original_display(HTML(html), *args, **kwargs)
#     else:
#         _original_display(obj, *args, **kwargs)

# # Monkey patch
# import IPython.display
# IPython.display.display = display

from IPython.display import display, HTML
import plotly.graph_objects as go

def show_centered(fig):
    html = f"<div style='display:flex; justify-content:center;'>{fig.to_html(include_plotlyjs='cdn')}</div>"
    display(HTML(html))


In [None]:
def denormalize(imgs):
    return imgs * 0.5 + 0.5

def show_images(imgs, grid_size=5):
    grid_imgs = make_grid(denormalize(imgs)[:24], nrow=6, padding=2).permute(1, 2, 0)

    fig = px.imshow(grid_imgs, aspect='auto')
    fig.update_xaxes(showticklabels=False).update_yaxes(showticklabels=False)
    fig.update_layout(
        width=800,
        height=400,
        title=dict(
            text="Anime Face",
            x=0.5,
            xanchor='center',
            yanchor='top',
        ),
        coloraxis_showscale=False,
    )
    show_centered(fig)

show_images(imgs)

+ Subset
    + 取少量数据验证流程正确
    + random_split就是返回两个subset

+ mokey patching

+ kagglehub下载数据与cache

+ weights_init
    + classname.find & type & isinstance

In [None]:
# def weights_init(m):
#     classname = m.__class__.__name__
#     if classname.find('Conv') != -1 or classname.find('Linear') != -1:
#         nn.init.normal_(m.weight.data, 0.0, 0.02)
#     if classname.find('BatchNorm') != -1:
#         nn.init.normal_(m.weight.data, 1.0, 0.02)
#         nn.init.constant_(m.bias.data, 0)

# def weights_init(m):
#     if type(m) == nn.Linear:
#         nn.init.normal_(m.weight.data, 0.0, 0.02)
#         if m.bias is not None:
#             nn.init.constant_(m.bias.data, 0)

# def weights_init(m):
#     if isinstance(m, nn.Linear):
#         nn.init.normal_(m.weight.data, 0.0, 0.02)
#         if m.bias is not None:
#             nn.init.constant_(m.bias.data, 0)


In [None]:
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

def basic_G(out_channels):
    return nn.Sequential(
        nn.ConvTranspose2d(2 * out_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False), # H -> 2*H
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    )

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(z_dim, 512, kernel_size=4, stride=1, padding=0, bias=False), # 1 -> 4
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            basic_G(256),
            basic_G(128),
            basic_G(64),
            nn.ConvTranspose2d(64, img_channels, 4, 2, 1),
            nn.Tanh(),
        )

    def forward(self, imgs):
        imgs = imgs.view(-1, z_dim, 1, 1)
        return self.net(imgs)

def basic_D(in_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, 2 * in_channels, kernel_size=4, stride=2, padding=1, bias=False), # H -> H/2
        nn.BatchNorm2d(2 * in_channels),
        nn.LeakyReLU(0.2, inplace=True),
    )

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(img_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            basic_D(64),
            basic_D(128),
            basic_D(256),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0),
            nn.Flatten(),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.net(x)

In [None]:
def train_D(D, D_optimizer, criterion, real_imgs, fake_imgs):
    real_labels = torch.ones(real_imgs.size(0), 1).to(device)
    fake_labels = torch.zeros(fake_imgs.size(0), 1).to(device)

    real_preds = D(real_imgs)
    real_loss = criterion(real_preds, real_labels)
    fake_preds = D(fake_imgs.detach())
    fake_loss = criterion(fake_preds, fake_labels)

    D_loss = real_loss + fake_loss
    D_optimizer.zero_grad()
    D_loss.backward()
    D_optimizer.step()

    real_score = torch.mean(real_preds).item()
    fake_score = torch.mean(fake_preds).item()

    return D_loss.item(), real_score, fake_score

def train_G(G, D, G_optimizer, criterion, fake_imgs):
    real_labels = torch.ones(fake_imgs.size(0), 1).to(device)

    preds = D(fake_imgs)
    G_loss = criterion(preds, real_labels)

    G_optimizer.zero_grad()
    G_loss.backward()
    G_optimizer.step()

    return G_loss.item()

def fit(D, G, D_optimizer: torch.optim.Adam, G_optimizer: torch.optim.Adam, scheduler, criterion, n_epochs):

    train_history = {
        "D_loss": [],
        "G_loss": [],
        "D_real": [],
        "D_fake": [],
        "lr": [],
    }

    for epoch in trange(n_epochs, desc="Epoch", leave=False):
        D_loss, G_loss, real_score, fake_score = 0.0, 0.0, 0.0, 0.0

        for real_imgs, _ in train_loader:
            real_imgs = real_imgs.to(device)

            z = torch.randn(batch_size, z_dim).to(device)
            fake_imgs = G(z)

            _D_loss, _real_score, _fake_score = train_D(D, D_optimizer, criterion, real_imgs, fake_imgs)
            D_loss += _D_loss
            real_score += _real_score
            fake_score += _fake_score

            _G_loss = train_G(G, D, G_optimizer, criterion, fake_imgs)
            G_loss += _G_loss

        train_history["D_loss"].append(D_loss / n_samples)
        train_history["G_loss"].append(G_loss / n_samples)
        train_history["D_real"].append(real_score / n_batches)
        train_history["D_fake"].append(fake_score / n_batches)

        lr = scheduler.get_last_lr()[0]
        train_history["lr"].append(lr)

        G_optimizer.param_groups[0]['lr'] = lr * k

        if epoch == 0 or (epoch+1) % 2 == 0:
            clear_output(wait=True)

            tqdm.write(
                f"Epoch [{epoch+1}/{n_epochs}] lr: {train_history['lr'][-1]:.6f} "
                f"D_loss: {train_history['D_loss'][-1]:.4f} G_loss: {train_history['G_loss'][-1]:.4f} "
                f"D_real: {train_history['D_real'][-1]:.4f} D_fake: {train_history['D_fake'][-1]:.4f} "
            )

            z = torch.randn(batch_size, z_dim).to(device)
            fake_imgs = G(z).detach().cpu()
            show_images(fake_imgs)

        scheduler.step()


    return train_history

In [None]:
D = Discriminator().to(device)
G = Generator().to(device)
D.apply(weights_init)
G.apply(weights_init)

optim_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
optim_G = torch.optim.Adam(G.parameters(), lr=lr*k, betas=(0.5, 0.999))

v= lambda epoch: 1 if epoch < 5 else 0.72 ** (epoch - 5)
scheduler = lr_scheduler.LambdaLR(optim_D, lr_lambda=v)
criterion = nn.BCELoss(reduction='sum')

train_history = fit(D, G, optim_D, optim_G, scheduler, criterion, n_epochs)

Epoch:  95%|█████████▌| 19/20 [37:11<01:51, 111.31s/it]

Epoch [20/20] lr: 0.000001 D_loss: 1.3979 G_loss: 0.6881 D_real: 0.4970 D_fake: 0.5027 




In [None]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

fig = make_subplots(rows=1, cols=3, subplot_titles=("Loss", "D(x) & D(G(z))", "Learning Rate"))

fig.add_traces([
    go.Scatter(y=train_history["D_loss"], mode='lines', name='D_loss'),
    go.Scatter(y=train_history["G_loss"], mode='lines', name='G_loss'),
], rows=1, cols=1)

fig.add_traces([
    go.Scatter(y=train_history["D_real"], mode='lines', name='D(x)'),
    go.Scatter(y=train_history["D_fake"], mode='lines', name='D(G(z))'),
], rows=1, cols=2)

fig.add_hline(
    y=0.5,
    line=dict(color="black", dash="dash"),
    row=1, # type: ignore
    col=2, # type: ignore
)

fig.add_trace(
    go.Scatter(y=train_history["lr"], mode='lines', name='Learning Rate'),
    row=1, col=3
)

fig.update_layout(height=400, width=1200, title_text="GAN Training History")
show_centered(fig)


![train_history](https://cdn.jsdelivr.net/gh/KuiMian/NoteImage@master/2025/10/upgit_20251004_train_history.png)