## 读取数据集

In [None]:
from torchvision.datasets import CIFAR10 #下载图片数据集
from torch.utils.data import DataLoader #读取批次
import torchvision.transforms as transforms #张量转换
from torchvision.utils import save_image #保存图片

In [None]:
#读取数据集
dataset = CIFAR10(root = './data', 
                 download = True, transform = transforms.ToTensor()) #下载数据集
dataloader = DataLoader(dataset, batch_size= 64, shuffle= True) #按批次读取数据(一批64张，总共有50000张，所以有50000/64=781+1=782批)，shuffle= True打乱数据

In [None]:
for batch_idx, data in enumerate(dataloader):
    real_images, _ = data
    batch_size = real_images.shape[0]
    print('#{} has {} images.'.format(batch_idx, batch_size))
    if batch_idx % 100 == 0:
        path = './data/CIFAR10_shuffled_batch{:03d}.png'.format(batch_idx)
        save_image(real_images, path, nrow=8, normalize=False)

In [None]:
#对上面for循环的解释：
#batch_idx 是第几批，batch_size是一批有几个（64个，最后一批凑不齐只有16个）也可以用real_images.size(0)来表示。
#save_image是保存图片，使用if条件语句，每100批保存当批的图片，共有782批，所以会保存8张图。每一张图有64张小图，设置显示行数为8，就会有8*8=64张小图凑成的一张大图，normallize设置是否调整输张量的范围到[0,1]

## 设置网络

In [None]:
import torch.nn as nn

In [None]:
#构建鉴别网络
n_d_feature = 64 #潜在大小
n_channel = 3 #输入通道数
dnet = nn.Sequential(
        nn.Conv2d(n_channel, n_d_feature, kernel_size=4,
                 stride=2, padding=1),
        nn.LeakyReLU(0.2),
        nn.Conv2d(n_d_feature, 2 * n_d_feature, kernel_size=4,
                 stride=2, padding=1, bias=False),
        nn.BatchNorm2d(2 * n_d_feature),
        nn.LeakyReLU(0.2),
        nn.Conv2d(2 * n_d_feature, 4 * n_d_feature, kernel_size=4,
                 stride=2, padding=1, bias=False),
        nn.BatchNorm2d(4 * n_d_feature),
        nn.LeakyReLU(0.2),
        nn.Conv2d(4 * n_d_feature, 1, kernel_size=4))
print(dnet)

In [None]:
#构建生成网络
latent_size = 64 #输入通道数
n_g_feature = 64 #输出通道数

gnet = nn.Sequential(
        nn.ConvTranspose2d(latent_size, 4 * n_g_feature, kernel_size=4,
                          bias=False),
        nn.BatchNorm2d(4 * n_g_feature),
        nn.ReLU(),
        nn.ConvTranspose2d(4 * n_g_feature, 2 * n_g_feature, kernel_size=4,
                          stride=2, padding=1, bias=False),
        nn.BatchNorm2d(2 * n_g_feature),
        nn.ReLU(),
        nn.ConvTranspose2d(2 * n_g_feature, n_g_feature, kernel_size=4,
                          stride=2, padding=1, bias=False),
        nn.BatchNorm2d(n_g_feature),
        nn.ReLU(),
        nn.ConvTranspose2d(n_g_feature, n_channel, kernel_size=4,
                          stride=2, padding=1),
        nn.Sigmoid())
print(gnet)

In [None]:
#网络初始化
import torch.nn.init as init
def weights_init(m):
    if type(m) in [nn.ConvTranspose2d, nn.Conv2d]:
        init.xavier_normal_(m.weight)
    elif type(m) == nn.BatchNorm2d:
        init.normal_(m.weight, 1.0, 0.02)
        init.constant_(m.bias, 0)
        
gnet.apply(weights_init)
dnet.apply(weights_init)

## 训练网络

In [None]:
import torch
import torch.optim

In [None]:
#载入cpu训练的预参数
#checkpoint_d = torch.load('discriminator.pth', map_location='cpu')
#checkpoint_g = torch.load('generator.pth', map_location='cpu')
#dnet.load_state_dict(checkpoint_d)
#gnet.load_state_dict(checkpoint_g)

In [None]:
#定义损失
criterion = nn.BCEWithLogitsLoss()
#定义优化器
goptimizer = torch.optim.Adam(gnet.parameters(),
                             lr=0.0002, betas=(0.5, 0.999))
doptimizer = torch.optim.Adam(dnet.parameters(),
                             lr=0.0002, betas=(0.5, 0.999))

In [None]:
#生成噪音数据，输入到G网络的数据
batch_size = 64
fixed_noises = torch.randn(batch_size, latent_size, 1, 1)

In [None]:
#开始训练
epoch_num = 5#共训练10个周期
for epoch in range(epoch_num):
    for batch_idx, data in enumerate(dataloader):
        real_images, _ = data
        batch_size = real_images.shape[0]
        #训练判别器D
        labels = torch.ones(batch_size)#真实数据的标签：1
        preds = dnet(real_images) #将真实数据喂给D网络
        outputs = preds.reshape(-1) #转换成未知行
        dloss_real = criterion(outputs, labels)
        dmean_real = outputs.sigmoid().mean() #计算判别器将多少真数据判别为真，仅用于输出显示
        
        noises = torch.randn(batch_size, latent_size, 1, 1)
        fake_images = gnet(noises) #生成假数据
        labels = torch.zeros(batch_size)#生成假数据的标签：0
        fake = fake_images.detach() #类似于固定生成器参数
        preds = dnet(fake) #将假数据喂给判别器
        outputs = preds.reshape(-1)#转换成未知行
        dloss_fake = criterion(outputs, labels)
        dmean_fake = outputs.sigmoid().mean() #计算判别器将多少假数据判断为真，仅用于输出显示
        
        dloss = dloss_real + dloss_fake #总的鉴别器损失为两者之和
        dnet.zero_grad()#梯度清零
        dloss.backward()#反向传播
        doptimizer.step()
        
        #训练生成器G
        labels = torch.ones(batch_size)#在训练生成器G时，希望生成器的标签为1
        preds = dnet(fake_images)#让假数据通过鉴别网络
        outputs = preds.reshape(-1)#转换成未知行
        gloss = criterion(outputs, labels)
        gmean_fake = outputs.sigmoid().mean() #计算判别器将多少假数据判断为真，仅用于输出显示
        
        gnet.zero_grad()#梯度清零
        gloss.backward()#反向传播
        goptimizer.step()
        
        #输出本步训练结果
        print('[{}/{}]'.format(epoch, epoch_num) + '[{}/{}]'.format(batch_idx, len(dataloader)) +
             '鉴别器G损失:{:g} 生成器D损失：{:g}'.format(dloss, gloss) + 
             '真数据判真比例：{:g} 假数据判真比例：{:g}/{:g}'.format(dmean_real, dmean_fake, gmean_fake))
        if batch_idx % 100 == 0:
            fake = gnet(fixed_noises) #噪声生成假数据
            path = './data/images_epoch{:02d}_batch{:03d}.png'.format(epoch, batch_idx)
            save_image(fake, path, normalize=False)

## 保存模型

In [None]:
torch.save(gnet.state_dict(),'./generator.pth')
torch.save(dnet.state_dict(),'./discriminator.pth')