<a href="https://colab.research.google.com/github/MLandML/MLandML/blob/learning_projects/Advanced_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
import torch,torchvision,os,PIL,pdb
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import make_grid
from tqdm.auto import tqdm
import numpy
from PIL import Image
import matplotlib.pyplot as plt

def show(tensor,num=25,wamb=0,name=''):
  data = tensor.detach().cpu()
  grid = make_grid(data[:num],nrow = 5).permute(1,2,0)

  plt.imshow(grid.clip(0,1)) #clipping the pixels so that the images in the grid are not oversaturated
  plt.show()

nepochs = 10000
batch_size = 128
lr = 1e-4
z_dim = 200
device = "cuda"

cur_step = 0
crit_cycles = 5 #after training gen,critic is trained 5 times before again training the gen
gen_losses=[]
crit_losses = []
show_step = 35
save_step = 35


In [14]:
from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn.modules.module import Module
#generator model

class Generator(nn.Module):
  def __init__(self,z_dim=200,d_dim=16):
    super().__init__()
    self.z_dim = z_dim

    self.gen = nn.Sequential(
        nn.ConvTranspose2d(z_dim,d_dim*32,4,1,0), #increases the size of input 
        #ConvTranspose2d: in_channels, out_channels, kernel,_size, stride=1,padding=0
        #Calculating new width and height: (n-1)*stride - 2*padding +ks
        #n:width or height
        #we begin with  1x1 image with z_dim number of channels(200)
        # 1x1 to 4x4 and from 200 channels to 512 channels
        nn.BatchNorm2d(d_dim*32),
        nn.ReLU(True),

        nn.ConvTranspose2d(d_dim*32,d_dim*16,4,2,1), #8x8 ch: 512 to 256
        nn.BatchNorm2d(d_dim*16),
        nn.ReLU(True),

        nn.ConvTranspose2d(d_dim*16,d_dim*8,4,2,1), #16x16 ch: 256 to 128
        nn.BatchNorm2d(d_dim*8),
        nn.ReLU(True),

        nn.ConvTranspose2d(d_dim*8,d_dim*4,4,2,1), #32x32 ch: 128 to 64
        nn.BatchNorm2d(d_dim*4),
        nn.ReLU(True),

        nn.ConvTranspose2d(d_dim*4,d_dim*2,4,2,1), #64x64 ch: 64 to 32
        nn.BatchNorm2d(d_dim*2),
        nn.ReLU(True),

        nn.ConvTranspose2d(d_dim*2,3,4,2,1), #128x128 ch: 32 to 3 (rgb)
        nn.Tanh(), #produce in the range from -1 to 1

    )

  def forward(self,noise):
    x = noise.view(len(noise),200,1,1) #128x200x1x1
    return self.gen(x)

def gen_noise(num,z_dim,device='cuda'):
  return torch.rand(num,z_dim,device=device) #128x200

In [16]:
#critic 

class critic(nn.Module):
  def __init__(self,d_dim=16):
    super().__init__()
    self.critic = nn.Sequential(
        #Conv2d: in_channels,out_channels,kernel_size,stride=1,padding=0
        #Conv2d: width or height = (n+2*pad-ks)//stride +1
        nn.Con2d(3,d_dim,4,2,1),#64x64 ch: 3,16
        nn.InstanceNorm2d(d_dim), #in critic instead of normalizing by batch, instance is best
        nn.LeakyReLU(0.2),

        nn.Conv2d(d_dim,d_dim*2,4,2,1),
        nn.InstanceNorm2d(d_dim*2),
        nn.LeakyReLU(0.2),

        nn.Conv2d(d_dim*2,d_dim*4,4,2,1),
        nn.InstanceNorm2d(d_dim*4),
        nn.LeakyReLU(0.2),

        nn.Conv2d(d_dim*4,d_dim*8,4,2,1),
        nn.InstanceNorm2d(d_dim*8),
        nn.LeakyReLU(0.2),

        nn.Conv2d(d_dim*8,d_dim*16,4,2,1),
        nn.InstanceNorm2d(d_dim*16),
        nn.LeakyReLU(0.2),

        nn.Conv2d(d_dim*16,1,4,1,0), #1x1 ch: 256 to 1
        
    )

  def forward(self,image):
    #image: 128x3x128x128  batch x ch x width x height
    crit_pred = self.crit(image) # 128x1x1x1
    return crit_pred.view(len(crit_pred),-1) #128x1 128 predictions real or fake