<a href="https://colab.research.google.com/github/VedantDere0104/Progressive_Growing_of_GANs/blob/main/ProGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
####

In [None]:
import torch
from torch import nn
from torchsummary import summary
import torch.nn.functional as F

from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import VOCSegmentation
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import torchvision

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
in_channels_gen = 512
out_channels = 3
z_dim = 512
out_channels_disc = 1
factors = [1, 1, 1, 1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32]

In [None]:
class Pixel_Norm(nn.Module):
  def __init__(self):
    super(Pixel_Norm , self).__init__()

    self.epsilon = 1e-8

  def forward(self , x):
    x = x / torch.sqrt(torch.mean(x ** 2 , dim=1 , keepdim = True) + self.epsilon)
    return x

In [None]:
pixel_norm = Pixel_Norm().to(device)
x = torch.randn(2 , 3 , 512 , 512).to(device)
z = pixel_norm(x)
torch.max(z)

In [None]:
class Conv(nn.Module):
  def __init__(self , 
               in_channels , 
               out_channels , 
               kernel_size = 3, 
               stride = 1, 
               padding = 1 , 
               gain = 2):
    super(Conv , self).__init__()

    self.conv1 = nn.Conv2d(in_channels , out_channels , kernel_size , stride , padding)
    self.scale = (gain / (in_channels * (kernel_size ** 2))) ** 0.5
    
    self.bias = self.conv1.bias
    self.conv1.bias = None

    nn.init.normal_(self.conv1.weight)
    nn.init.zeros_(self.bias)

  def forward(self , x):
    x = x * self.scale
    x = self.conv1(x)
    x = x + self.bias.view(1 , self.bias.shape[0] , 1 , 1)
    return x


In [None]:
conv = Conv(3 , 32).to(device)
x = torch.randn(2 , 3 , 512 , 512).to(device)
z = conv(x)
z.shape

In [None]:
class ConvT(nn.Module):
  def __init__(self , 
               in_channels , 
               out_channels ,
               kernel_size = 2 , 
               stride = 2 , 
               padding = 0 , 
               gain = 2):
    super(ConvT , self).__init__()

    self.convT = nn.ConvTranspose2d(in_channels , out_channels , kernel_size , stride , padding)
    self.scale = (gain / (in_channels * (kernel_size ** 2)))**0.5

    self.bias = self.convT.bias
    self.convT.bias = None

    nn.init.normal_(self.convT.weight)
    nn.init.zeros_(self.bias)

  def forward(self , x):
    x = x * self.scale
    x = self.convT(x)
    x = x + self.bias.view(1 , self.bias.shape[0] , 1 , 1)
    return x
    



In [None]:
convT = ConvT(3 , 32).to(device)
x = torch.randn(2 , 3 , 256 , 256).to(device)
z = convT(x)
z.shape

In [None]:
class Generator_Block(nn.Module):
  def __init__(self , 
               in_channels ,
               out_channels, 
               kernel_size = 3 ,
               stride = 1 , 
               padding = 1 , 
               use_norm = True , 
               use_activation = True):
    super(Generator_Block , self).__init__()

    self.use_norm = use_norm
    self.use_activation = use_activation

    self.conv1 = Conv(in_channels , in_channels , kernel_size , stride , padding)
    
    self.conv2 = Conv(in_channels , out_channels , kernel_size , stride , padding)

    if self.use_norm:
      self.norm = Pixel_Norm()
    if self.use_activation:
      self.activation = nn.LeakyReLU(0.2)
  
  def forward(self , x):
    x = self.conv1(x)
    if self.use_norm:
      x = self.norm(x)
    if self.use_activation:
      x = self.activation(x)
    
    x = self.conv2(x)
    if self.use_norm:
      x = self.norm(x)
    if self.use_activation:
      x = self.activation(x)
    return x

In [None]:
generator_block = Generator_Block(3 , 32).to(device)
summary(generator_block , (3 , 512 , 512))

In [None]:
class Initial_Block(nn.Module):
  def __init__(self , 
               in_channels , 
               out_channels):
    super(Initial_Block , self).__init__()

    self.convT = ConvT(in_channels , out_channels , kernel_size=4 , stride=1 , padding=0)
    self.conv = Conv(out_channels , out_channels )
    self.lrelu = nn.LeakyReLU(0.2)
    self.pixel_norm = Pixel_Norm()

  def forward(self , x):
    x = self.lrelu(self.pixel_norm(self.convT(x)))
    x = self.lrelu(self.pixel_norm(self.conv(x)))
    return x

In [None]:
init_block = Initial_Block(512, 512).to(device)
summary(init_block , (512 , 1 , 1))

In [None]:
class Generator(nn.Module):
  def __init__(self , 
               in_channels , 
               out_channels):
    super(Generator , self).__init__()

    filters = [1, 1, 1, 1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32 ]
    self.initial_block = Initial_Block(in_channels , in_channels)
    #self.init_last = Conv(in_channels , out_channels , kernel_size=1 , stride=1 , padding=0)
    self.init_last = Generator_Block(in_channels , 
                                     out_channels , 
                                     kernel_size=1 , 
                                     stride=1 , 
                                     padding=0 , 
                                     use_norm=False ,
                                     use_activation=False)

    self.sigmoid = nn.Sigmoid()

    self.conv_ = nn.ModuleList()
    self.last_ = nn.ModuleList()

    for i in range(len(filters) -1):
      conv_in_channels = int(in_channels * filters[i])
      conv_out_channels = int(in_channels * filters[i + 1])
      #print(conv_in_channels , conv_out_channels  , filters[i])
      self.conv_.append(Generator_Block(conv_in_channels , conv_out_channels))
      self.last_.append(Conv(conv_out_channels , out_channels , kernel_size=1 , stride=1 , padding=0))

    self.upsample = nn.Upsample(scale_factor=2)
    
  def fade_in(self , upscaled , generated , alpha):
    return torch.tanh(alpha * generated + (1 - alpha) * upscaled)

  def forward(self , x ,alpha ,  steps):
    out = self.initial_block(x)
    if steps == 0:
      out = self.init_last(out)
      return self.sigmoid(out)
    #print(out.shape)
    for step in range(steps):
      upscaled = self.upsample(out)
      out = self.conv_[step](upscaled)
      #print(upscaled.shape , out.shape)


    

    upscaled = self.last_[step - 1](upscaled)
    generated = self.last_[step](out)

    z = self.fade_in(upscaled , generated , alpha)
    return z

In [None]:
generator_ = Generator(512 , 3).to(device)
x = torch.randn(2 , 512 , 1 , 1).to(device)
alpha = 1
steps = 8
z = generator_(x , alpha , steps)
z.shape

In [None]:
class Disc_Block(nn.Module):
  def __init__(self , 
               in_channels , 
               out_channels , 
               kernel_size = 3 , 
               stride = 1, 
               padding = 1 , 
               use_norm = True , 
               use_activation = True , 
               use_pool = True , 
               use_second_conv = True):
    super(Disc_Block , self).__init__()

    self.use_norm = use_norm
    self.use_activation = use_activation
    self.use_pool = use_pool
    self.use_second_conv = use_second_conv


    self.conv1 = Conv(in_channels , out_channels , kernel_size , stride , padding)
    
    if self.use_norm:
      self.norm = nn.InstanceNorm2d(out_channels)
      self.norm1 = nn.InstanceNorm2d(out_channels)
    if self.use_activation:
      self.activation = nn.LeakyReLU(0.2)

    self.pool = nn.MaxPool2d(kernel_size=2 , stride=2)

    self.conv2 = Conv(out_channels , out_channels)




  def forward(self , x):
    x = self.conv1(x)
    if self.use_norm:
      x = self.norm(x)
    if self.use_activation:
      x = self.activation(x)

    if self.use_second_conv:
      x = self.conv2(x)
      if self.use_norm:
        x = self.norm1(x)
      if self.use_activation:
        x = self.activation(x)
  
    if self.use_pool:
      x = self.pool(x)
    return x


In [None]:
disc_block = Disc_Block(32 , 3 , use_second_conv=False).to(device)
summary(disc_block , (32, 512 , 512))

In [None]:
class To_Discriminator(nn.Module):
  def __init__(self , 
               in_channels , 
               out_channels):
    super(To_Discriminator , self).__init__()

    self.linear1 = nn.Linear(in_channels , in_channels * 2)
    self.batchnorm1 = nn.BatchNorm1d(in_channels * 2)
    self.relu = nn.ReLU(inplace=True)

    self.linear2 = nn.Linear(in_channels * 2 , in_channels * 4)
    self.batchnorm2 = nn.BatchNorm1d(in_channels * 4)
    
    self.linear3= nn.Linear(in_channels * 4 , out_channels)
    self.sigmoid = nn.Sigmoid()

  def forward(self , x):
    x = self.relu(self.batchnorm1(self.linear1(x)))
    x = self.relu(self.batchnorm2(self.linera2(x)))
    x = self.sigmoid(self.linear3)
    return x

In [None]:
class Discriminator(nn.Module):
  def __init__(self , 
               in_channels , 
               out_channels , 
               hidden_dim = 16):
    super(Discriminator , self).__init__()

    self.init_conv = Disc_Block(in_channels , hidden_dim)
    self.init_last = Conv(hidden_dim + 1 , out_channels , kernel_size=2 , stride=2 , padding=0)

    self.conv_ = nn.ModuleList()
    self.last_ = nn.ModuleList()

    filters = [1 , 2 , 4 , 8 , 16 , 32 , 32 , 32 , 32]

    for i in range(len(filters)-1):
      conv_in_channels = hidden_dim * filters[i]
      conv_out_channels = hidden_dim * filters[i+1]
      self.conv_.append(Disc_Block(conv_in_channels , conv_out_channels))
      self.last_.append(Conv(conv_out_channels + 1 , out_channels , kernel_size=2 , stride=2 , padding=0))

    self.last_conv = Disc_Block(conv_out_channels , conv_out_channels)
    self.last_layer = Conv(conv_out_channels , out_channels , kernel_size=2 , stride=2 , padding=0)

    #print(self.conv_)

  def minibatch_std(self , x):
    x_ = torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
    #print(x_.shape , x.shape)
    x = torch.cat([x , x_] , dim=1)
    return x


  def forward(self  , x , alpha , steps):
    out = self.init_conv(x)
    if steps == 0:
      out = self.minibatch_std(out)
      out = self.init_last(out)
      return out

    #print(out.shape)

    for step in range(steps):
      #print(f'Step {step}')
      out = self.conv_[step](out)
      if step + 1 == steps:
        out = self.minibatch_std(out)
        out = self.last_[step](out)
      #print(f'out.shape {out.shape}')

    #print(out.shape)

    #out = self.last_conv(out)
    #out = self.last_layer(out)

    return out.view(out.shape[0] , -1)

In [None]:
disc = Discriminator(3 , 1).to(device)
x = torch.randn(2 , 3 , 1024 , 1024).to(device)
z = disc(x , 0.5 , 8)
z.shape

In [None]:

def show_tensor_images(image_tensor, num_images=2, size=(3 , 1024 , 1024)):
  image_shifted = image_tensor
  image_unflat = image_shifted.detach().cpu().view(-1, *size)
  image_grid = make_grid(image_unflat[:num_images], nrow=5)
  plt.imshow(image_grid.permute(1, 2, 0).squeeze())
  plt.show()

In [None]:
def crop(image, new_shape):
    middle_height = image.shape[2] // 2
    middle_width = image.shape[3] // 2
    starting_height = middle_height - new_shape[2] // 2
    final_height = starting_height + new_shape[2]
    starting_width = middle_width - new_shape[3] // 2
    final_width = starting_width + new_shape[3]
    cropped_image = image[:, :, starting_height:final_height, starting_width:final_width]
    return cropped_image

In [None]:
transform = transforms.Compose([
                                transforms.ToTensor()
])

In [None]:
dataset = torchvision.datasets.ImageFolder('/content/drive/MyDrive/Celeb_hq/celeba_hq/train/' , transform=transform)

In [None]:
batch_size = 2

In [None]:
dataloader = DataLoader(dataset , batch_size , shuffle=True)

In [None]:
def resize_tensor(input_tensors, h, w):
  final_output = None
  batch_size, channel, height, width = input_tensors.shape
  input_tensors = torch.squeeze(input_tensors, 1)
  
  for img in input_tensors:
    img_PIL = transforms.ToPILImage()(img)
    img_PIL = torchvision.transforms.Resize([h,w])(img_PIL)
    img_PIL = torchvision.transforms.ToTensor()(img_PIL)
    if final_output is None:
      final_output = img_PIL
    else:
      final_output = torch.unsqueeze(final_output , dim=0)
      img_PIL = torch.unsqueeze(img_PIL, 0)
      #print(final_output.shape , img_PIL.shape)
      final_output = torch.cat((final_output, img_PIL), 0)
      #print(final_output.shape)
  #final_output = torch.unsqueeze(final_output, 1)
  #print(final_output.shape)
  return final_output

In [None]:
for x , y in dataloader:
  print(x.shape)
  show_tensor_images(x , num_images=2)
  x = resize_tensor(x , 64 , 64)
  print(x.shape)
  show_tensor_images(x , size=(3 , 64 , 64))

  break

In [None]:
n_epochs = [2 , 3 , 4 , 5 , 10 , 50 , 100 , 200]
display_step = [100 , 75 , 50 , 25 , 10 , 5]
batch_size = 2
lr = 0.0002
target_shape = 512
betas = (0.5 , 0.999)

In [None]:
progan_steps = [0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8]

In [None]:
generator = Generator(in_channels_gen , out_channels).to(device)
discriminator = Discriminator(out_channels ,out_channels_disc).to(device)

In [None]:
opt_generator = torch.optim.Adam(generator.parameters() , lr=lr , betas=betas)
opt_discriminator = torch.optim.Adam(discriminator.parameters() , lr=lr , betas = betas)

In [None]:
adv_criterion = nn.BCEWithLogitsLoss()
recon_criterion = nn.L1Loss()
lambda_recon = 200

In [None]:
def get_gen_loss( cur_batch_size,
                 alpha , 
                 pro_steps , 
                 real_img , 
                 z_dim = z_dim , 
                 generator = generator , 
                 discriminator = discriminator , 
                 adv_criterion = adv_criterion , 
                 recon_criterion = recon_criterion , 
                 lambda_recon = lambda_recon):
  noise = torch.randn((cur_batch_size , z_dim , 1 , 1) , device = device , requires_grad=True , dtype=torch.float)
  fake_img = generator(noise , alpha , pro_steps)
  disc_fake_pred = discriminator(fake_img , alpha , pro_steps)
  disc_fake_loss = adv_criterion(disc_fake_pred , torch.zeros_like(disc_fake_pred))

  #real_img = crop(real_img , fake_img.shape)
  real_img = resize_tensor(real_img , fake_img.shape[2] , fake_img.shape[3])
  #print(real_img.shape , fake_img.shape)
  #show_tensor_images(real_img , size=(3 , fake_img.shape[2] , fake_img.shape[3]))

  gen_adv_loss = adv_criterion(fake_img , real_img)
  gen_recon_loss = recon_criterion(fake_img , real_img)

  loss = disc_fake_loss + lambda_recon * gen_adv_loss + lambda_recon * gen_recon_loss

  return loss , real_img

In [None]:
def train():
  mean_generator_loss = 0
  mean_discriminator_loss = 0
  cur_step = 0
  alpha = 1e-5
  for pro_step in progan_steps:
    for epoch in range(n_epochs[pro_step]):
      for real_img , _ in tqdm(dataloader):
        real_img = real_img.to(device)
        cur_batch_size = real_img.shape[0]
        #print(pro_step)
        #show_tensor_images(real_img)

        opt_generator.zero_grad()
        
        gen_loss , real_img_ = get_gen_loss(cur_batch_size , alpha , pro_step , real_img)

        gen_loss.backward()
        opt_generator.step()

        opt_discriminator.zero_grad()
        noise = torch.randn((cur_batch_size , z_dim , 1 , 1) , requires_grad=True , dtype=torch.float).to(device)
        with torch.no_grad():
          fake_img = generator(noise , alpha , pro_step)
        disc_fake_pred = discriminator(fake_img , alpha , pro_step)
        disc_real_pred = discriminator(real_img , alpha , pro_step)

        disc_fake_loss = adv_criterion(disc_fake_pred , torch.zeros_like(disc_fake_pred))
        disc_real_loss = adv_criterion(disc_real_pred , torch.ones_like(disc_real_pred))

        disc_loss = (disc_fake_loss + disc_real_loss)/2

        disc_loss.backward()
        opt_discriminator.step()

        mean_discriminator_loss += disc_loss.item() / display_step[pro_step]
        mean_generator_loss += gen_loss.item() / display_step[pro_step]

        if cur_step % display_step[pro_step] == 0:
          if cur_step > 0:
            print(f"ProGAN Steps {pro_step} :Epoch {epoch}: Step {cur_step}: Generator loss: {mean_generator_loss}, Discriminator loss: {mean_discriminator_loss}")
          else:
            print("Pretrained initial state")
          print('Real_image')
          #print(torch.max(real_img) , torch.min(real_img))
          show_tensor_images(real_img)
          print('Resized Real_image')
          #print(torch.max(real_img_) , torch.min(real_img_))
          show_tensor_images(real_img_ , size=(real_img_.shape[1] , real_img_.shape[2] , real_img_.shape[3]))
          print('Generated_image')
          #print(torch.max(fake_img) , torch.min(fake_img))
          show_tensor_images(fake_img , size=(fake_img.shape[1] , fake_img.shape[2] , fake_img.shape[3]))
          mean_generator_loss = 0
          mean_discriminator_loss = 0
        cur_step += 1   

In [None]:
train()

In [None]:
torch.save(generator.state_dict() , '/content/drive/MyDrive/Pro_GAN_Generator.pth')

In [None]:
torch.save(discriminator.state_dict() , '/content/drive/MyDrive/Pro_GAN_Discriminator.pth')