In [0]:
import os
import torch
import torchvision
import torch.nn as nn
from torchvision import transforms
from torchvision.utils import save_image
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Hyper-parameters
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = '/content/'

In [0]:
#gradient penalty
lambda_gp = 10

In [0]:
# Create a directory if not exists
if not os.path.exists(sample_dir):
   os.makedirs(sample_dir)

In [0]:
# Image processing
transform = transforms.Compose([transforms.ToTensor(),
 transforms.Normalize((0.5,), (0.5,))])
# MNIST dataset
mnist = torchvision.datasets.MNIST(root='../../data/',
                                  train=True,
                                  transform=transform,
                                  download=True)
# Data loader
data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                         batch_size=batch_size,
                                         shuffle=True)
# Discriminator
D = nn.Sequential(
   nn.Linear(image_size, hidden_size),
   nn.LeakyReLU(0.2),
   nn.Linear(hidden_size, hidden_size),
   nn.LeakyReLU(0.2),
   nn.Linear(hidden_size, 1))
   
# Generator
G = nn.Sequential(
   nn.Linear(latent_size, hidden_size),
   nn.ReLU(),
   nn.Linear(hidden_size, hidden_size),
   nn.ReLU(),
   nn.Linear(hidden_size, image_size),
   nn.Tanh())


In [0]:
# Device setting
D = D.to(device)
G = G.to(device)
# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
g_optimizer = torch.optim.RMSprop(G.parameters(), lr=5e-5)
d_optimizer = torch.optim.RMSprop(D.parameters(), lr=5e-5)

def denorm(x):
   out = (x + 1) / 2
   return out.clamp(0, 1)
def reset_grad():
   d_optimizer.zero_grad()
   g_optimizer.zero_grad()

In [0]:
def compute_gradient_penalty(D, real_samples, fake_samples):
   """Calculates the gradient penalty loss for WGAN GP"""
   # Random weight term for interpolation between real and fake samples
   alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1)))
   # Get random interpolation between real and fake samples
   interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
   d_interpolates = D(interpolates)
   fake = Variable(Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False)
   # Get gradient w.r.t. interpolates
   gradients = autograd.grad(
       outputs=d_interpolates,
       inputs=interpolates,
       grad_outputs=fake,
       create_graph=True,
       retain_graph=True,
       only_inputs=True,
   )[0]
   gradients = gradients.view(gradients.size(0), -1)
   gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
   return gradient_penalty

In [14]:
# Start training
total_step = len(data_loader)
for epoch in range(num_epochs):
  for i, (images, _) in enumerate(data_loader):
      images = images.reshape(batch_size, -1).to(device)
      # Create the labels which are later used as input for the BCE loss
      real_labels = torch.ones(batch_size, 1).to(device)
      fake_labels = torch.zeros(batch_size, 1).to(device)
      # ================================================================== #
      #                      Train the discriminator                       #
      # ================================================================== #
      # Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
      # Second term of the loss is always zero since real_labels == 1
      for w in range(5):
       # outputs = D(images)
       # d_loss_real = criterion(outputs, real_labels)
       # real_score = outputs
       # Compute BCELoss using fake images
       # First term of the loss is always zero since fake_labels == 0
       z = torch.randn(batch_size, latent_size).to(device)
       # fake_images = G(z)
       # outputs = D(fake_images)
       # d_loss_fake = criterion(outputs, fake_labels)
       # fake_score = outputs
       # Dicriminator forward-loss-backward-update
       G_sample = G(z)
       D_real = D(images)
       D_fake = D(G_sample)
       d_loss = -(torch.mean(D_real) - torch.mean(D_fake))
       # Backprop and optimize
       # d_loss = d_loss_real + d_loss_fake
       reset_grad()
       d_loss.backward()
       d_optimizer.step()
       # Weight clipping
       for p in D.parameters():
         p.data.clamp_(-0.01, 0.01)
      # ================================================================== #
      #                        Train the generator                         #
      # ================================================================== #
      # Compute loss with fake images
      z = torch.randn(batch_size, latent_size).to(device)
      #Generator loss:
      G_sample = G(z)
      D_fake = D(G_sample)
      g_loss = -torch.mean(D_fake)
     #  fake_images = G(z)
     #  outputs = D(fake_images)
      # We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))
      # For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdf
     #  g_loss = criterion(outputs, real_labels)
      # Backprop and optimize
      reset_grad()
      g_loss.backward()
      g_optimizer.step()
      if (i+1) % 200 == 0:
          print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'
                .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(),
                        D_real.mean().item(), D_fake.mean().item()))
  # Save real images
  if (epoch+1) == 1:
      images = images.reshape(images.size(0), 1, 28, 28)
      save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
  # Save sampled images
  fake_images = G_sample.reshape(G_sample.size(0), 1, 28, 28)
  save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))

Epoch [0/200], Step [200/600], d_loss: -0.1212, g_loss: -6.6026, D(x): 6.72, D(G(z)): 6.60
Epoch [0/200], Step [400/600], d_loss: -0.0274, g_loss: -5.0753, D(x): 5.09, D(G(z)): 5.08
Epoch [0/200], Step [600/600], d_loss: -0.1311, g_loss: -0.7612, D(x): 0.88, D(G(z)): 0.76
Epoch [1/200], Step [200/600], d_loss: -0.2411, g_loss: -2.2285, D(x): 2.48, D(G(z)): 2.23
Epoch [1/200], Step [400/600], d_loss: -0.2466, g_loss: -2.4689, D(x): 2.71, D(G(z)): 2.47
Epoch [1/200], Step [600/600], d_loss: -0.2132, g_loss: -2.6496, D(x): 2.83, D(G(z)): 2.65
Epoch [2/200], Step [200/600], d_loss: -0.1860, g_loss: -2.2708, D(x): 2.45, D(G(z)): 2.27
Epoch [2/200], Step [400/600], d_loss: -0.2440, g_loss: -1.8210, D(x): 2.08, D(G(z)): 1.82
Epoch [2/200], Step [600/600], d_loss: -0.2119, g_loss: -2.1141, D(x): 2.32, D(G(z)): 2.11
Epoch [3/200], Step [200/600], d_loss: -0.1920, g_loss: -1.4139, D(x): 1.60, D(G(z)): 1.41
Epoch [3/200], Step [400/600], d_loss: -0.2279, g_loss: -1.5342, D(x): 1.76, D(G(z)): 1.53

In [20]:
import cv2
import numpy as np
from google.colab.patches import cv2_imshow
cntr = 1
for blk in range(20):
 img_arr = [cv2.imread('/content/samples/fake_images-'+str(cntr)+'.png')]
 cntr += 1
 for i in range(1,10):
   img_arr.append(cv2.imread('/content/samples/fake_images-'+str(cntr)+'.png'))
   cntr += 1
 pctrs = np.hstack(img_arr)
 print("<<<<<<<<<>>>>>>>>>>")
 print("Batch :",blk)
 print("<<<aa<<<<<<>>>>>>>>>>")
 cv2_imshow(pctrs)
 print(img_arr)


<<<<<<<<<>>>>>>>>>>
Batch : 0
<<<<<<<<<>>>>>>>>>>


TypeError: ignored

In [0]:
img_arr = [cv2.imread('/content/samples/fake_images-'+str(1)+'.png')]