In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision
from tqdm.notebook import tqdm
!jupyter nbextension enable --py widgetsnbextension
from torch.utils.tensorboard import SummaryWriter
import albumentations as A
import torch.nn.functional as F
from math import log2
from albumentations.pytorch import ToTensorV2
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import numpy as np
torch.backends.cudnn.benchmark=True

  warn(


Enabling notebook extension jupyter-js-widgets/extension...
      - Validating: [32mOK[0m


In [2]:
factors = [1, 1, 1, 1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32]

In [3]:
class WeightedSumConv2d(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=3,stride=1,padding=1,gain=2):
        super(WeightedSumConv2d,self).__init__()
        self.conv=nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding)
        self.scale=(gain/(in_channels*(kernel_size**2)))**0.5
        self.bias=self.conv.bias
        self.conv.bias=None
        # initalize convolution layer
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)
        
    def forward(self,x):
        return self.conv(x*self.scale)+self.bias.view(1,self.bias.shape[0],1,1)

In [4]:
class PixelNorm(nn.Module):
    def __init__(self):
        super(PixelNorm, self).__init__()
        self.eps=1e-8
        
    def forward(self,x):
        return x/torch.sqrt(torch.mean(x**2,dim=1,keepdim=True)+self.eps)

In [5]:
class ConvBlock(nn.Module):
    def __init__(self,in_channels,out_channels,use_pix_norm=True):
        super(ConvBlock, self).__init__()
        self.use_pix_norm=use_pix_norm
        self.con1=WeightedSumConv2d(in_channels,out_channels)
        self.con2=WeightedSumConv2d(out_channels,out_channels)
        self.leaky=nn.LeakyReLU(0.2)
        self.pn=PixelNorm()
    def forward(self,x):
        x=self.leaky(self.con1(x))
        x=self.pn(x) if self.use_pix_norm else x
        x=self.leaky(self.con2(x))
        x=self.pn(x) if self.use_pix_norm else x
        return x
    

In [6]:
class Generator(nn.Module):
    def __init__(self,nosie_dim,in_channels,img_channels=3) -> None:
        super().__init__()
        self.inital_block=nn.Sequential(
            PixelNorm(),
            nn.ConvTranspose2d(nosie_dim,in_channels,4,1,0),
            nn.LeakyReLU(0.2),
            WeightedSumConv2d(in_channels,in_channels,kernel_size=3,stride=1,padding=1),
            nn.LeakyReLU(0.2),
            PixelNorm(),
        )
        
        self.inital_rgb=WeightedSumConv2d(in_channels,img_channels,kernel_size=1,stride=1,padding=0)
        self.progessive_blocks,self.rgb_layers=nn.ModuleList([]),nn.ModuleList([self.inital_rgb])
        
        for i in range(len(factors)-1):
            conv_in_c=int(in_channels*factors[i])
            conv_out_c=int(in_channels*factors[i+1])
            self.progessive_blocks.append(ConvBlock(conv_in_c,conv_out_c))
            self.rgb_layers.append(WeightedSumConv2d(conv_out_c,img_channels,
                                                     kernel_size=1,stride=1,padding=0))
    def fade_in(self,alpha,upsclaed,generated):
        return torch.tanh(alpha*generated+(1-alpha)*upsclaed)
    
    def forward(self,x,alpha,steps):
        out=self.inital_block(x)
        # if step=0 4x4 if step=1 8x8 ......
        if steps==0:
            return self.inital_rgb(out)
        
        for step in range(steps):
            upscaled=nn.functional.interpolate(out,scale_factor=2,mode='nearest')
            out=self.progessive_blocks[step](upscaled)
            
        final_upscaled=self.rgb_layers[steps-1](upscaled)
        final_out=self.rgb_layers[steps](out)
        
        return self.fade_in(alpha,final_upscaled,final_out)

In [7]:
class CriticDiscriminator(nn.Module):
    def __init__(self,noise_dim,in_channels,img_channels) -> None:
        super(CriticDiscriminator,self).__init__()
        self._progessive_block,self.rgb_layers=nn.ModuleList([]),nn.ModuleList([])
        self.leaky=nn.LeakyReLU(0.2)

        for i in range(len(factors)-1,0,-1):
            con_in_c=int(in_channels*factors[i])
            con_in_out=int(in_channels*factors[i-1])
            self._progessive_block.append(ConvBlock(con_in_c,con_in_out,use_pix_norm=False))
            
            self.rgb_layers.append(WeightedSumConv2d(img_channels,con_in_c,kernel_size=1,stride=1,padding=0))
            
        self.initial_rgb=WeightedSumConv2d(img_channels,in_channels,kernel_size=1,stride=1,padding=0)
        self.avg_pool=nn.AvgPool2d(kernel_size=2,stride=2)
        self.rgb_layers.append(self.initial_rgb)
        self.final_layer=nn.Sequential(
            WeightedSumConv2d(in_channels+1,in_channels,kernel_size=3,stride=1,padding=1),
            nn.LeakyReLU(0.2),
            WeightedSumConv2d(in_channels,in_channels,kernel_size=4,stride=1,padding=0),
            nn.LeakyReLU(0.2),
            WeightedSumConv2d(in_channels,1,kernel_size=1,stride=1,padding=0),
        )
    def fade_in(self,alpha,downscale,out):
        return (alpha*out +(1-alpha)*downscale)
    
    def minibatchstd(self,x):
        batch_statistics=torch.std(x,dim=0).mean().repeat(x.shape[0],1,x.shape[2],x.shape[3])
        return torch.cat([x,batch_statistics],dim=1)
    
    def forward(self,x,alpha,steps):# steps=0 (4x4). step=1 (8x8)
        current_step=len(self._progessive_block)-steps
        out=self.leaky(self.rgb_layers[current_step](x))
        
        if steps==0:
            out=self.minibatchstd(out)
            return self.final_layer(out).view(out.shape[0],-1)
        
        downscaled=self.leaky(self.rgb_layers[current_step+1](self.avg_pool(x)))
        out=self.avg_pool(self._progessive_block[current_step](out))
        out=self.fade_in(alpha,downscaled,out)
        
        for step in range(current_step+1,len(self._progessive_block)):
            out=self.avg_pool(self._progessive_block[step](out))
        out=self.minibatchstd(out)
        return self.final_layer(out).view(out.shape[0],-1)
                    

In [8]:
# noise_dim=50
# in_channels=256
# g=Generator(noise_dim,in_channels=in_channels,img_channels=3)
# d=CriticDiscriminator(noise_dim,in_channels=in_channels,img_channels=3)
# for img_size in [4,8,16,32,64,128,256,512,1024]:
#     num_steps=int(log2(img_size/4))
#     x=torch.randn(1,noise_dim,1,1)
#     z=g(x,0.5,steps=num_steps)
#     print(z.shape)
#     assert z.shape == (1, 3, img_size, img_size)
#     out = d(z, alpha=0.5, steps=num_steps)
#     print(out.shape)
#     assert out.shape == (1, 1)
#     print(f"Success! At img size: {img_size}")

In [9]:
def gradient_penalty(critic, real, fake,alpha,train_step,device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * beta + fake.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images, alpha, train_step)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

In [10]:
img_channels=3
def data_loader(image_size):
    transform = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.Normalize(
                [0.5 for _ in range(img_channels)],
                [0.5 for _ in range(img_channels)],
            ),
        ]
    )
    batch_size = bs[int(log2(image_size / 4))]
    dataset = datasets.ImageFolder(root="/mnt/disk1/Gulshan/GAN/ProGAN/celeba_hq/train", transform=transform)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
    )
    return loader, dataset

# Hyperparameters

In [11]:
Start_Train_Img_size=4
device="cuda:7" if torch.cuda.is_available() else "cpu"
lr=1e-3
num_workers=4
lambda_GP=10
critic_iteration=1
bs = [32, 32, 32, 16, 16, 16, 16, 8, 4]
img_channels=3
noise_dim=256
in_channels=256
progessive_epochs=[70]*len(bs)
fixed_noise=torch.randn(8,noise_dim,1,1).to(device)
epoch_step = int(log2(Start_Train_Img_size / 4))
progessive_epochs

[70, 70, 70, 70, 70, 70, 70, 70, 70]

In [12]:
generater=Generator(noise_dim,in_channels,img_channels).to(device)
discriminator=CriticDiscriminator(noise_dim,in_channels,img_channels).to(device)
opt_gen = optim.Adam(generater.parameters(), lr=lr, betas=(0.0, 0.99))
opt_critic = optim.Adam(
    discriminator.parameters(), lr=lr, betas=(0.0, 0.99)
)
scaler_critic = torch.cuda.amp.GradScaler()
scaler_gen = torch.cuda.amp.GradScaler()

writer = SummaryWriter(f"logs/gan1")

In [13]:
def train():
    generater.train()
    discriminator.train()
    for num_epochs in progessive_epochs[epoch_step:]:
        alpha=1e-5
        loader,dataset=data_loader(4 * 2 ** epoch_step)
        print(f"Current image size: {4 * 2 ** epoch_step}")
        for epoch in tqdm(range(num_epochs),total=num_epochs):
            for idx,(real,_) in enumerate(tqdm(loader)):
                real=real.to(device)
                current_batch_size=real.shape[0]
                # train critic discriminator
                noise=torch.randn(current_batch_size,noise_dim,1,1).to(device)
                with torch.cuda.amp.autocast():
                    fake=generater(noise,alpha,epoch_step)
                    critic_real=discriminator(real,alpha,epoch_step)
                    critic_fake=discriminator(fake.detach(),alpha,epoch_step)
                    gp=gradient_penalty(discriminator,real,fake,alpha,epoch_step,device)
                    
                    loss_critic=(
                        -(torch.mean(critic_real)- torch.mean(critic_fake))
                        + lambda_GP*gp
                        +(0.001*torch.mean(critic_real**2)) #0.001 -drift to make critic go fdar away from zero
                    )
                opt_critic.zero_grad()
                scaler_critic.scale(loss_critic).backward()
                scaler_critic.step(opt_critic)
                scaler_critic.update()
                    
                with torch.cuda.amp.autocast():
                    gen_fake=discriminator(fake,alpha,epoch_step)
                    loss_gen=-torch.mean(gen_fake)
                
                opt_gen.zero_grad()
                scaler_gen.scale(loss_gen).backward()
                scaler_gen.step(opt_gen)
                scaler_gen.update()    
                
                alpha+=current_batch_size/(len(dataset)*progessive_epochs[epoch_step]*0.5)
                
                if idx % 500 == 0:
                    with torch.no_grad():
                        fixed_fake = generater(fixed_noise, alpha, epoch_step) * 0.5 + 0.5
                        img_grid_real = torchvision.utils.make_grid(real[:8].detach(), normalize=True)
                        img_grid_fake = torchvision.utils.make_grid(fixed_fake[:8].detach(), normalize=True)
                        writer.add_image("Real", img_grid_real, global_step=epoch)
                        writer.add_image("Fake", img_grid_fake, global_step=epoch)
                        writer.add_scalar("Loss Critic", loss_critic.item(), global_step=epoch)
                        writer.add_scalar("Loss generator", loss_gen.item(), global_step=epoch)
train()          

Current image size: 4


  0%|          | 0/70 [00:00<?, ?it/s]

  0%|          | 0/274 [00:00<?, ?it/s]