In [0]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch import optim
from torch.autograd import Variable

In [0]:
# Dimensionality of the latent space
DIM = 100

In [0]:
image_size = 64 # We set the size of the generated images (64x64).
batch_size = 64
num_of_channels = 3

In [4]:
# check if CUDA is available
train_on_gpu = torch.cuda.is_available()

if train_on_gpu == True:
  print("CUDA is available!  Training on GPU ...")
else:
  print("CUDA is not available.  Training on CPU ...")

CUDA is available!  Training on GPU ...


In [5]:
transform = transforms.Compose([
                                transforms.Scale(image_size),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

  "please use transforms.Resize instead.")


In [6]:
dataset = datasets.CIFAR10(
    'data',
    download = True,
    transform = transform
)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))


Extracting data/cifar-10-python.tar.gz to data


In [7]:
!ls

data  sample_data


In [8]:
dataset

Dataset CIFAR10
    Number of datapoints: 50000
    Root location: data
    Split: Train
    StandardTransform
Transform: Compose(
               Scale(size=64, interpolation=PIL.Image.BILINEAR)
               ToTensor()
               Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
           )

In [9]:
type(dataset)

torchvision.datasets.cifar.CIFAR10

In [0]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)

In [11]:
class make_generator_model(nn.Module):
  def __init__(self):
    super(make_generator_model, self).__init__()
    
    self.main = nn.Sequential(
        
        nn.ConvTranspose2d(DIM, 512, kernel_size=4, stride=1, padding=0, bias=False), # inverse convolution since we are turning a vector into an image
        nn.BatchNorm2d(512),
        nn.ReLU(True),

        nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(256),
        nn.ReLU(True),

        nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(128),
        nn.ReLU(True),

        nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(64),
        nn.ReLU(True),

        nn.ConvTranspose2d(64, num_of_channels, kernel_size=4, stride=2, padding=1, bias=False),
        nn.Tanh()
    )

  def forward(self, noise):
    return self.main(noise)

generator = make_generator_model()

print(generator)

make_generator_model(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Ta

In [12]:
class make_discriminator_model(nn.Module):
  def __init__(self):
    super(make_discriminator_model, self).__init__()

    self.main = nn.Sequential(
        
        # Below conv layer gets torch.Size([64, 3, 64, 64])
        nn.Conv2d(num_of_channels, 64, kernel_size=4, stride=2, padding=1, bias=False),
        nn.LeakyReLU(0.2),

        # Below conv layer gets torch.Size([64, 64, 32, 32])
        nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(128),
        nn.LeakyReLU(0.2),

        # Below conv layer gets torch.Size([64, 128, 16, 16])
        nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2),

        # Below conv layer gets torch.Size([64, 256, 8, 8])
        nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
        nn.BatchNorm2d(512),
        nn.LeakyReLU(0.2),

        # Below conv layer gets torch.Size([64, 512, 4, 4]) & outputs torch.Size([64, 1, 1, 1])
        nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=False),
        nn.Sigmoid()
    )

  def forward(self, image):                         
    # print(image.shape)                 # torch.Size([64, 3, 64, 64])
    output = self.main(image)
    # print(output.shape)                # torch.Size([64, 1, 1, 1]) is the final output after going through sequential steps above          
    return output.view(-1)               # the .view converts it from torch.Size([64, 1, 1, 1]) to torch.Size([64])

discriminator = make_discriminator_model()

print(discriminator)

make_discriminator_model(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid()
  )
)


In [0]:
criterion = nn.BCELoss()

In [0]:
optimizer_generator = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_discriminator = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [0]:
def train_d(model, loss, optimizer, inputs, labels):

  model.train()

  inputs = Variable(inputs, requires_grad=False)   # To create torch variable. By default requires_grad is False
  labels = Variable(labels, requires_grad=False)

  optimizer.zero_grad()

  # get output from the model, given the inputs
  logps = model.forward(inputs)                    # torch.Size([64]). Same shape as labels (ones & zeros)

  # get loss for the predicted output
  cost = loss.forward(logps, labels)

  # get gradients w.r.t to parameters
  cost.backward()

  # update parameters
  optimizer.step()

  return cost.item()

In [0]:
def train_g(model, loss, optimizer, inputs, labels):

  model.train()

  inputs = Variable(inputs, requires_grad=False)   # To create torch variable. By default requires_grad is False
  labels = Variable(labels, requires_grad=False)

  optimizer.zero_grad()

  # get output from the model, given the inputs
  logps = discriminator.forward(inputs)                    # torch.Size([64]). Same shape as labels (ones & zeros)

  # get loss for the predicted output
  cost = loss.forward(logps, labels)

  # get gradients w.r.t to parameters
  cost.backward()

  # update parameters
  optimizer.step()

  return cost.item()

In [0]:
import matplotlib.pyplot as plt

def plot_generated_images(epoch):
  noise = torch.randn(batch_size, DIM, 1, 1)  # 100 feature maps of size 1x1, torch.Size([64, 100, 1, 1])
  noise = Variable(noise, requires_grad=False)                   # torch.Size([64, 100, 1, 1])
  images = generator.forward(noise)                         # torch.Size([64, 3, 64, 64])
  #noise = np.random.randn(25, DIM)                             # (25, 100)
  #generated_images = generator.predict(noise)                  # (25, 784)
  #generated_images = generated_images.reshape(25,28,28)        # (25, 28, 28)
  fig, axes = plt.subplots(8, 8, figsize=(20,20))              # We use 5 rows & 5 columns as there are 25 images
  axes = axes.flatten()
  for img, ax in zip(images, axes):
    img = img / 2 + 0.5 
    img=img.permute(1,2,0)
    #ax.imshow(img)
    #plt.axis('off')
  plt.tight_layout()
  plt.savefig('gan_generated_image %d.png' %epoch)
  plt.close()

In [0]:
epochs = 4

for epoch in range(epochs):

  for batch_idx, (real_images, labels) in enumerate(dataloader):   
  # we don't care about labels & create our own labels
  # batch_idx goes from 0 to (50,000 / batch_size) = 781.25 = truncated to 781

    ############################
    # TRAINING THE DISCRIMINATOR
    ############################
    # print(real_images.shape)                                     # torch.Size([64, 3, 64, 64])
    ones = torch.ones(batch_size)                                  # torch.Size([64])
    disc_loss_real = train_d(discriminator, criterion, optimizer_discriminator, real_images, ones)
    # print(disc_loss_real)

    noise = torch.randn(batch_size, DIM, 1, 1)  # 100 feature maps of size 1x1, torch.Size([64, 100, 1, 1])
    noise = Variable(noise, requires_grad=False)                   # torch.Size([64, 100, 1, 1])
    fake_images = generator.forward(noise)                         # torch.Size([64, 3, 64, 64])

    zeros = torch.zeros(batch_size)                                # torch.Size([64])
    disc_loss_fake = train_d(discriminator, criterion, optimizer_discriminator, fake_images, zeros)
    # print(disc_loss_fake)
    discriminator_loss = (disc_loss_real + disc_loss_fake) / 2

    ########################
    # TRAINING THE GENERATOR
    ########################
    gen_loss = train_g(generator, criterion, optimizer_generator, fake_images, ones)

    if batch_idx % 500 == 0:
      print(f"epoch: {epoch}/{epochs}, discriminator_loss: {discriminator_loss:.2f}, generator_loss: {gen_loss:.2f}")

    if epoch % 2 == 0:
      plot_generated_images(batch_idx)

epoch: 0/4, discriminator_loss: 1.56, generator_loss: 1.79


In [0]:
import matplotlib.image as mpimg
image = mpimg.imread('gan_generated_image 0.png')
plt.imshow(image)

In [0]:
import matplotlib.image as mpimg
image = mpimg.imread('gan_generated_image 2.png')
plt.imshow(image)