In [1]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.utils as utils
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import Compose, Normalize, Resize, ToTensor

In [2]:
OUT_PATH = 'output'
IMAGE_SIZE = 64    # 图像尺寸，原图是28*28的，缩放为64*64
BATCH_SIZE = 32
IMAGE_CHANNEL = 1  # 输出图像通道数
Z_DIM = 100
G_HIDDEN = 64
X_DIM = 64
D_HIDDEN = 64
EPOCH_NUM = 10
REAL_LABEL = 1.0
FAKE_LABEL = 0.0
lr = 2e-4
seed = np.random.randint(1, 10000)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

In [3]:
mean = torch.tensor(0.5)
std = torch.tensor(0.5)

compose = Compose([Resize(IMAGE_SIZE, antialias=True) ,ToTensor(), Normalize(mean,std)])
train_dataset = datasets.MNIST('./data', train=True, transform=compose, download=True)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [4]:
def weights_init(m):
    """默认参数是按均匀分布随机初始化的
       为了加速收敛，重新按正态分布初始化
    """
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [5]:
class Generator(nn.Module):
    """ 合成网络将一个z_dim@1*1图像反向卷积为1@64*64的图像
    """
    def __init__(self, z_dim=Z_DIM, g_hidden=G_HIDDEN, 
                 image_channel=IMAGE_CHANNEL) -> None:
        super().__init__()
        self.z_dim = z_dim
        self.g_hidden = g_hidden
        self.image_channel = image_channel
        self.cnn1 = nn.ConvTranspose2d(in_channels=self.z_dim, out_channels=
            self.g_hidden*8, kernel_size=4, stride=1, padding=0, bias=False)
        self.bn1 = nn.BatchNorm2d(num_features=self.g_hidden*8)
        self.cnn2 = nn.ConvTranspose2d(in_channels=self.g_hidden*8, out_channels=
            self.g_hidden*4, kernel_size=4, stride=2, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(num_features=self.g_hidden*4)
        self.cnn3 = nn.ConvTranspose2d(in_channels=self.g_hidden*4, out_channels=
            self.g_hidden*2, kernel_size=4, stride=2, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(num_features=self.g_hidden*2)
        self.cnn4 = nn.ConvTranspose2d(in_channels=self.g_hidden*2, out_channels=
            self.g_hidden, kernel_size=4, stride=2, padding=1, bias=False)
        self.bn4 = nn.BatchNorm2d(num_features=self.g_hidden)
        self.cnn5 = nn.ConvTranspose2d(in_channels=self.g_hidden, out_channels=
            self.image_channel,kernel_size=4, stride=2, padding=1, bias=False)
        
    def forward(self, X):
        # 输入：100@1*1
        X = self.cnn1(X)
        X = self.bn1(X)
        X = F.relu(X, inplace=True)
        # 输入：512@4*4
        X = self.cnn2(X)
        X = self.bn2(X)
        X = F.relu(X, inplace=True)
        # 输入：256@8*8
        X = self.cnn3(X)
        X = self.bn3(X)
        X = F.relu(X, inplace=True)
        # 输入：128@16*16
        X = self.cnn4(X)
        X = self.bn4(X)
        X = F.relu(X, inplace=True)
        # 输入：64@32*32
        X = self.cnn5(X)
        X = F.tanh(X)
        # 输出：1@64*64
        return X

In [6]:
class Discriminator(nn.Module):
    """ 鉴别网络是一个分类网络，但是没有线性层
        通过卷积将输入1@64*64变换为1@1*1
    """
    def __init__(self, d_hidden=D_HIDDEN, image_channel=IMAGE_CHANNEL) -> None:
        super().__init__()
        self.image_channel = image_channel
        self.d_hidden = d_hidden

        self.cnn1 = nn.Conv2d(in_channels=self.image_channel, out_channels=
            self.d_hidden, kernel_size=4, stride=2, padding=1, bias=False)
        self.cnn2 = nn.Conv2d(in_channels=self.d_hidden, out_channels=
            self.d_hidden*2, kernel_size=4, stride=2, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(num_features=self.d_hidden*2)
        self.cnn3 = nn.Conv2d(in_channels=self.d_hidden*2, out_channels=
            self.d_hidden*4, kernel_size=4, stride=2, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(num_features=self.d_hidden*4)
        self.cnn4 = nn.Conv2d(in_channels=self.d_hidden*4, out_channels=
            self.d_hidden*8, kernel_size=4, stride=2, padding=1, bias=False)
        self.bn4 = nn.BatchNorm2d(num_features=self.d_hidden*8)
        self.cnn5 = nn.Conv2d(in_channels=self.d_hidden*8, out_channels=1,
            kernel_size=4, stride=1, padding=0, bias=False)
        
    def forward(self, X):
        # 1@64*64
        X = self.cnn1(X)
        X = F.leaky_relu(X, 0.2, inplace=True)
        # 64@32*32
        X = self.cnn2(X)
        X = self.bn2(X)
        X = F.leaky_relu(X, 0.2, inplace=True)
        # 128@16*16
        X = self.cnn3(X)
        X = self.bn3(X)
        X = F.leaky_relu(X, 0.2, inplace=True)
        # 256@8*8
        X = self.cnn4(X)
        X = self.bn4(X)
        X = F.leaky_relu(X, 0.2, inplace=True)
        # 512@4*4
        X = self.cnn5(X)
        # 1@1*1
        X = F.sigmoid(X)
        return X.view(-1, 1).squeeze(1)

In [7]:
netG = Generator().to(device)
netG.apply(weights_init)

netD = Discriminator().to(device)
netD.apply(weights_init)

loss_fn = nn.BCELoss()

viz_noise = torch.randn(BATCH_SIZE, Z_DIM, 1, 1, device=device)

optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))

for epoch in range(EPOCH_NUM):
    for i, (x_real, _) in enumerate(train_loader):
        x_real = x_real.to(device)
        real_label = torch.full((x_real.size(0),), REAL_LABEL, device=device)
        fake_label = torch.full((x_real.size(0),), FAKE_LABEL, device=device)

        # 根据真实样本更新网络D
        netD.zero_grad()
        y_real = netD(x_real)
        loss_D_real = loss_fn(y_real, real_label)
        loss_D_real.backward()

        # 根据样本数据更新网络D
        z_noise = torch.randn(x_real.size(0), Z_DIM, 1, 1, device=device)
        x_fake = netG(z_noise)
        y_fake = netD(x_fake.detach())
        loss_D_fake = loss_fn(y_fake, fake_label)
        loss_D_fake.backward()
        optimizerD.step()

        # 根据样本数据更新网络G
        netG.zero_grad()
        # x_fake = netG(z_noise)
        y_fake_r = netD(x_fake)
        loss_G = loss_fn(y_fake_r, real_label)
        loss_G.backward()
        optimizerG.step()

        if i % 100 == 0:
            print(f'Epoch {epoch} [{i}/{len(train_loader)}] loss_D_real: {loss_D_real.mean().item():.4f} loss_D_fake: {loss_D_fake.mean().item():.4f} loss_G: {loss_G.mean().item():.4f}')
            
            utils.save_image(x_real, os.path.join(OUT_PATH, f'real_samples_{epoch}.png'), normalize=True)
            with torch.no_grad():
                viz_sample = netG(viz_noise)
                utils.save_image(viz_sample, os.path.join(OUT_PATH, f'fake_samples_{epoch}.png'), normalize=True)
    torch.save(netG.state_dict(), os.path.join(OUT_PATH, f'netG_{epoch}.pth'))
    torch.save(netD.state_dict(), os.path.join(OUT_PATH, f'netD_{epoch}.pth'))


Epoch 0 [0/1875] loss_D_real: 0.2624 loss_D_fake: 1.5612 loss_G: 7.2237
Epoch 0 [100/1875] loss_D_real: 0.0124 loss_D_fake: 0.0000 loss_G: 16.6706
Epoch 0 [200/1875] loss_D_real: 0.0264 loss_D_fake: 0.5759 loss_G: 4.2032
Epoch 0 [300/1875] loss_D_real: 0.1879 loss_D_fake: 0.0709 loss_G: 4.7026
Epoch 0 [400/1875] loss_D_real: 0.1312 loss_D_fake: 0.1293 loss_G: 3.3828
Epoch 0 [500/1875] loss_D_real: 0.1946 loss_D_fake: 0.0253 loss_G: 3.9092
Epoch 0 [600/1875] loss_D_real: 0.0519 loss_D_fake: 0.0690 loss_G: 4.2776
Epoch 0 [700/1875] loss_D_real: 0.3297 loss_D_fake: 0.0522 loss_G: 3.0497
Epoch 0 [800/1875] loss_D_real: 0.1435 loss_D_fake: 0.0455 loss_G: 3.4699
Epoch 0 [900/1875] loss_D_real: 0.0585 loss_D_fake: 1.1727 loss_G: 2.5565
Epoch 0 [1000/1875] loss_D_real: 0.2620 loss_D_fake: 1.0710 loss_G: 6.6138
Epoch 0 [1100/1875] loss_D_real: 0.3057 loss_D_fake: 0.0555 loss_G: 2.3494
Epoch 0 [1200/1875] loss_D_real: 0.1161 loss_D_fake: 0.5404 loss_G: 3.2750
Epoch 0 [1300/1875] loss_D_real: 0.0

RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED