In [43]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.autograd import Variable

In [44]:
#图像处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

In [45]:
#MNIST 数据集
train_dataset = dsets.MNIST(root='../data/', train=True, transform=transform, download=True)

In [46]:
#加载数据集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)

In [53]:
#定义判别网络
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 1)
        
    def forward(self, x):
        h = F.relu(self.fc1(x))
        h = F.relu(self.fc2(h))
        out = F.sigmoid(self.fc3(h))
        return out
    
#定义生成网络
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(128, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 784)
            
    def forward(self, x):
        h = F.leaky_relu(self.fc1(x))
        h = F.leaky_relu(self.fc2(h))
        out = F.tanh(self.fc3(h))
        return out
    
    
discriminator = Discriminator()
generator = Generator()
discriminator.cuda()
generator.cuda()   

Generator (
  (fc1): Linear (128 -> 256)
  (fc2): Linear (256 -> 256)
  (fc3): Linear (256 -> 784)
)

In [54]:
#定义损失函数优化函数
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0005)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0005)

In [61]:
#训练网络
for epoch in range(200):
    for i, (images, _) in enumerate(train_loader):
        images = images.view(images.size(0), -1)
        images = Variable(images.cuda())
        real_labels = Variable(torch.ones(images.size(0))).cuda()
        fake_labels = Variable(torch.zeros(images.size(0))).cuda()
        
        #训练生成网络
        discriminator.zero_grad()
        
        #输入的是真实图片
        outputs = discriminator(images)
        real_loss = criterion(outputs, real_labels)
        real_score = outputs
        #输入的是随机噪声
        noise = Variable(torch.randn(images.size(0), 128)).cuda()
        fake_images = generator(noise)
        outputs = discriminator(fake_images.detach()) 
        fake_loss = criterion(outputs, fake_labels)
        fake_score = outputs
        
        d_loss = real_loss + fake_loss
        d_loss.backward()
        d_optimizer.step()
        
        
        #训练生成网络
        generator.zero_grad()
        noise = Variable(torch.randn(images.size(0), 128)).cuda()
        fake_images = generator(noise)
        outputs = discriminator(fake_images)
        #生成图片与真实图片的对比
        g_loss = criterion(outputs, real_labels)
        g_loss.backward()
        g_optimizer.step()
        
        if (i+1) % 300 == 0:
            print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, ' 
                  'D(x): %.2f, D(G(z)): %.2f' 
                  %(epoch, 200, i+1, 600, d_loss.data[0], g_loss.data[0],
                    real_score.cpu().data.mean(), fake_score.cpu().data.mean()))
            
        #保存采样图片
        fake_images = fake_images.view(fake_images.size(0), 1, 28, 28)
        torchvision.utils.save_image(fake_images.data, './sample/fake_samples_%d.png' % (epoch+1))
        

Epoch [0/200], Step[300/600], d_loss: 0.1373, g_loss: 4.6390, D(x): 0.98, D(G(z)): 0.11
Epoch [0/200], Step[600/600], d_loss: 0.0594, g_loss: 5.2584, D(x): 0.96, D(G(z)): 0.02
Epoch [1/200], Step[300/600], d_loss: 1.6614, g_loss: 1.2567, D(x): 0.54, D(G(z)): 0.43
Epoch [1/200], Step[600/600], d_loss: 0.9018, g_loss: 1.3473, D(x): 0.67, D(G(z)): 0.35
Epoch [2/200], Step[300/600], d_loss: 0.5366, g_loss: 1.4849, D(x): 0.80, D(G(z)): 0.25
Epoch [2/200], Step[600/600], d_loss: 1.5921, g_loss: 1.0523, D(x): 0.50, D(G(z)): 0.43
Epoch [3/200], Step[300/600], d_loss: 1.8828, g_loss: 0.6064, D(x): 0.55, D(G(z)): 0.66
Epoch [3/200], Step[600/600], d_loss: 1.0602, g_loss: 1.2066, D(x): 0.61, D(G(z)): 0.38
Epoch [4/200], Step[300/600], d_loss: 0.9928, g_loss: 1.4936, D(x): 0.69, D(G(z)): 0.40
Epoch [4/200], Step[600/600], d_loss: 1.0316, g_loss: 1.4265, D(x): 0.66, D(G(z)): 0.43
Epoch [5/200], Step[300/600], d_loss: 0.5354, g_loss: 2.2022, D(x): 0.85, D(G(z)): 0.29
Epoch [5/200], Step[600/600], d_

In [65]:
#保存网络
torch.save(generator.state_dict(), 'generator.pkl')
torch.save(discriminator.state_dict(), 'discriminator.pkl')