<a href="https://colab.research.google.com/github/ayush12gupta/model_zoo/blob/master/cGAN/cGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
import torch.optim as opt
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchvision.utils as vutils

from torchvision import datasets
import torchsummary
from torchsummary import summary
from torch.autograd import Variable
import time
import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline


In [33]:
######## Dataset ###########
dataset = datasets.CIFAR10(root='./content', train=True,download=True, 
     transform=transforms.Compose([transforms.Resize(32), 
         transforms.CenterCrop(32),
         transforms.ToTensor(), 
         transforms.Normalize((0.5,), (0.5,))]))

dataloader = torch.utils.data.DataLoader(dataset, batch_size = 32, shuffle=True)

Epoch = 100
channel = 3
image_size = 32
latent_dim = 100
num_class = 10
batch_size = 32
device = 'cuda:0'
image_shape = (channel, image_size, image_size)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./content/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Failed download. Trying https -> http instead. Downloading http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./content/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./content/cifar-10-python.tar.gz to ./content


In [0]:
class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()

    self.label_embedding = nn.Embedding(10,10)
    self.layer = 128

    self.model = nn.Sequential(
        nn.Linear(latent_dim+num_class, self.layer),
        nn.BatchNorm1d(self.layer,0.8),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Linear(self.layer, self.layer*2),
        nn.BatchNorm1d(self.layer*2,0.8),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Linear(self.layer*2, self.layer*4),
        nn.BatchNorm1d(self.layer*4,0.8),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Linear(self.layer*4, self.layer*8),
        nn.BatchNorm1d(self.layer*8,0.8),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Linear(self.layer*8,channel*image_size*image_size),
        nn.Tanh()        
    )  


  def forward(self, noise, labels):
    c = self.label_embedding(labels)
    #print(noise.shape)
    z = noise.view(noise.size(0),latent_dim)
    x = torch.cat([c,z],1)
    out = self.model(x)
    return out.view(out.size(0),channel,image_size,image_size)

In [0]:
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator,self).__init__()
    self.label_embedding = nn.Embedding(10,10)
    self.layer = 256
    
    self.model = nn.Sequential(
        nn.Linear(num_class+(channel*image_size*image_size),self.layer*4),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Dropout(0.4),
        nn.Linear(self.layer*4,self.layer*2),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Dropout(0.4),
        nn.Linear(self.layer*2,self.layer),
        nn.LeakyReLU(0.2, inplace=True),
        nn.Dropout(0.4),
        nn.Linear(self.layer, 1),
        nn.Sigmoid()
    )


  def forward(self, img, label):
    x = img.view(img.size(0),-1)
    #print(x.shape)
    z = self.label_embedding(label)
    x = torch.cat([x, z],1)
    out = self.model(x)
    return out

In [0]:
  def init_weights(m): 
    if type(m)==nn.Linear:
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)

In [48]:
generator = Generator().to(device)
gen_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
discriminator = Discriminator().to(device)
discriminator.apply(init_weights)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Loss functions 
a_loss = torch.nn.BCELoss()

  This is separate from the ipykernel package so we can avoid doing imports until


In [0]:
real_label = 0.9
fake_label = 0.0

label_type = torch.LongTensor
img_type = torch.FloatTensor

if device=='cuda:0': 
    generator.to(device)
    discriminator.to(device)
    a_loss.to(device)
    label_type = torch.cuda.LongTensor
    img_type = torch.cuda.FloatTensor


In [0]:
fix_noise = torch.FloatTensor(np.random.normal(0, 1,(batch_size, latent_dim))).to(device) # To evaluate on a particular noise
fix_label = torch.LongTensor(np.random.randint(0, num_class, batch_size)).to(device)


In [0]:
G_Loss_FM = G_losses
D_Loss_FM = D_losses

In [0]:
G_losses = []
D_losses = []
for epoch in range(1,Epoch+1):
  G_loss=0.
  D_loss=0.
  for i, data in enumerate(dataloader):
    (imgs,labels) = data
    batch_size = imgs.shape[0]
    imgs = Variable(imgs.type(img_type))
    labels = Variable(labels.type(label_type))

    # Creating real and fake label for calculation of loss
    r_label = Variable(img_type(batch_size,1).fill_(real_label)).to(device)
    f_label = Variable(img_type(batch_size,1).fill_(fake_label)).to(device)

    # Training Generator

    gen_optimizer.zero_grad()

    noise = Variable(img_type(np.random.normal(0, 1,(batch_size, latent_dim)))).to(device)
    rand_label = Variable(label_type(np.random.randint(0, num_class, batch_size))).to(device)
    dis = discriminator(generator(noise, rand_label),rand_label)
    #print(type(dis),'  ',type(r_label))
    g_loss = a_loss(dis,r_label)
    g_loss.backward()
    gen_optimizer.step()

    # Training Discriminator

    d_optimizer.zero_grad()

    noise = Variable(img_type(np.random.normal(0, 1,(batch_size, latent_dim)))).to(device)
    rand_label = Variable(label_type(np.random.randint(0, num_class, batch_size))).to(device)

    d_real = discriminator(imgs, labels)
    loss_real = a_loss(d_real, r_label)

    d_fake = discriminator(generator(noise,rand_label).detach(),rand_label)
    loss_fake = a_loss(d_fake, f_label)

    d_loss = 0.5*(loss_fake+loss_real)

    d_loss.backward()
    d_optimizer.step()

    G_loss += g_loss.item()
    D_loss += d_loss.item()

    if i%100 == 0: 
        
        static_fake = generator(fix_noise, fix_label)
        vutils.save_image(static_fake.detach(), '/content/drive/My Drive/cGAN/Image/fake_samples_epoch_%03d.png' % (epoch), normalize=True)

  print('Epoch {} || G_loss: {} || D_loss: {}'.format(epoch,G_loss/(i+1),D_loss/(i+1)))
  #print('Epoch {} || G_loss: {} || D_loss: {}'.format(epoch,g_loss.item(),d_loss.item()))
  G_losses.append(G_loss/(i+1))
  D_losses.append(D_loss/(i+1))
  # static_fake = generator(fix_noise, fix_label)
  # plt.imshow(static_fake.squeeze().detach().cpu(),normalize=True)#.view(channel,image_size,image_size
  # plt.show()
  # plt.savefig('/content/drive/My Drive/cGAN/Image/fake_samples_epoch_%03d.png' % (epoch))
  # #Checkpoint
  torch.save(generator.state_dict(),'/content/drive/My Drive/cGAN/generator/generator_epoch_{}_.pth'.format(epoch))
  torch.save(discriminator.state_dict(),'/content/drive/My Drive/cGAN/discriminator/discriminator_epoch_{}_.pth'.format(epoch))

Epoch 1 || G_loss: 1.5329498236406636 || D_loss: 0.5564369258595367
Epoch 2 || G_loss: 1.2858589751477891 || D_loss: 0.5990531109154262
Epoch 3 || G_loss: 1.133895427999173 || D_loss: 0.6189287347398503
Epoch 4 || G_loss: 1.0424635885468065 || D_loss: 0.6309133710879511
Epoch 5 || G_loss: 0.967129300247761 || D_loss: 0.6460659252247288
Epoch 6 || G_loss: 0.9323032888097025 || D_loss: 0.6508059439488275
Epoch 7 || G_loss: 0.9067564338579135 || D_loss: 0.6561434090480694
Epoch 8 || G_loss: 0.8960834864386366 || D_loss: 0.6606628908915773
Epoch 9 || G_loss: 0.8785627251318152 || D_loss: 0.6636681763773459
Epoch 10 || G_loss: 0.8634224061926282 || D_loss: 0.6677855755835866
Epoch 11 || G_loss: 0.8583885295911241 || D_loss: 0.669143032852229
Epoch 12 || G_loss: 0.8552864090571095 || D_loss: 0.6707005879090371
Epoch 13 || G_loss: 0.8468283362977411 || D_loss: 0.6747483720553661
Epoch 14 || G_loss: 0.8303324440237961 || D_loss: 0.6775244929740158
Epoch 15 || G_loss: 0.837846108567463 || D_los

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive
