# GAN

### 生成对抗模型自从提出以来就一直备受人类的关注 这里使用Pytorch简单实现一个生成对抗模型

In [1]:
import os

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image

In [2]:
if not os.path.exists('gan_img/'):
    os.mkdir('gan_img/')

In [3]:
def to_img(x):
    out = 0.5*(x+1)
    out = out.clamp(0,1)
    out = out.view(-1,1,28,28)
    return out

In [12]:
batch_size = 32
num_epoch = 100
z_dimension = 100

In [13]:
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,0.5,0.5),std=(0.5,0.5,0.5))
])
dataset = MNIST('data/',transform=img_transform,train=True)
dataloader = DataLoader(dataset,batch_size=batch_size,shuffle=True)

In [14]:
# discriminator
class discriminator(nn.Module):
    def __init__(self):
        super(discriminator,self).__init__()
        self.dis = nn.Sequential(
            nn.Linear(784,256),
            nn.LeakyReLU(0.2),
            nn.Linear(256,256),
            nn.LeakyReLU(0.2),
            nn.Linear(256,1),
            nn.Sigmoid()
        )
    def forward(self,x):
        x = self.dis(x)
        return x

In [15]:
class generator(nn.Module):
    def __init__(self):
        super(generator,self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(100,256),
            nn.ReLU(True),
            nn.Linear(256,256),
            nn.ReLU(True),
            nn.Linear(256,784),
            nn.Tanh()
        )
    def forward(self,x):
        x = self.gen(x)
        return x

In [16]:
D =discriminator()
G = generator()

In [17]:
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(),lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(),lr=0.0003)

In [18]:
for epoch in range(num_epoch):
    for i,(img,_) in enumerate(dataloader):
        num_img = img.size(0)
        
        # ===========train discriminator
        img = img.view(num_img,-1)
        real_img = Variable(img)
        
        real_label = Variable(torch.ones(num_img))
        fake_label = Variable(torch.zeros(num_img))
        
        # loss of real_img
        real_out = D(real_img)
        # 计算 真实的图片与真实的label之间的差距
        d_loss_real = criterion(real_out,real_label)
        real_scores = real_out
        
        # loss of fake_img
        z = Variable(torch.randn(num_img,z_dimension))
        fake_img = G(z)
        fake_out = D(fake_img)
        d_loss_fake = criterion(fake_out,fake_label)
        fake_scores = fake_out
        
        d_loss = d_loss_real+d_loss_fake
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()
        
        # ===========train generator
        z = Variable(torch.randn(num_img,z_dimension))
        fake_img = G(z)
        output = D(fake_img)
        g_loss = criterion(output, real_label)
        
        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()
        
        if (i+1)%100 == 0:
            print('Epoch[{}/{}], d_loss:{:.6f},g_loss:{:.6f} D real:{:.6f}, D fake:{:.6f}'.
                 format(epoch,num_epoch,d_loss.data[0],g_loss.data[0],real_scores.data.mean(),fake_scores.data.mean())
                 )
        if epoch == 0:
            real_images = to_img(real_img.cpu().data)
            save_image(real_images,'gan_img/real_images.png')
        fake_images = to_img(fake_img.cpu().data)
        save_image(fake_images,'gan_img/fake_images-{}.png'.format(epoch+1))

  "Please ensure they have the same size.".format(target.size(), input.size()))


Epoch[0/100], d_loss:0.208112,g_loss:3.407711 D real:0.957910, D fake:0.127103
Epoch[0/100], d_loss:0.014907,g_loss:5.143075 D real:0.997821, D fake:0.012622
Epoch[0/100], d_loss:0.176355,g_loss:5.960326 D real:0.988008, D fake:0.142136
Epoch[0/100], d_loss:0.007931,g_loss:8.736539 D real:0.992913, D fake:0.000551
Epoch[0/100], d_loss:0.127201,g_loss:7.142903 D real:0.920611, D fake:0.026585
Epoch[0/100], d_loss:0.059487,g_loss:6.814245 D real:0.994137, D fake:0.050181
Epoch[0/100], d_loss:0.039839,g_loss:6.136299 D real:0.981029, D fake:0.019239
Epoch[0/100], d_loss:0.043768,g_loss:5.502153 D real:0.988890, D fake:0.031494
Epoch[0/100], d_loss:0.199733,g_loss:4.842411 D real:0.929091, D fake:0.041840


KeyboardInterrupt: 