### Using prtorch to build a GAN zoo


In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import torchvision.utils as vutils
from torch.autograd import Variable
import torch.utils.data

In [10]:
# load the data

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=64, shuffle=True)

Files already downloaded


### 按照Goodfeli最初的GAN paper中设计的网络结构； 

In [11]:
#使用一个类构建网络，
#通过super函数继承 nn.Module的构造方法
class G(nn.Module):
    def __init__(self):
        super(G, self).__init__() 
        
        self.g1 = nn.Linear(100, 1200)
        self.g2 = nn.Linear(1200,1200)
        self.g3 = nn.Linear(1200,1 * 28 * 28) #construction images    
        
    def forward(self, x):
        x = x.view(x.size(0), x.size(1))
        x = F.relu(self.g1(x))
        x = F.relu(self.g2(x))
        x = self.g3(x)
        
        return x.view(x.size(0), 1, 28, 28) 
g_model = G()

In [12]:
class D(nn.Module):
    def __init__(self):
        super(D, self).__init__()
        self.d1 = nn.Linear(1*28*28, 240)
        self.d2 = nn.Linear(240, 240)
        self.d3 = nn.Linear(240, 1)
        
    def forward(self, x):
        x = x.view(x.size(0), x.size(1) * x.size(2) * x.size(3)) #view 的操作应该是 reshape； 
        x = F.relu(self.d1(x))
        x = F.relu(self.d2(x))
        x = F.sigmoid(self.d3(x))
        
        return x
d_model = D()

In [13]:
input_ = torch.FloatTensor(64, 1, 28, 28)
noise = torch.FloatTensor(64, 100, 1, 1) #(batch size; 100 dimension; 1 * 1)
fixed_noise = torch.FloatTensor(32, 100, 1, 1).normal_(0, 1)
label = torch.FloatTensor(64)

input_ = Variable(input_)
label = Variable(label)
noise = Variable(noise)
fixed_noise = Variable(fixed_noise)
#noise.data.normal_(0,1)

In [14]:
optimizerD = optim.Adam(d_model.parameters(), lr = 0.001, betas = (0.5, 0.999))
optimizerG = optim.Adam(g_model.parameters(), lr = 0.001, betas = (0.5, 0.999))

criterion = nn.BCELoss()

In [15]:
for epoch in range(500):
    for i, data in enumerate(train_loader, 0):
        #update the D model
        d_model.zero_grad()
        real, _ = data
        batch_size = real.size(0)
        input_.data.resize_(real.size()).copy_(real)
        label.data.resize_(batch_size).fill_(1) # real label is 1
        
        output = d_model(input_)
        loss_D_r = criterion(output, label)
        loss_D_r.backward()
        D_real = output.data.mean()
        
        noise.data.resize_(64, 100, 1, 1)
        noise.data.normal_(0, 1)
        label.data.fill_(0) # fake label
        fake_input = g_model(noise)
        output = d_model(fake_input)
        loss_D_f = criterion(output, label)
        loss_D_f.backward()
        D_fake = output.data.mean()
        
        errD = D_real + D_fake
        
        optimizerD.step()
        
        #update G mdoel
        g_model.zero_grad()
        label.data.fill_(1) # G model want the G samples be 1
        noise.data.resize_(64, 100, 1, 1)
        noise.data.normal_(0, 1)
        fake_input = g_model(noise)
        output = d_model(fake_input)
        
        loss_G = criterion(output, label)
        loss_G.backward()
        loss_D_G = output.data.mean()
        
        optimizerG.step()
        
        
        if i%100 == 0:
            print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
              % (epoch, 1000, i, len(train_loader),
                 D_real + D_fake, loss_D_G, D_real, D_fake, loss_D_G ))
            
            vutils.save_image(real, 
                             '%s/real_sample.png' % 'logs')
            fake = g_model(fixed_noise)
            vutils.save_image(fake.data,
                             '%s/fake_sample_epoch_%03d.png' % ('logs', epoch))
        
        

[0/1000][0/938] Loss_D: 1.0040 Loss_G: 0.4980 D(x): 0.5049 D(G(z)): 0.4992 / 0.4980
[0/1000][100/938] Loss_D: 1.3195 Loss_G: 0.3876 D(x): 0.7905 D(G(z)): 0.5291 / 0.3876
[0/1000][200/938] Loss_D: 1.2756 Loss_G: 0.0831 D(x): 0.9557 D(G(z)): 0.3200 / 0.0831
[0/1000][300/938] Loss_D: 1.0716 Loss_G: 0.0972 D(x): 0.8748 D(G(z)): 0.1968 / 0.0972
[0/1000][400/938] Loss_D: 1.1398 Loss_G: 0.0483 D(x): 0.8408 D(G(z)): 0.2990 / 0.0483
[0/1000][500/938] Loss_D: 1.0742 Loss_G: 0.0153 D(x): 0.7606 D(G(z)): 0.3136 / 0.0153
[0/1000][600/938] Loss_D: 0.9814 Loss_G: 0.1266 D(x): 0.7196 D(G(z)): 0.2618 / 0.1266
[0/1000][700/938] Loss_D: 1.0952 Loss_G: 0.0678 D(x): 0.7386 D(G(z)): 0.3566 / 0.0678
[0/1000][800/938] Loss_D: 0.9973 Loss_G: 0.1938 D(x): 0.6589 D(G(z)): 0.3384 / 0.1938
[0/1000][900/938] Loss_D: 1.0699 Loss_G: 0.1873 D(x): 0.6991 D(G(z)): 0.3707 / 0.1873


AssertionError: 