# PyTorch_Learningn (MIVRC)

![](https://ws1.sinaimg.cn/large/abaebc48ly1fqgrsuu355j209x02wjr7.jpg)

## 课程5：卷积生成对抗网络构建（GAN） - 以CIFAR10生成为例

![](https://ws1.sinaimg.cn/large/abaebc48ly1fqixn0bkasj20g70grjul.jpg)

### 代码示例：

In [1]:
#导入所需的库/包
import torch
import torch.nn as nn
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.autograd import Variable

In [2]:
#定义超参
num_epochs = 50
batch_size = 100
learning_rate = 0.002

In [3]:
#图片预处理
transform = transforms.Compose([
        transforms.Resize(36),
        transforms.RandomCrop(32),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])

In [4]:
#下载 CIFAR-10 数据集
train_dataset = datasets.CIFAR10(root='./data/', train=True, transform=transform, download=True)
#加载数据集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

Files already downloaded and verified


In [5]:
# 构建判别网络
# 实质上就是一个二分类网络

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=4,stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(16, 32, kernel_size=4,stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(32, 64, kernel_size=4,stride=2, padding=1, bias=False), 
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 1, kernel_size=4),
            nn.Sigmoid())
    
    def forward(self, x):
        out = self.model(x)
        out = out.view(out.size(0), -1)
        return out

In [6]:
D = Discriminator()
D.cuda()

Discriminator(
  (model): Sequential(
    (0): Conv2d(3, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(0.2, inplace)
    (2): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True)
    (4): LeakyReLU(0.2, inplace)
    (5): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (7): LeakyReLU(0.2, inplace)
    (8): Conv2d(64, 1, kernel_size=(4, 4), stride=(1, 1))
    (9): Sigmoid()
  )
)

In [7]:
#构建生成网络

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(16, 3, kernel_size=4, stride=2, padding=1, bias=True),
            nn.Tanh())
    
    def forward(self, x):
        x = x.view(x.size(0), 128, 1, 1)
        out = self.model(x)
        return out

In [8]:
G = Generator()
G.cuda()

Generator(
  (model): Sequential(
    (0): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (2): ReLU(inplace)
    (3): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True)
    (5): ReLU(inplace)
    (6): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True)
    (8): ReLU(inplace)
    (9): ConvTranspose2d(16, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (10): Tanh()
  )
)

In [9]:
# 定义损失函数
criterion = nn.BCELoss()

# 定义优化器
d_optimizer = torch.optim.Adam(D.parameters(), lr=learning_rate)
g_optimizer = torch.optim.Adam(G.parameters(), lr=learning_rate)

In [10]:
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(train_loader):
        images = Variable(images).cuda()
        real_labels = Variable(torch.ones(images.size(0))).cuda()
        fake_labels = Variable(torch.zeros(images.size(0))).cuda()
        
    # 训练判别网络（D）
        # 判断真样本
        outputs = D(images)
        real_loss = criterion(outputs, real_labels)
        
        # 判断假样本
        noise = Variable(torch.randn(images.size(0), 128)).cuda()
        fake_images = G(noise)
        outputs = D(fake_images) 
        fake_loss = criterion(outputs, fake_labels)
        
        # 损失总和
        d_loss = real_loss + fake_loss
        
        # 反向传播及迭代更新
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
    # 训练生成器（G）
        noise = Variable(torch.randn(images.size(0), 128)).cuda()
        fake_images = G(noise)
        outputs = D(fake_images)
        g_loss = criterion(outputs, real_labels)
        
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()
        
        if (i+1) % 100 == 0:
            print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, g_loss: %.4f, '  'D(x): %.2f, D(G(z)): %.2f' 
                  %(epoch, num_epochs, i+1, 500, d_loss.data[0], g_loss.data[0], real_loss.data.mean(), fake_loss.data.mean()))
            
            # 保存生成的图像
            torchvision.utils.save_image(fake_images.data, './sample/fake_samples_%d_%d.png' %(epoch+1, i+1))

  "Please ensure they have the same size.".format(target.size(), input.size()))


Epoch [0/50], Step[100/500], d_loss: 0.0862, g_loss: 6.5287, D(x): 0.06, D(G(z)): 0.03
Epoch [0/50], Step[200/500], d_loss: 0.0472, g_loss: 5.3358, D(x): 0.01, D(G(z)): 0.03
Epoch [0/50], Step[300/500], d_loss: 0.0122, g_loss: 9.8844, D(x): 0.00, D(G(z)): 0.01
Epoch [0/50], Step[400/500], d_loss: 0.0882, g_loss: 9.8392, D(x): 0.09, D(G(z)): 0.00
Epoch [0/50], Step[500/500], d_loss: 0.0823, g_loss: 6.4358, D(x): 0.02, D(G(z)): 0.06
Epoch [1/50], Step[100/500], d_loss: 0.0337, g_loss: 8.6659, D(x): 0.01, D(G(z)): 0.02
Epoch [1/50], Step[200/500], d_loss: 0.3523, g_loss: 10.3504, D(x): 0.35, D(G(z)): 0.00
Epoch [1/50], Step[300/500], d_loss: 0.1087, g_loss: 6.0589, D(x): 0.06, D(G(z)): 0.05
Epoch [1/50], Step[400/500], d_loss: 0.0164, g_loss: 8.2462, D(x): 0.01, D(G(z)): 0.01
Epoch [1/50], Step[500/500], d_loss: 0.0106, g_loss: 10.5304, D(x): 0.01, D(G(z)): 0.00
Epoch [2/50], Step[100/500], d_loss: 0.2328, g_loss: 5.0847, D(x): 0.08, D(G(z)): 0.16
Epoch [2/50], Step[200/500], d_loss: 0.07

Epoch [18/50], Step[500/500], d_loss: 1.3838, g_loss: 2.2400, D(x): 0.23, D(G(z)): 1.15
Epoch [19/50], Step[100/500], d_loss: 0.1044, g_loss: 4.0532, D(x): 0.05, D(G(z)): 0.05
Epoch [19/50], Step[200/500], d_loss: 0.3336, g_loss: 6.5023, D(x): 0.33, D(G(z)): 0.00
Epoch [19/50], Step[300/500], d_loss: 2.0037, g_loss: 1.4236, D(x): 0.09, D(G(z)): 1.92
Epoch [19/50], Step[400/500], d_loss: 0.6323, g_loss: 6.0423, D(x): 0.56, D(G(z)): 0.07
Epoch [19/50], Step[500/500], d_loss: 0.5689, g_loss: 3.2707, D(x): 0.39, D(G(z)): 0.18
Epoch [20/50], Step[100/500], d_loss: 0.2628, g_loss: 5.0705, D(x): 0.23, D(G(z)): 0.03
Epoch [20/50], Step[200/500], d_loss: 0.0739, g_loss: 5.2666, D(x): 0.03, D(G(z)): 0.05
Epoch [20/50], Step[300/500], d_loss: 0.5625, g_loss: 4.2440, D(x): 0.36, D(G(z)): 0.20
Epoch [20/50], Step[400/500], d_loss: 0.2778, g_loss: 5.2569, D(x): 0.15, D(G(z)): 0.13
Epoch [20/50], Step[500/500], d_loss: 0.3630, g_loss: 5.0824, D(x): 0.26, D(G(z)): 0.10
Epoch [21/50], Step[100/500], d_

Epoch [37/50], Step[400/500], d_loss: 0.8875, g_loss: 4.4708, D(x): 0.21, D(G(z)): 0.67
Epoch [37/50], Step[500/500], d_loss: 0.1954, g_loss: 4.5093, D(x): 0.03, D(G(z)): 0.17
Epoch [38/50], Step[100/500], d_loss: 0.1243, g_loss: 6.4659, D(x): 0.04, D(G(z)): 0.09
Epoch [38/50], Step[200/500], d_loss: 0.2752, g_loss: 3.8673, D(x): 0.07, D(G(z)): 0.20
Epoch [38/50], Step[300/500], d_loss: 0.3209, g_loss: 3.3679, D(x): 0.07, D(G(z)): 0.25
Epoch [38/50], Step[400/500], d_loss: 0.8196, g_loss: 4.4441, D(x): 0.16, D(G(z)): 0.66
Epoch [38/50], Step[500/500], d_loss: 0.3088, g_loss: 5.7579, D(x): 0.21, D(G(z)): 0.10
Epoch [39/50], Step[100/500], d_loss: 0.0790, g_loss: 6.4840, D(x): 0.05, D(G(z)): 0.03
Epoch [39/50], Step[200/500], d_loss: 0.1761, g_loss: 8.3874, D(x): 0.16, D(G(z)): 0.02
Epoch [39/50], Step[300/500], d_loss: 1.2684, g_loss: 2.1373, D(x): 0.81, D(G(z)): 0.46
Epoch [39/50], Step[400/500], d_loss: 0.0785, g_loss: 5.6017, D(x): 0.06, D(G(z)): 0.02
Epoch [39/50], Step[500/500], d_

In [11]:
# 保存模型文件及参数
torch.save(G.state_dict(), './generator.pkl')
torch.save(D.state_dict(), './discriminator.pkl')