In [1]:
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
from torch.autograd import Variable
import matplotlib.pyplot as plt
import pylab
import numpy as np

#### 参数设置

In [2]:

latent_size = 64
hidden_size = 256
image_size = 784  # 28*28
num_epochs = 100
batch_size = 32
sample_dir = 'samples'
save_dir = 'save'


In [3]:
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

if not os.path.exists(save_dir):
    os.makedirs(save_dir)

#### 加载数据集

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])

# MNIST 数据集下载
mnist = torchvision.datasets.MNIST(root='./data/',
                                   train=True,
                                   transform=transform,
                                   download=False)

# 利用pytorch库加载数据集
data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                          batch_size=batch_size,
                                          shuffle=True)

## 模型定义

### 判别器

判别器由三层神经网络构成

输入层是28*28的图片展开

隐藏层的维度是256

输入层和隐藏层的激活函数均使用斜率为0.2的LeakyReLu

输出层激活函数使用 sigmoid, 目的是将结果映射到 0,1 区间

### 生成器

输入一个100维的0～1之间的高斯分布，然后通过第一层线性变换将其映射到256维,
然后通过LeakyReLU激活函数，接着进行一个线性变换，再经过一个LeakyReLU激活函数，
然后经过线性变换将其变成784维，最后经过Tanh激活函数是希望生成的假的图片数据分布
能够在-1～1之间。

In [5]:
D = nn.Sequential(
    nn.Linear(image_size, hidden_size),  # 784的特征到256的隐藏层特征
    nn.LeakyReLU(0.2),  # 斜率为0.2的LeakyReLu激活函数
    nn.Linear(hidden_size, hidden_size),  # 线性映射
    nn.LeakyReLU(0.2),
    nn.Linear(hidden_size, 1),
    nn.Sigmoid())

G = nn.Sequential(
    nn.Linear(latent_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, image_size),
    nn.Tanh())

D = D.cuda()
G = G.cuda()

### 训练判别器

 分为两部分：1、真的图像判别为真；2、假的图像判别为假
 此过程中，生成器参数不断更新

 其次定义 优化函数,优化函数的学习率为0.0002

In [6]:
criterion = nn.BCELoss()  # 是单目标二分类交叉熵函数
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)


def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)


def reset_grad():
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()


d_losses = np.zeros(num_epochs)
g_losses = np.zeros(num_epochs)
real_scores = np.zeros(num_epochs)
fake_scores = np.zeros(num_epochs)


## 模型训练


In [7]:
total_step = len(data_loader)
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(data_loader):
        # view()函数作用是将一个多行的Tensor,拼接成一行
        # 第一个参数是要拼接的tensor,第二个参数是-1
        images = images.view(batch_size, -1).cuda()
        images = Variable(images)
        real_labels = torch.ones(batch_size, 1).cuda()  # 定义真实的图片label为1
        real_labels = Variable(real_labels)
        fake_labels = torch.zeros(batch_size, 1).cuda()  # 定义假的图片的label为0
        fake_labels = Variable(fake_labels)

        ## 都要再将其转化为Tensor 给模型训练

        ## 训练判别器

        # 用真实的图片取训练判别器
        outputs = D(images)
        d_loss_real = criterion(outputs, real_labels)
        real_score = outputs

        # 用假图片训练判别器
        z = torch.randn(batch_size, latent_size).cuda()
        z = Variable(z)
        fake_images = G(z)
        outputs = D(fake_images)
        d_loss_fake = criterion(outputs, fake_labels)
        fake_score = outputs

        # 反向传播误差 通过优化函数取迭代每一步的参数
        d_loss = d_loss_real + d_loss_fake
        reset_grad()
        d_loss.backward()
        d_optimizer.step()

        # 训练生成器

        z = torch.randn(batch_size, latent_size).cuda()
        z = Variable(z)
        fake_images = G(z)
        outputs = D(fake_images)

        # 我们通过最大化 log(D(G(z)) 函数来训练生成器
        g_loss = criterion(outputs, real_labels)

        # 反向传播误差
        reset_grad()
        g_loss.backward()
        g_optimizer.step()

        # 更新得到的数据
        d_losses[epoch] = d_losses[epoch] * (i / (i + 1.)) + d_loss.data * (1. / (i + 1.))
        g_losses[epoch] = g_losses[epoch] * (i / (i + 1.)) + g_loss.data * (1. / (i + 1.))
        real_scores[epoch] = real_scores[epoch] * (i / (i + 1.)) + real_score.mean().data * (1. / (i + 1.))
        fake_scores[epoch] = fake_scores[epoch] * (i / (i + 1.)) + fake_score.mean().data * (1. / (i + 1.))

        if (i + 1) % 200 == 0:
            print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'
                  .format(epoch, num_epochs, i + 1, total_step, d_loss.data, g_loss.data,
                          real_score.mean().data, fake_score.mean().data))
        # 保存真图片

    if (epoch + 1) == 1:
        images = images.view(images.size(0), 1, 28, 28)
        save_image(denorm(images.data), os.path.join(sample_dir, 'real_images.png'))

    # 保存每一轮生成的假图片
    fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
    save_image(denorm(fake_images.data), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch + 1)))

    # 保存数据并绘图
    np.save(os.path.join(save_dir, 'd_losses.npy'), d_losses)
    np.save(os.path.join(save_dir, 'g_losses.npy'), g_losses)
    np.save(os.path.join(save_dir, 'fake_scores.npy'), fake_scores)
    np.save(os.path.join(save_dir, 'real_scores.npy'), real_scores)

    plt.figure()
    pylab.xlim(0, num_epochs + 1)
    plt.plot(range(1, num_epochs + 1), d_losses, label='d loss')
    plt.plot(range(1, num_epochs + 1), g_losses, label='g loss')
    plt.legend()
    plt.savefig(os.path.join(save_dir, 'loss.pdf'))
    plt.close()

    plt.figure()
    pylab.xlim(0, num_epochs + 1)
    pylab.ylim(0, 1)
    plt.plot(range(1, num_epochs + 1), fake_scores, label='fake score')
    plt.plot(range(1, num_epochs + 1), real_scores, label='real score')
    plt.legend()
    plt.savefig(os.path.join(save_dir, 'accuracy.pdf'))
    plt.close()


Epoch [0/100], Step [200/1875], d_loss: 0.0587, g_loss: 3.7573, D(x): 0.99, D(G(z)): 0.05
Epoch [0/100], Step [400/1875], d_loss: 0.2250, g_loss: 4.8788, D(x): 0.93, D(G(z)): 0.09
Epoch [0/100], Step [600/1875], d_loss: 0.0379, g_loss: 6.2319, D(x): 0.97, D(G(z)): 0.01
Epoch [0/100], Step [800/1875], d_loss: 0.1264, g_loss: 4.8664, D(x): 0.94, D(G(z)): 0.05
Epoch [0/100], Step [1000/1875], d_loss: 0.1927, g_loss: 2.5057, D(x): 0.99, D(G(z)): 0.17
Epoch [0/100], Step [1200/1875], d_loss: 0.1646, g_loss: 5.3826, D(x): 0.95, D(G(z)): 0.02
Epoch [0/100], Step [1400/1875], d_loss: 0.4125, g_loss: 2.9136, D(x): 0.90, D(G(z)): 0.21
Epoch [0/100], Step [1600/1875], d_loss: 0.2241, g_loss: 4.5864, D(x): 0.90, D(G(z)): 0.10
Epoch [0/100], Step [1800/1875], d_loss: 0.9241, g_loss: 3.6340, D(x): 0.72, D(G(z)): 0.23
Epoch [1/100], Step [200/1875], d_loss: 0.2924, g_loss: 2.9576, D(x): 0.93, D(G(z)): 0.17
Epoch [1/100], Step [400/1875], d_loss: 0.7535, g_loss: 2.1950, D(x): 0.83, D(G(z)): 0.37
Epoch

Epoch [10/100], Step [200/1875], d_loss: 0.6198, g_loss: 2.6211, D(x): 0.86, D(G(z)): 0.12
Epoch [10/100], Step [400/1875], d_loss: 0.9265, g_loss: 3.1154, D(x): 0.79, D(G(z)): 0.24
Epoch [10/100], Step [600/1875], d_loss: 0.2794, g_loss: 3.5018, D(x): 0.87, D(G(z)): 0.04
Epoch [10/100], Step [800/1875], d_loss: 0.4389, g_loss: 3.5904, D(x): 0.89, D(G(z)): 0.14
Epoch [10/100], Step [1000/1875], d_loss: 0.2938, g_loss: 3.2127, D(x): 0.94, D(G(z)): 0.16
Epoch [10/100], Step [1200/1875], d_loss: 0.3084, g_loss: 3.1836, D(x): 0.95, D(G(z)): 0.19
Epoch [10/100], Step [1400/1875], d_loss: 0.2762, g_loss: 3.7759, D(x): 0.89, D(G(z)): 0.08
Epoch [10/100], Step [1600/1875], d_loss: 0.7426, g_loss: 2.7195, D(x): 0.82, D(G(z)): 0.26
Epoch [10/100], Step [1800/1875], d_loss: 0.4169, g_loss: 3.5733, D(x): 0.88, D(G(z)): 0.13
Epoch [11/100], Step [200/1875], d_loss: 0.6362, g_loss: 3.7300, D(x): 0.79, D(G(z)): 0.10
Epoch [11/100], Step [400/1875], d_loss: 0.3714, g_loss: 2.2955, D(x): 0.86, D(G(z)):

Epoch [19/100], Step [1800/1875], d_loss: 0.8409, g_loss: 2.1671, D(x): 0.66, D(G(z)): 0.20
Epoch [20/100], Step [200/1875], d_loss: 0.5060, g_loss: 2.0669, D(x): 0.83, D(G(z)): 0.19
Epoch [20/100], Step [400/1875], d_loss: 1.4575, g_loss: 3.5038, D(x): 0.48, D(G(z)): 0.07
Epoch [20/100], Step [600/1875], d_loss: 0.5728, g_loss: 2.6245, D(x): 0.78, D(G(z)): 0.19
Epoch [20/100], Step [800/1875], d_loss: 0.6230, g_loss: 2.1273, D(x): 0.83, D(G(z)): 0.27
Epoch [20/100], Step [1000/1875], d_loss: 0.6590, g_loss: 2.4989, D(x): 0.83, D(G(z)): 0.27
Epoch [20/100], Step [1200/1875], d_loss: 0.9063, g_loss: 2.2102, D(x): 0.69, D(G(z)): 0.23
Epoch [20/100], Step [1400/1875], d_loss: 0.8942, g_loss: 2.3477, D(x): 0.61, D(G(z)): 0.13
Epoch [20/100], Step [1600/1875], d_loss: 0.7878, g_loss: 2.0480, D(x): 0.71, D(G(z)): 0.19
Epoch [20/100], Step [1800/1875], d_loss: 0.9528, g_loss: 1.6386, D(x): 0.62, D(G(z)): 0.15
Epoch [21/100], Step [200/1875], d_loss: 0.5508, g_loss: 2.2191, D(x): 0.86, D(G(z))

Epoch [29/100], Step [1600/1875], d_loss: 0.6659, g_loss: 1.3372, D(x): 0.76, D(G(z)): 0.22
Epoch [29/100], Step [1800/1875], d_loss: 0.8606, g_loss: 1.6311, D(x): 0.62, D(G(z)): 0.18
Epoch [30/100], Step [200/1875], d_loss: 0.7205, g_loss: 1.5978, D(x): 0.79, D(G(z)): 0.32
Epoch [30/100], Step [400/1875], d_loss: 0.6170, g_loss: 2.2186, D(x): 0.80, D(G(z)): 0.22
Epoch [30/100], Step [600/1875], d_loss: 0.9863, g_loss: 1.9899, D(x): 0.62, D(G(z)): 0.22
Epoch [30/100], Step [800/1875], d_loss: 0.9728, g_loss: 2.0707, D(x): 0.61, D(G(z)): 0.26
Epoch [30/100], Step [1000/1875], d_loss: 0.7466, g_loss: 1.6589, D(x): 0.72, D(G(z)): 0.23
Epoch [30/100], Step [1200/1875], d_loss: 1.1761, g_loss: 1.2864, D(x): 0.81, D(G(z)): 0.48
Epoch [30/100], Step [1400/1875], d_loss: 0.9941, g_loss: 1.4570, D(x): 0.74, D(G(z)): 0.36
Epoch [30/100], Step [1600/1875], d_loss: 1.0021, g_loss: 2.3130, D(x): 0.58, D(G(z)): 0.23
Epoch [30/100], Step [1800/1875], d_loss: 0.8216, g_loss: 1.2812, D(x): 0.74, D(G(z)

Epoch [39/100], Step [1400/1875], d_loss: 0.7491, g_loss: 1.9368, D(x): 0.72, D(G(z)): 0.26
Epoch [39/100], Step [1600/1875], d_loss: 0.9810, g_loss: 1.2525, D(x): 0.71, D(G(z)): 0.37
Epoch [39/100], Step [1800/1875], d_loss: 0.8143, g_loss: 1.6173, D(x): 0.81, D(G(z)): 0.35
Epoch [40/100], Step [200/1875], d_loss: 0.9227, g_loss: 1.1987, D(x): 0.77, D(G(z)): 0.34
Epoch [40/100], Step [400/1875], d_loss: 0.9826, g_loss: 1.4216, D(x): 0.64, D(G(z)): 0.29
Epoch [40/100], Step [600/1875], d_loss: 0.9238, g_loss: 1.5173, D(x): 0.82, D(G(z)): 0.36
Epoch [40/100], Step [800/1875], d_loss: 1.0040, g_loss: 1.3776, D(x): 0.56, D(G(z)): 0.21
Epoch [40/100], Step [1000/1875], d_loss: 0.6095, g_loss: 1.8043, D(x): 0.80, D(G(z)): 0.25
Epoch [40/100], Step [1200/1875], d_loss: 0.7896, g_loss: 1.5460, D(x): 0.75, D(G(z)): 0.30
Epoch [40/100], Step [1400/1875], d_loss: 0.8533, g_loss: 1.4861, D(x): 0.83, D(G(z)): 0.42
Epoch [40/100], Step [1600/1875], d_loss: 1.1149, g_loss: 1.2469, D(x): 0.60, D(G(z)

Epoch [49/100], Step [1200/1875], d_loss: 1.0442, g_loss: 1.7993, D(x): 0.58, D(G(z)): 0.28
Epoch [49/100], Step [1400/1875], d_loss: 1.0361, g_loss: 1.2130, D(x): 0.68, D(G(z)): 0.39
Epoch [49/100], Step [1600/1875], d_loss: 0.9041, g_loss: 2.3120, D(x): 0.63, D(G(z)): 0.14
Epoch [49/100], Step [1800/1875], d_loss: 1.0528, g_loss: 1.7149, D(x): 0.61, D(G(z)): 0.29
Epoch [50/100], Step [200/1875], d_loss: 1.1232, g_loss: 1.2430, D(x): 0.72, D(G(z)): 0.42
Epoch [50/100], Step [400/1875], d_loss: 0.8278, g_loss: 1.6715, D(x): 0.74, D(G(z)): 0.32
Epoch [50/100], Step [600/1875], d_loss: 1.0260, g_loss: 1.0441, D(x): 0.73, D(G(z)): 0.38
Epoch [50/100], Step [800/1875], d_loss: 0.7746, g_loss: 1.6993, D(x): 0.71, D(G(z)): 0.26
Epoch [50/100], Step [1000/1875], d_loss: 0.7689, g_loss: 1.7237, D(x): 0.77, D(G(z)): 0.33
Epoch [50/100], Step [1200/1875], d_loss: 0.8448, g_loss: 1.4703, D(x): 0.70, D(G(z)): 0.28
Epoch [50/100], Step [1400/1875], d_loss: 1.1030, g_loss: 1.6627, D(x): 0.61, D(G(z)

Epoch [59/100], Step [1000/1875], d_loss: 0.9219, g_loss: 2.0341, D(x): 0.71, D(G(z)): 0.33
Epoch [59/100], Step [1200/1875], d_loss: 0.9988, g_loss: 1.7073, D(x): 0.67, D(G(z)): 0.35
Epoch [59/100], Step [1400/1875], d_loss: 0.9951, g_loss: 1.7663, D(x): 0.78, D(G(z)): 0.42
Epoch [59/100], Step [1600/1875], d_loss: 0.9148, g_loss: 1.2574, D(x): 0.68, D(G(z)): 0.32
Epoch [59/100], Step [1800/1875], d_loss: 1.2321, g_loss: 1.5480, D(x): 0.53, D(G(z)): 0.29
Epoch [60/100], Step [200/1875], d_loss: 0.8553, g_loss: 1.4564, D(x): 0.64, D(G(z)): 0.25
Epoch [60/100], Step [400/1875], d_loss: 1.0116, g_loss: 1.1900, D(x): 0.66, D(G(z)): 0.35
Epoch [60/100], Step [600/1875], d_loss: 1.1448, g_loss: 1.3937, D(x): 0.68, D(G(z)): 0.42
Epoch [60/100], Step [800/1875], d_loss: 0.8927, g_loss: 1.4152, D(x): 0.68, D(G(z)): 0.32
Epoch [60/100], Step [1000/1875], d_loss: 1.1190, g_loss: 1.4037, D(x): 0.68, D(G(z)): 0.33
Epoch [60/100], Step [1200/1875], d_loss: 1.0541, g_loss: 1.3495, D(x): 0.61, D(G(z)

Epoch [69/100], Step [800/1875], d_loss: 0.9066, g_loss: 1.5735, D(x): 0.71, D(G(z)): 0.30
Epoch [69/100], Step [1000/1875], d_loss: 0.8092, g_loss: 1.7600, D(x): 0.67, D(G(z)): 0.23
Epoch [69/100], Step [1200/1875], d_loss: 1.0720, g_loss: 1.3418, D(x): 0.72, D(G(z)): 0.42
Epoch [69/100], Step [1400/1875], d_loss: 0.7535, g_loss: 2.0887, D(x): 0.71, D(G(z)): 0.24
Epoch [69/100], Step [1600/1875], d_loss: 0.7470, g_loss: 1.0042, D(x): 0.75, D(G(z)): 0.31
Epoch [69/100], Step [1800/1875], d_loss: 0.8084, g_loss: 1.4245, D(x): 0.65, D(G(z)): 0.20
Epoch [70/100], Step [200/1875], d_loss: 0.9113, g_loss: 0.8449, D(x): 0.80, D(G(z)): 0.40
Epoch [70/100], Step [400/1875], d_loss: 0.9810, g_loss: 1.3560, D(x): 0.76, D(G(z)): 0.38
Epoch [70/100], Step [600/1875], d_loss: 0.9312, g_loss: 1.7846, D(x): 0.56, D(G(z)): 0.21
Epoch [70/100], Step [800/1875], d_loss: 0.8395, g_loss: 1.7277, D(x): 0.69, D(G(z)): 0.27
Epoch [70/100], Step [1000/1875], d_loss: 0.8428, g_loss: 1.9267, D(x): 0.68, D(G(z))

Epoch [79/100], Step [600/1875], d_loss: 1.1639, g_loss: 1.5018, D(x): 0.64, D(G(z)): 0.34
Epoch [79/100], Step [800/1875], d_loss: 1.2432, g_loss: 1.1649, D(x): 0.58, D(G(z)): 0.31
Epoch [79/100], Step [1000/1875], d_loss: 1.1082, g_loss: 1.2962, D(x): 0.76, D(G(z)): 0.43
Epoch [79/100], Step [1200/1875], d_loss: 1.0520, g_loss: 1.2582, D(x): 0.73, D(G(z)): 0.41
Epoch [79/100], Step [1400/1875], d_loss: 0.8175, g_loss: 1.4137, D(x): 0.73, D(G(z)): 0.31
Epoch [79/100], Step [1600/1875], d_loss: 1.0114, g_loss: 1.3629, D(x): 0.63, D(G(z)): 0.33
Epoch [79/100], Step [1800/1875], d_loss: 0.9708, g_loss: 1.1151, D(x): 0.69, D(G(z)): 0.31
Epoch [80/100], Step [200/1875], d_loss: 0.6929, g_loss: 1.7520, D(x): 0.75, D(G(z)): 0.26
Epoch [80/100], Step [400/1875], d_loss: 0.9951, g_loss: 1.1492, D(x): 0.72, D(G(z)): 0.38
Epoch [80/100], Step [600/1875], d_loss: 0.9750, g_loss: 1.3094, D(x): 0.70, D(G(z)): 0.37
Epoch [80/100], Step [800/1875], d_loss: 1.1657, g_loss: 1.5061, D(x): 0.62, D(G(z)):

Epoch [89/100], Step [400/1875], d_loss: 0.8882, g_loss: 1.0469, D(x): 0.72, D(G(z)): 0.33
Epoch [89/100], Step [600/1875], d_loss: 0.9151, g_loss: 1.5268, D(x): 0.76, D(G(z)): 0.38
Epoch [89/100], Step [800/1875], d_loss: 0.9694, g_loss: 1.3342, D(x): 0.72, D(G(z)): 0.38
Epoch [89/100], Step [1000/1875], d_loss: 0.6857, g_loss: 1.1489, D(x): 0.74, D(G(z)): 0.26
Epoch [89/100], Step [1200/1875], d_loss: 0.9596, g_loss: 1.4175, D(x): 0.68, D(G(z)): 0.32
Epoch [89/100], Step [1400/1875], d_loss: 1.3338, g_loss: 1.3056, D(x): 0.62, D(G(z)): 0.45
Epoch [89/100], Step [1600/1875], d_loss: 0.8364, g_loss: 1.1456, D(x): 0.83, D(G(z)): 0.40
Epoch [89/100], Step [1800/1875], d_loss: 0.7836, g_loss: 1.2608, D(x): 0.78, D(G(z)): 0.34
Epoch [90/100], Step [200/1875], d_loss: 0.8755, g_loss: 1.2264, D(x): 0.70, D(G(z)): 0.33
Epoch [90/100], Step [400/1875], d_loss: 1.1811, g_loss: 1.3955, D(x): 0.52, D(G(z)): 0.28
Epoch [90/100], Step [600/1875], d_loss: 0.8803, g_loss: 1.6449, D(x): 0.73, D(G(z)):

Epoch [99/100], Step [200/1875], d_loss: 0.8077, g_loss: 1.5056, D(x): 0.75, D(G(z)): 0.32
Epoch [99/100], Step [400/1875], d_loss: 1.0466, g_loss: 1.5229, D(x): 0.63, D(G(z)): 0.33
Epoch [99/100], Step [600/1875], d_loss: 1.2768, g_loss: 1.3586, D(x): 0.55, D(G(z)): 0.36
Epoch [99/100], Step [800/1875], d_loss: 1.1079, g_loss: 1.2901, D(x): 0.64, D(G(z)): 0.37
Epoch [99/100], Step [1000/1875], d_loss: 0.9466, g_loss: 1.4828, D(x): 0.66, D(G(z)): 0.32
Epoch [99/100], Step [1200/1875], d_loss: 0.9838, g_loss: 1.4148, D(x): 0.61, D(G(z)): 0.25
Epoch [99/100], Step [1400/1875], d_loss: 0.9303, g_loss: 1.2012, D(x): 0.66, D(G(z)): 0.32
Epoch [99/100], Step [1600/1875], d_loss: 0.8438, g_loss: 1.5138, D(x): 0.76, D(G(z)): 0.36
Epoch [99/100], Step [1800/1875], d_loss: 1.2586, g_loss: 1.0770, D(x): 0.64, D(G(z)): 0.38


## 训练结果

#### 100次迭代后的生成的图片

<img src='./samples/fake_images-100.png' align='left'>