In [1]:
import torch
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms, datasets

In [2]:
train = datasets.MNIST('', train = True, download = True,
                       transform = transforms.Compose([transforms.ToTensor()]))
train_set = torch.utils.data.DataLoader(train, batch_size = 100, shuffle = True)

In [3]:
if torch.cuda.is_available():
    device = torch.device('cuda:0')
    print("Running on GPU")
else:
    device = torch.device('cpu')
    print("Running on cpu")

Running on GPU


In [4]:
def latent_space_vectors(size): #size is the number of samples in a batch
    return torch.randn(size, 100).to(device)

def real_data_target(size):
    return (torch.ones(size, 1)).to(device)

def fake_data_target(size):
    return (torch.zeros(size, 1)).to(device)

![gan.png](attachment:gan.png)

In [5]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        input_features = 100
        output_features = 784
        self.fc1 = nn.Linear(input_features, 256)
        self.fc2 = nn.Linear(256, 512)
        self.fc3 = nn.Linear(512, 1024)
        self.out = nn.Linear(1024, output_features)
        
    def forward(self, x):
        x = self.fc1(x)
        x = F.leaky_relu(x, 0.2)
        x = self.fc2(x)
        x = F.leaky_relu(x, 0.2)
        x = self.fc3(x)
        x = F.leaky_relu(x, 0.2)
        x = self.out(x)
        x = torch.tanh(x)
        return x

In [6]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        input_features = 784
        output_features = 1
        self.fc1 = nn.Linear(input_features, 1024)
        self.fc2 = nn.Linear(1024, 512)
        self.fc3 = nn.Linear(512, 256)
        self.out = nn.Linear(256, output_features)
    
    def forward(self, x):
        x = self.fc1(x)
        x = F.leaky_relu(x, 0.2)
        x = self.fc2(x)
        x = F.leaky_relu(x, 0.2)
        x = self.fc3(x)
        x = F.leaky_relu(x, 0.2)
        x = self.out(x)
        x = torch.sigmoid(x)
        return x

In [7]:
loss_function = torch.nn.BCELoss()
generator = Generator()
discriminator = Discriminator()

if torch.cuda.is_available():
    generator.cuda()
    discriminator.cuda()
    loss_function.cuda()
    print(generator)
    print(discriminator)
    print(loss_function)

Generator(
  (fc1): Linear(in_features=100, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=1024, bias=True)
  (out): Linear(in_features=1024, out_features=784, bias=True)
)
Discriminator(
  (fc1): Linear(in_features=784, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=256, bias=True)
  (out): Linear(in_features=256, out_features=1, bias=True)
)
BCELoss()


In [8]:
optimizer_generator = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_discriminator = optim.Adam(discriminator.parameters(), lr=0.0002)

In [9]:
def train_discriminator(real_image, fake_image):
    optimizer_discriminator.zero_grad()
    
    #training discriminator using real images
    prediction_real_image = discriminator.forward(real_image)
    loss_real_image = loss_function(prediction_real_image, real_data_target(real_image.size(0)))
    loss_real_image.backward()
    
    #training discriminator using fake images
    prediction_fake_image = discriminator.forward(fake_image)
    loss_fake_image = loss_function(prediction_fake_image, fake_data_target(prediction_fake_image.size(0)))
    loss_fake_image.backward()
    
    optimizer_discriminator.step()
    
    return loss_fake_image + loss_real_image, prediction_real_image, prediction_fake_image

In [10]:
def tarin_generator(fake_image):
    optimizer_generator.zero_grad()
    prediction_fake_image = discriminator.forward(fake_image)
    error_fake_image = loss_function(prediction_fake_image, real_data_target(prediction_fake_image.size(0)))
    error_fake_image.backward()
    optimizer_generator.step()
    
    return error_fake_image

In [11]:
EPOCHS = 10

for epoch in range(EPOCHS):
    for real_batch,_ in tqdm(train_set):

        real_image = real_batch.view(real_batch.size(0), 784).to(device)
        fake_image = generator.forward(latent_space_vectors(100))
        d_error, d_pred_real, d_pred_fake = train_discriminator(real_image, fake_image)
        
        
        fake_image = generator.forward(latent_space_vectors(100))
        g_error = tarin_generator(fake_image)
        
    print("Discriminator loss = " + str(d_error.data.cpu().numpy()) + " Generator loss = " + str(g_error.data.cpu().numpy()))

100%|████████████████████████████████████████████████████████████████████████████████| 600/600 [00:30<00:00, 19.82it/s]
  0%|▎                                                                                 | 2/600 [00:00<00:32, 18.40it/s]

Discriminator loss = 1.1046861 Generator loss = 1.2974416


100%|████████████████████████████████████████████████████████████████████████████████| 600/600 [00:31<00:00, 18.86it/s]
  0%|▎                                                                                 | 2/600 [00:00<00:35, 16.85it/s]

Discriminator loss = 0.35793236 Generator loss = 1.8317511


100%|████████████████████████████████████████████████████████████████████████████████| 600/600 [00:32<00:00, 18.48it/s]
  0%|▎                                                                                 | 2/600 [00:00<00:38, 15.54it/s]

Discriminator loss = 0.33811456 Generator loss = 3.0206711


100%|████████████████████████████████████████████████████████████████████████████████| 600/600 [00:33<00:00, 18.14it/s]
  0%|▎                                                                                 | 2/600 [00:00<00:36, 16.31it/s]

Discriminator loss = 0.821491 Generator loss = 2.058123


100%|████████████████████████████████████████████████████████████████████████████████| 600/600 [00:33<00:00, 17.98it/s]
  0%|▎                                                                                 | 2/600 [00:00<00:34, 17.29it/s]

Discriminator loss = 0.46835536 Generator loss = 2.0661302


100%|████████████████████████████████████████████████████████████████████████████████| 600/600 [00:33<00:00, 17.82it/s]
  0%|▎                                                                                 | 2/600 [00:00<00:36, 16.17it/s]

Discriminator loss = 0.3697992 Generator loss = 2.9792717


100%|████████████████████████████████████████████████████████████████████████████████| 600/600 [00:33<00:00, 17.82it/s]
  0%|▎                                                                                 | 2/600 [00:00<00:39, 15.07it/s]

Discriminator loss = 0.48019934 Generator loss = 2.7912111


100%|████████████████████████████████████████████████████████████████████████████████| 600/600 [00:33<00:00, 17.79it/s]
  0%|▎                                                                                 | 2/600 [00:00<00:39, 15.19it/s]

Discriminator loss = 0.5994385 Generator loss = 2.1251178


100%|████████████████████████████████████████████████████████████████████████████████| 600/600 [00:33<00:00, 17.95it/s]
  0%|▎                                                                                 | 2/600 [00:00<00:42, 14.12it/s]

Discriminator loss = 0.5162189 Generator loss = 2.2121634


100%|████████████████████████████████████████████████████████████████████████████████| 600/600 [00:33<00:00, 17.74it/s]

Discriminator loss = 0.7448069 Generator loss = 2.4311585





<h3>GANs are hard to train</h3>
Looking for more fine tuning and optimized way for training.