In [1]:
import torch

In [2]:
from torch import nn,optim
from torch.autograd.variable import Variable
from torchvision import datasets,transforms


In [3]:
from utils import Logger

In [4]:
def mnist_data():
    compose = transforms.Compose([transforms.ToTensor(),
                                 transforms.Normalize((.5,.5,.5),(.5,.5,.5))])
    out_dir = './dataset'
    
    return datasets.MNIST(root = out_dir,train = True,transform=compose,download = True)


data = mnist_data()

dataloader = torch.utils.data.DataLoader(data,batch_size =100,shuffle=True)
num_batches = len(dataloader)


In [5]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        n_features = 784
        n_out = 1
        
        self.hidden0 = nn.Sequential(nn.Linear(n_features,1024),
                                    nn.LeakyReLU(0.2),
                                    nn.Dropout(0.3))
        self.hidden1 = nn.Sequential(nn.Linear(1024,512),
                                    nn.LeakyReLU(0.2),
                                    nn.Dropout(0.2))
        
        self.hidden2 = nn.Sequential(nn.Linear(512,256),
                                    nn.LeakyReLU(0.2),
                                    nn.Dropout(0.2))
        self.out = nn.Sequential(nn.Linear(256,n_out),
                                    nn.Sigmoid())
        
    
    def forward(self,x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.out(x)
        return x
    
        

In [6]:
discriminator = Discriminator()

In [7]:
print(discriminator)

Discriminator(
  (hidden0): Sequential(
    (0): Linear(in_features=784, out_features=1024, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Dropout(p=0.3)
  )
  (hidden1): Sequential(
    (0): Linear(in_features=1024, out_features=512, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Dropout(p=0.2)
  )
  (hidden2): Sequential(
    (0): Linear(in_features=512, out_features=256, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Dropout(p=0.2)
  )
  (out): Sequential(
    (0): Linear(in_features=256, out_features=1, bias=True)
    (1): Sigmoid()
  )
)


In [8]:
def images_to_vector(images):
    return images.view(images.shape[0],784)

In [9]:
def vector_to_images(vector):
    return vector.view(vector.shape[0],1,28,28)

In [12]:
class Generator(torch.nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        n_features = 100
        n_out = 784
        
        self.hidden0 = nn.Sequential(nn.Linear(n_features,256),
                                    nn.LeakyReLU(0.2))
        self.hidden1 = nn.Sequential(nn.Linear(256,512),
                                    nn.LeakyReLU(0.2))
        self.hidden2 = nn.Sequential(nn.Linear(512,1024),
                                    nn.LeakyReLU(0.2))
        self.out = nn.Sequential(nn.Linear(1024,n_out),
                                nn.Tanh())
    
    def forward(self,x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x =self.out(x)
        return x
        

In [13]:
generator = Generator()

In [14]:
print(generator)

Generator(
  (hidden0): Sequential(
    (0): Linear(in_features=100, out_features=256, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
  )
  (hidden1): Sequential(
    (0): Linear(in_features=256, out_features=512, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
  )
  (hidden2): Sequential(
    (0): Linear(in_features=512, out_features=1024, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
  )
  (out): Sequential(
    (0): Linear(in_features=1024, out_features=784, bias=True)
    (1): Tanh()
  )
)


In [15]:
def noise(size):
    n = Variable(torch.randn(size,100))
    return n

In [23]:
x = noise(10)
print(x.size())

torch.Size([10, 100])


In [17]:
x

tensor([[-1.2050e+00,  1.5677e-01, -1.9934e-01,  9.8299e-01, -6.0401e-01,
         -2.6841e+00, -6.2995e-01, -1.5590e+00,  9.8946e-01,  1.9900e+00,
         -4.9580e-01,  1.3135e+00, -4.3641e-01, -3.9364e-02,  9.6730e-01,
          1.2551e+00,  7.9360e-01, -1.6929e+00, -7.8316e-01,  6.5946e-01,
         -7.7041e-01,  3.9155e-01,  1.9606e-01,  2.9355e-01, -6.7046e-01,
         -1.9068e-01,  1.0934e+00,  9.7931e-01, -7.2651e-01, -1.4129e+00,
         -6.1852e-01, -1.8512e+00, -1.8685e+00, -7.4001e-01,  1.4150e+00,
         -5.7120e-01, -1.6259e+00, -3.1629e-01,  1.7295e+00, -1.8937e+00,
          1.2015e+00, -1.8396e+00,  2.8268e-01, -2.9296e-01, -2.2385e-01,
         -1.6030e-02, -9.6601e-01,  6.5134e-01, -4.1908e-01, -5.7547e-01,
         -9.5684e-02,  3.8932e-01, -1.2595e+00, -1.4368e+00, -3.6383e-01,
         -8.7506e-01, -1.3983e+00, -8.6403e-01,  4.8898e-01, -1.1303e+00,
          8.1996e-01, -7.8545e-01,  8.9278e-01, -1.0144e+00, -9.3250e-01,
          2.8742e-01,  1.1580e+00,  1.

In [19]:
d_optimizer = optim.Adam(discriminator.parameters(),lr = 0.0002)
g_optimizer = optim.Adam(generator.parameters(),lr = 0.0002)

In [20]:
loss = nn.BCELoss()

In [21]:
def ones_target(size):
    data = Variable(torch.ones(size,1))
    return data

In [22]:
def zeros_target(size):
    data = Variable(torch.zeros(size,1))
    return data

In [25]:
def train_discriminator(optimizer,real_data,fake_data):
    N = real_data.size(0)
    optimizer.zero_grad()
    
    #Training on real data
    prediction_real = discriminator(real_data)
    error_real = loss(prediction_real,ones_target(N))
    error_real.backward()
    
    # Training on fake_data
    prediction_fake = discriminator(fake_data)
    error_fake = loss(prediction_fake,zeros_target(N))
    error_fake.backward()
    
    optimizer.step()
    
    return error_real + error_fake,prediction_real,prediction_fake

In [26]:
def train_generator(optimizer,fake_data):
    N = fake_data.size(0)
    optimizer.zero_grad()
    
    pred = discriminator(fake_data)
    error = loss(pred,ones_target(N))
    error.backward()
    
    optimizer.step()
    
    return error

In [27]:
num_test_samples = 16
test_noise = noise(num_test_samples)

In [28]:
logger = Logger(model_name = 'GAN',data_name = 'MNIST')

epochs = 20

for epoch in epochs:
    for n_batch,(real_batch,_) in enumerate(dataloader):
        N = real_batch.size(0)
        real_data = Variable(images_to_vector(real_batch))
        fake_data = generator(noise(N)).detach()
        # to do complete training 