In [1]:
import argparse
import os
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
import torch.nn as nn
import torch

# 创建文件夹
os.makedirs("./images/gan/", exist_ok=True)  # 保存训练过程的图片
os.makedirs("./save/gan/", exist_ok=True)  # 训练完成时模型保存的位置
os.makedirs("./datasets/mnist", exist_ok=True)  # 下载数据集存放的位置

ModuleNotFoundError: No module named 'numpy'

In [None]:
channels = 1
img_size = 28
batch_size = 64
latent_dim = 100


# 图像的尺寸:(1， 28， 28),  和图像的像素面积:(784)
img_shape = (channels, img_size, img_size)
img_area = np.prod(img_shape)

cuda = True if torch.cuda.is_available() else False
print('cuda:', cuda)

# mnist数据集下载
mnist = datasets.MNIST(
    root='./datasets/', train=True, download=True, transform=transforms.Compose(
        [transforms.Resize(img_size),
         transforms.ToTensor(),
         transforms.Normalize([0.5], [0.5])]
    ),
)

dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True)


In [None]:
# 定义判别器 Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_area, 512),  # 输入特征数为784
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity


# 定义生成器 Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # 模型中间块儿

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, img_area),
            nn.Tanh()  # 将(784)的数据每一个都映射到[-1, 1]之间
        )

    def forward(self, z):  # 输入的是(64， 100)的噪声数据
        imgs = self.model(z)  # 噪声数据通过生成器模型
        imgs = imgs.view(imgs.size(0), *img_shape)
        return imgs


In [None]:
# 创建生成器，判别器对象
generator = Generator()
discriminator = Discriminator()

# 首先需要定义loss的度量方式  （二分类的交叉熵）
criterion = torch.nn.BCELoss()

# 其次定义 优化函数,优化函数的学习率为0.0003
# betas:用于计算梯度以及梯度平方的运行平均值的系数
lr = 0.0002
b1 = 0.5
b2 = 0.999
optimizer_G = torch.optim.Adam(
    generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(
    discriminator.parameters(), lr=lr, betas=(b1, b2))

# 如果有显卡，都在cuda模式中运行
if torch.cuda.is_available():
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    criterion = criterion.cuda()

In [None]:
# ----------
# Training
# ----------
# 进行多个epoch的训练
from torch.autograd import Variable
from torchvision.utils import save_image

n_epochs = 50
sample_interval = 500
for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(dataloader):
        # 训练判别器
        imgs = imgs.view(imgs.size(0), -1)
        # 将tensor变成Variable放入计算图中，tensor变成variable之后才能进行反向传播求梯度
        real_img = Variable(imgs).cuda()
        real_label = Variable(torch.ones(imgs.size(0), 1)
                              ).cuda()  # 定义真实的图片label为1
        fake_label = Variable(torch.zeros(imgs.size(0), 1)
                              ).cuda()  # 定义假的图片的label为0

        # Train Discriminator
        # 计算真实图片的损失
        real_out = discriminator(real_img)
        loss_real_D = criterion(real_out, real_label)
        real_scores = real_out
        # 计算假的图片的损失
        # detach(): 从当前计算图中分离下来避免梯度传到G，因为G不用更新
        z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()
        fake_img = generator(z).detach()  # 随机噪声放入生成网络中，生成一张假的图片。
        fake_out = discriminator(fake_img)
        loss_fake_D = criterion(fake_out, fake_label)
        fake_scores = fake_out
        # 损失函数和优化
        loss_D = loss_real_D + loss_fake_D
        optimizer_D.zero_grad()
        loss_D.backward()
        optimizer_D.step()

        # Train Generator
        z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()
        fake_img = generator(z)
        output = discriminator(fake_img)
        # 损失函数和优化
        loss_G = criterion(output, real_label)
        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()

        # 打印训练过程中的日志
        if (i + 1) % 100 == 0:
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"
                % (epoch, n_epochs, i, len(dataloader), loss_D.item(), loss_G.item(), real_scores.data.mean(), fake_scores.data.mean())
            )
        # 保存训练过程中的图像
        batches_done = epoch * len(dataloader) + i
        if batches_done % sample_interval == 0:
            save_image(fake_img.data[:25], "./images/gan/%d.png" %
                       batches_done, nrow=5, normalize=True)

# 保存模型
torch.save(generator.state_dict(), './save/gan/generator.pth')
torch.save(discriminator.state_dict(), './save/gan/discriminator.pth')


[Epoch 0/50] [Batch 99/938] [D loss: 1.314692] [G loss: 0.664975] [D real: 0.587226] [D fake: 0.515453]
[Epoch 0/50] [Batch 199/938] [D loss: 1.085302] [G loss: 1.073382] [D real: 0.659144] [D fake: 0.483829]
[Epoch 0/50] [Batch 299/938] [D loss: 1.169366] [G loss: 0.688896] [D real: 0.482463] [D fake: 0.332062]
[Epoch 0/50] [Batch 399/938] [D loss: 1.033351] [G loss: 1.379799] [D real: 0.632202] [D fake: 0.420711]
[Epoch 0/50] [Batch 499/938] [D loss: 1.073791] [G loss: 1.228435] [D real: 0.758415] [D fake: 0.538224]
[Epoch 0/50] [Batch 599/938] [D loss: 1.095375] [G loss: 0.857327] [D real: 0.535340] [D fake: 0.319981]
[Epoch 0/50] [Batch 699/938] [D loss: 0.921353] [G loss: 1.397030] [D real: 0.681502] [D fake: 0.387700]
[Epoch 0/50] [Batch 799/938] [D loss: 1.012501] [G loss: 0.808668] [D real: 0.480912] [D fake: 0.199825]
[Epoch 0/50] [Batch 899/938] [D loss: 1.112474] [G loss: 1.684664] [D real: 0.821570] [D fake: 0.587232]
[Epoch 1/50] [Batch 99/938] [D loss: 0.803504] [G loss: 