


![DCGAN中的G结构图](../images/DCGAN/DCGAN_G.png)


![DCGAN中的D结构图](../images/DCGAN/DCGAN_D.png)


数据：https://pan.baidu.com/s/1eSifHcA  提取码：g5qa 

以上内容参考 https://zhuanlan.zhihu.com/p/24767059 

In [12]:
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import transforms
from torch.utils import data
import os
from PIL import Image
from torch import optim

In [10]:
# 构建生成模型
class Generator(nn.Module):
    def __init__(self, noise_dim=100):
        super(Generator, self).__init__()
        self.fc1 = nn.Linear(noise_dim, 4 * 4 * 1024)
        self.conv1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=5, stride=2, padding=2)
        self.conv2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=5, stride=2, padding=2)
        self.conv3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=5, stride=2, padding=2)
        self.conv4 = nn.ConvTranspose2d(in_channels=128, out_channels=3, kernel_size=5, stride=2, padding=2)
    def forward(self, x):
        A1 = self.fc1(x)
        A1 = A1.view(-1, 1024, 4, 4)
        A2 = nn.ReLU(self.conv1(A1))
        A3 = nn.ReLU(self.conv2(A2))
        A4 = nn.ReLU(self.conv3(A3))
        y_hat = nn.Tanh(self.conv4(A4))
        return y_hat
    

In [9]:
# 构建判别模型
class Discriminator(nn.Module):
    def __init__(self, input_chanel=3, result=1):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(input_chanel, 64, kernel_size=5, stride=2, padding=2)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2)
        self.bn3 = nn.BatchNorm2d(256)
        self.conv4 = nn.Conv2d(256, 512, kernel_size=5, stride=2, padding=2)
        self.bn4 = nn.BatchNorm2d(512)
        self.fc = nn.Linear(4 * 4 * 512, result)
    def forward(self, x):
        A1 = nn.LeakyReLU(self.con1(x))
        A2 = nn.LeakyReLU(self.bn2(self.conv2(A1)))
        A3 = nn.LeakyReLU(self.bn3(self.conv3(A2)))
        A4 = nn.LeakyReLU(self.bn4(self.conv4(A3)))
        y_hat = self.fc(A4)
        return y_hat
    

In [16]:
# 数据集处理
class ImageDataset(data.Dataset):
    def __init__(self, path, transform=None):
        """
            root 是存在图片的文件夹
        """
        self.images = list(map(lambda x: os.path.join(path, x), os.listdir(path)))
        self.transform = transform
    def __getitem__(self, index):
        image_file = self.images[index]
        image = Image.open(image_file).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)
        return image
    def __len__(self):
        return len(self.images)


def get_dataset(self, path, img_scale, batch_size):
    transforms = transforms.Compose([
            transforms.Scale(img_scale),
            transforms.ToTensor,
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    
    dataset = ImageDataset(path, transforms)
    data_loader = data.DataLoader(dataset=dataset, 
                                 batch_size=batch_size,
                                 shuffle=True,
                                 drop_last=True)
    return data_loader

In [15]:
# 工具函数:

# 生成 噪音 z
def gen_noisy(batch_size, noisy_dim):
    return torch.randn(batch_size, noisy_dim)

# tensor to variable
def to_variable(x):
    x = Variable(x)
    if torch.cuda.is_available():
        x = x.cuda()
    return x

In [18]:
# 这里可以定义一个 Config 类，用来保存这些超参数
class Config(object):
    def __init__(self):
        self.batch_size = 128
        self.image_path = './images/'
        self.noisy_dim = 100
        self.G_lr = 1e-4
        self.D_lr = 1e-6
        self.EPOCH = 200
        self.img_scale = 64
        self.k_step = 2

In [13]:
# 训练阶段
config = Config()

train_data_loader = get_dataset(config.image_path, config.img_scale, config.batch_size)

G = Generator()
D = Discriminator()


loss_fn = torch.nn.BCELoss()
g_optimizer = optim.Adam(G.parameters, lr=config.G_lr)
d_optimizer = optim.Adam(D.parameters, lr=config.D_lr)


for epoch in range(config.EPOCH):
    g_total_loss = torch.FloatTensor([0])
    d_total_loss = torch.FloatTensor([0])
    
    for i, data in enumerate(train_data_loader):
        
        true_inputs, _ = data
        images = to_variable(true_inputs)
        batch_size = images.size(0)
        
        z = to_variable(gen_noisy(batch_size, config.noisy_dim))
        
        real_labels = to_var(torch.ones(batch_size))
        fake_labels = to_var(torch.zeros(batch_size))
        
        ###          train D           ###
        outputs = D(images)
        d_loss_real = loss_fn(outputs, real_labels)
        real_score = outputs
        
        fake_images = G(z)
        outputs = D(fake_images)
        d_loss_fake = loss_fn(outputs, fake_labels)
        fake_score = outputs
        
        d_loss = d_loss_real + d_loss_fake
        D.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        ###          train G           ###
        if i % config.k_step == 0:
            z = to_variable(gen_noisy(batch_size, config.noisy_dim))
            fake_images = G(z)
            outputs = D(fake_images)
            g_loss = loss_fn(outputs, real_labels)
            G.zero_grad()
            g_loss.backward()
            g_optimizer.step()
        
        if (i + 1) % 300 == 0:
            print('Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, '
                  'g_loss: %.4f, D(x): %.2f, D(G(z)): %.2f'
                  % (epoch, 200, i + 1, 600, d_loss.data[0], g_loss.data[0],
                     real_score.data.mean(), fake_score.data.mean()))
    
    # Save real images
    fake_images = fake_images.view(fake_images.size(0), 1, 48, 48)
    save_image((fake_images.data), './data/test_DCGAN/fake_images-%d.png' % (epoch + 1))
        