<a href="https://colab.research.google.com/github/ayush12gupta/GAN-Implementations/blob/master/cGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import os
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'

In [0]:
VERSION = "20200325"  #@param ["1.5" , "20200325", "nightly"]
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0100  3727  100  3727    0     0  36900      0 --:--:-- --:--:-- --:--:-- 36900
Updating TPU and VM. This may take around 2 minutes.
Updating TPU runtime to pytorch-dev20200325 ...
Uninstalling torch-1.5.0a0+d6149a7:
  Successfully uninstalled torch-1.5.0a0+d6149a7
Uninstalling torchvision-0.6.0a0+3c254fb:
  Successfully uninstalled torchvision-0.6.0a0+3c254fb
Copying gs://tpu-pytorch/wheels/torch-nightly+20200325-cp36-cp36m-linux_x86_64.whl...
- [1 files][ 83.4 MiB/ 83.4 MiB]                                                
Operation completed over 1 objects/83.4 MiB.                                     
Copying gs://tpu-pytorch/wheels/torch_xla-nightly+20200325-cp36-cp36m-linux_x86_64.whl...
- [1 files][114.5 MiB/114.5 MiB]                             

In [0]:
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

# "Map function": acquires a corresponding Cloud TPU core, creates a tensor on it,
# and prints its core
def simple_map_fn(index, flags):
  # Sets a common random seed - both for initialization and ensuring graph is the same
  torch.manual_seed(1234)

  # Acquires the (unique) Cloud TPU core corresponding to this process's index
  device = xm.xla_device()  

  # Creates a tensor on this process's device
  t = torch.randn((2, 2), device=device)

  print("Process", index ,"is using", xm.xla_real_devices([str(device)])[0])

  # Barrier to prevent master from exiting before workers connect.
  xm.rendezvous('init')

# Spawns eight of the map functions, one for each of the eight cores on
# the Cloud TPU
flags = {}
# Note: Colab only supports start_method='fork'
xmp.spawn(simple_map_fn, args=(flags,), nprocs=8, start_method='fork')

Process 0 is using TPU:0
Process 6 is using TPU:6
Process 2 is using TPU:2
Process 3 is using TPU:3
Process 5 is using TPU:5
Process 7 is using TPU:7
Process 4 is using TPU:4
Process 1 is using TPU:1


In [0]:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pl
import torch
import torch.optim as opt
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchvision.utils as vutils

from torchvision import datasets
import torchsummary
from torchsummary import summary
from torch.autograd import Variable
import time
import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline


In [0]:
######## Dataset ###########
transform=transforms.Compose([transforms.Resize(32), 
         transforms.CenterCrop(32),
         transforms.ToTensor(), 
         transforms.Normalize((0.5,), (0.5,))])
if not xm.is_master_ordinal():
    xm.rendezvous('download_only_once')
dataset = datasets.CIFAR10(root='./content', train=True,download=True, transform=transform)

dataloader = torch.utils.data.DataLoader(dataset, batch_size = 32, shuffle=True)



Files already downloaded and verified


In [0]:
Epoch = 100
channel = 3
image_size = 32
latent_dim = 100
num_class = 10
batch_size = 32
##'cuda:0'
image_shape = (channel, image_size, image_size)

In [0]:
class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()

    self.label_embedding = nn.Embedding(10,10)
    self.layer = 128

    self.model = nn.Sequential(
        nn.Linear(latent_dim+num_class, self.layer),
        nn.BatchNorm1d(self.layer,0.8),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Linear(self.layer, self.layer*2),
        nn.BatchNorm1d(self.layer*2,0.8),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Linear(self.layer*2, self.layer*4),
        nn.BatchNorm1d(self.layer*4,0.8),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Linear(self.layer*4, self.layer*8),
        nn.BatchNorm1d(self.layer*8,0.8),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Linear(self.layer*8,channel*image_size*image_size),
        nn.Tanh()        
    )  


  def forward(self, noise, labels):
    c = self.label_embedding(labels)
    #print(noise.shape)
    z = noise.view(noise.size(0),latent_dim)
    x = torch.cat([c,z],1)
    out = self.model(x)
    return out.view(out.size(0),channel,image_size,image_size)

In [0]:
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator,self).__init__()
    self.label_embedding = nn.Embedding(10,10)
    self.layer = 256
    
    self.model = nn.Sequential(
        nn.Linear(num_class+(channel*image_size*image_size),self.layer*4),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Dropout(0.4),
        nn.Linear(self.layer*4,self.layer*2),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Dropout(0.4),
        nn.Linear(self.layer*2,self.layer),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Dropout(0.4),
        nn.Linear(self.layer, 1),
        nn.Sigmoid()
    )


  def forward(self, img, label):
    x = img.view(img.size(0),-1)
    #print(x.shape)
    z = self.label_embedding(label)
    x = torch.cat([x, z],1)
    out = self.model(x)
    return out

In [0]:
  def init_weights(m): 
    if type(m)==nn.Linear:
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)

In [0]:
generator = Generator().to(device)
gen_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
discriminator = Discriminator().to(device)
discriminator.apply(init_weights)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Loss functions 
a_loss = torch.nn.BCELoss()

  This is separate from the ipykernel package so we can avoid doing imports until


In [0]:
real_label = 0.9
fake_label = 0.0

label_type = torch.LongTensor#.to(device)
img_type = torch.FloatTensor#.to(device)

if device=='cuda:0': 
    generator.to(device)
    discriminator.to(device)
    a_loss.to(device)
    label_type = torch.cuda.LongTensor
    img_type = torch.cuda.FloatTensor


In [0]:
fix_noise = torch.FloatTensor(np.random.normal(0, 1,(batch_size, latent_dim))).to(device) # To evaluate on a particular noise
fix_label = torch.LongTensor(np.random.randint(0, num_class, batch_size)).to(device)


In [0]:
G_Loss_FM = G_losses
D_Loss_FM = D_losses

NameError: ignored

In [0]:
G_losses = []
D_losses = []
for epoch in range(1,Epoch+1):
  G_loss=0.
  D_loss=0.
  for i, data in enumerate(dataloader):
    (imgs,labels) = data
    batch_size = imgs.shape[0]
    imgs = Variable(imgs.type(img_type)).to(device)
    labels = Variable(labels.type(label_type)).to(device)

    # Creating real and fake label for calculation of loss
    r_label = Variable(img_type(batch_size,1).fill_(real_label)).to(device)
    f_label = Variable(img_type(batch_size,1).fill_(fake_label)).to(device)

    # Training Generator

    gen_optimizer.zero_grad()

    noise = Variable(img_type(np.random.normal(0, 1,(batch_size, latent_dim)))).to(device)
    rand_label = Variable(label_type(np.random.randint(0, num_class, batch_size))).to(device)
    dis = discriminator(generator(noise, rand_label),rand_label)
    #print(type(dis),'  ',type(r_label))
    g_loss = a_loss(dis,r_label)
    g_loss.backward()
    gen_optimizer.step()

    # Training Discriminator

    d_optimizer.zero_grad()

    noise = Variable(img_type(np.random.normal(0, 1,(batch_size, latent_dim)))).to(device)
    rand_label = Variable(label_type(np.random.randint(0, num_class, batch_size))).to(device)

    d_real = discriminator(imgs, labels)
    loss_real = a_loss(d_real, r_label)

    d_fake = discriminator(generator(noise,rand_label).detach(),rand_label)
    loss_fake = a_loss(d_fake, f_label)

    d_loss = 0.5*(loss_fake+loss_real)

    d_loss.backward()
    d_optimizer.step()

    G_loss += g_loss.item()
    D_loss += d_loss.item()

    if i%100 == 0: 
        
        static_fake = generator(fix_noise, fix_label)
        vutils.save_image(static_fake.detach(), '/content/drive/My Drive/cGAN/Image/samples_%d.png' % (epoch), normalize=True)

  print('Epoch {} || G_loss: {} || D_loss: {}'.format(epoch,G_loss/(i+1),D_loss/(i+1)))
  #print('Epoch {} || G_loss: {} || D_loss: {}'.format(epoch,g_loss.item(),d_loss.item()))
  G_losses.append(G_loss/(i+1))
  D_losses.append(D_loss/(i+1))
  # static_fake = generator(fix_noise, fix_label)
  # plt.imshow(static_fake.squeeze().detach().cpu(),normalize=True)#.view(channel,image_size,image_size
  # plt.show()
  # plt.savefig('/content/drive/My Drive/cGAN/Image/fake_samples_epoch_%03d.png' % (epoch))
  # #Checkpoint
  torch.save(generator.state_dict(),'/content/drive/My Drive/cGAN/generator/generator_{}_.pth'.format(epoch))
  torch.save(discriminator.state_dict(),'/content/drive/My Drive/cGAN/discriminator/discriminator_{}_.pth'.format(epoch))

KeyboardInterrupt: ignored

In [0]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [0]:
def map_fn(index):

  torch.manual_seed(1234)

  device = xm.xla_device() 

  
  transform=transforms.Compose([transforms.Resize(32), 
         transforms.CenterCrop(32),
         transforms.ToTensor(), 
         transforms.Normalize((0.5,), (0.5,))])
  
  if not xm.is_master_ordinal():
    xm.rendezvous('download_only_once')

  train_dataset = datasets.CIFAR10(
    "./content",
    train=True,
    download=True,
    transform=transform)

  if xm.is_master_ordinal():
    xm.rendezvous('download_only_once')
  
  # Creates the (distributed) train sampler, which let this process only access
  # its portion of the training dataset.
  train_sampler = torch.utils.data.distributed.DistributedSampler(
    train_dataset,
    num_replicas=xm.xrt_world_size(),
    rank=xm.get_ordinal(),
    shuffle=True)

  
  # Creates dataloaders, which load data in batches
  # Note: test loader is not shuffled or sampled
  dataloader = torch.utils.data.DataLoader(
      train_dataset,
      batch_size=32,
      sampler=train_sampler,
      num_workers=8,
      drop_last=True)
  
  generator = Generator().to(device)
  gen_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
  discriminator = Discriminator().to(device)
  discriminator.apply(init_weights)
  d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

  a_loss = torch.nn.BCELoss()

  for epoch in range(1,Epoch+1):
    para_train_loader = pl.ParallelLoader(dataloader, [device]).per_device_loader(device)
    G_loss=0.
    D_loss=0.
    for i, data in enumerate(para_train_loader):
      (imgs,labels) = data
      batch_size = imgs.shape[0]
      imgs = Variable(imgs.type(img_type)).to(device)
      labels = Variable(labels.type(label_type)).to(device)

      # Creating real and fake label for calculation of loss
      r_label = Variable(img_type(batch_size,1).fill_(real_label)).to(device)
      f_label = Variable(img_type(batch_size,1).fill_(fake_label)).to(device)

      # Training Generator

      gen_optimizer.zero_grad()

      noise = Variable(img_type(np.random.normal(0, 1,(batch_size, latent_dim)))).to(device)
      rand_label = Variable(label_type(np.random.randint(0, num_class, batch_size))).to(device)
      dis = discriminator(generator(noise, rand_label),rand_label)
      #print(type(dis),'  ',type(r_label))
      g_loss = a_loss(dis,r_label)
      g_loss.backward()
      #gen_optimizer.step()
      xm.optimizer_step(gen_optimizer)

      # Training Discriminator

      d_optimizer.zero_grad()

      noise = Variable(img_type(np.random.normal(0, 1,(batch_size, latent_dim)))).to(device)
      rand_label = Variable(label_type(np.random.randint(0, num_class, batch_size))).to(device)

      d_real = discriminator(imgs, labels)
      loss_real = a_loss(d_real, r_label)

      d_fake = discriminator(generator(noise,rand_label).detach(),rand_label)
      loss_fake = a_loss(d_fake, f_label)

      d_loss = 0.5*(loss_fake+loss_real)

      d_loss.backward()
      #d_optimizer.step()
      xm.optimizer_step(d_optimizer)

      G_loss += g_loss.item()
      D_loss += d_loss.item()

      if i%100 == 0: 
        
        static_fake = generator(fix_noise, fix_label)
        vutils.save_image(static_fake.detach(), '/content/drive/My Drive/cGAN/Image/samples_%d.png' % (epoch), normalize=True)

    print('Epoch {} || G_loss: {} || D_loss: {}'.format(epoch,G_loss/(i+1),D_loss/(i+1)))
    #print('Epoch {} || G_loss: {} || D_loss: {}'.format(epoch,g_loss.item(),d_loss.item()))
    G_losses.append(G_loss/(i+1))
    D_losses.append(D_loss/(i+1))
    # static_fake = generator(fix_noise, fix_label)
    # plt.imshow(static_fake.squeeze().detach().cpu(),normalize=True)#.view(channel,image_size,image_size
    # plt.show()
    # plt.savefig('/content/drive/My Drive/cGAN/Image/fake_samples_epoch_%03d.png' % (epoch))
    # #Checkpoint
    torch.save(generator.state_dict(),'/content/drive/My Drive/cGAN/generator/generator_{}_.pth'.format(epoch))
    torch.save(discriminator.state_dict(),'/content/drive/My Drive/cGAN/discriminator/discriminator_{}_.pth'.format(epoch))

In [0]:
xmp.spawn(map_fn, args=(), nprocs=8, start_method='fork')

Exception: ignored