# Progressive Growing GAN (PGGAN) with PyTorch

## Overview

This Notebook contains the PyTorch implementation of a Progressive Growing Generative Adversarial Network (PGGAN). PGGAN is a cutting-edge GAN architecture designed to generate high-quality images with intricate details.

## What is PGGAN

High-quality image production was greatly aided by the introduction of Progressive Growing GANs (PGGANs), a unique training approach for Generative Adversarial Networks (GANs). PGGANs, which were introduced in 2017 by Tero Karras, Timo Aila, Samuli Laine, and Jaakko Lehtinen, tackle issues related to training stability, scalability, and the production of varied and intricate visuals. The main characteristic is the gradual growth of the discriminator and generator networks, which begin at low resolutions and add layers one at a time to handle higher resolutions.

## Key Features

- **Progressive Growth:** Incremental addition of layers to G and D during training for high-quality image generation.
- **Stability:** Addresses common issues such as mode collapse and training instability for more reliable training.
- **Quality and Variation:** The progressive training approach leads to improved image quality and the generation of diverse samples.

## Visualization

![Progressive Growing GAN](87937a70-cafc-43e3-a4ea-7ae472b0fef8.png)

The image above visualizes the process of gradual growth. The process begins with both the generator (G) and discriminator (D) having a low spatial resolution of 4Ã—4 pixels. As training progresses, layers are incrementally added to G and D, resulting in a gradual increase in the spatial resolution of the generated images. Importantly, all existing layers remain trainable throughout this process.

## Refer to Original Paper

For a detailed understanding of Progressive Growing GANs, please refer to the [original paper](https://arxiv.org/pdf/1710.10196.pdf):

- [Progressive Growing of GANs for Improved Quality, Stability, and Variation](https://arxiv.org/pdf/1710.10196.pdf) by Tero Karras, Timo Aila, Samuli Laine, and Jaakko Lehtinen (2018).

## **Model Structure**
![Model Structure](025262ee-87dd-4f8a-aaf6-5b8b5ed73086.png)

In the image above, We can see the model structure proposed in the paper. and here in this part, we will implement the structure.
### Tools for Building Generator:

#### 1. **Equalized Learning Rate Conv2d (`WSConv2d`):**
   - **Purpose:** Weight scaling is applied to the convolutional layers to ensure equalized learning rates, as suggested by the ProGan paper.
   - **Why Use:** Normalizing the weights helps stabilize the learning process and improves convergence during training. It ensures that the scale of weights does not hinder the optimization.

#### 2. **PixelNorm (`PixelNorm`):**
   - **Purpose:** Normalizes individual pixel values in the input tensor, ensuring consistent scales across pixels.
   - **Why Use:** Pixel-wise normalization helps maintain stability and improves the quality of generated images. It prevents certain pixels from dominating the learning process.

#### 3. **Convolution Block (`ConvBlock`):**
   - **Purpose:** Combines multiple layers, including equalized learning rate convolution, Leaky ReLU activation, and optional pixel-wise normalization.
   - **Why Use:** The convolution block serves as a fundamental building block for both the Generator and Discriminator. It introduces non-linearity through activation functions and enables the network to learn hierarchical features.

---

### Tools for Building Discriminator:

#### 1. **Equalized Learning Rate Conv2d (`WSConv2d`):**
   - **Purpose:** Same as in the Generator, weight scaling is applied to convolutional layers for equalized learning rates.
   - **Why Use:** Consistency in weight scaling between the Generator and Discriminator ensures stable adversarial training.

#### 2. **PixelNorm (`PixelNorm`):**
   - **Purpose:** Used for normalizing pixel values, similar to its role in the Generator.
   - **Why Use:** Ensures that input images to the Discriminator have consistent pixel-wise scales, contributing to a stable learning process.

#### 3. **Convolution Block (`ConvBlock`):**
   - **Purpose:** Similar to the Generator, the convolution block is utilized in the Discriminator to introduce non-linearity and hierarchical feature learning.
   - **Why Use:** Leaky ReLU activation and optional pixel-wise normalization contribute to the Discriminator's ability to discern features in the input images.

#### 4. **Average Pooling (`nn.AvgPool2d`):**
   - **Purpose:** Down-samples the input tensor using average pooling.
   - **Why Use:** Down-sampling is crucial for progressive growing. It reduces the spatial resolution of the input, enabling the Discriminator to focus on larger-scale features in higher-resolution images.

#### 5. **Minibatch Standard Deviation:**
   - **Purpose:** Computes the standard deviation across the batch and concatenates it with the input tensor.
   - **Why Use:** Enhances the Discriminator's ability to detect variations and patterns across the entire batch, making it less sensitive to minor differences between generated and real images.

#### 6. **Fade-in Operation:**
   - **Purpose:** Linearly interpolates between two input tensors (e.g., between down-scaled and original images).
   - **Why Use:** Facilitates the smooth transition between resolutions during progressive growing, preventing sudden jumps and promoting stable training.

These tools collectively contribute to the stability, convergence, and overall performance of both the Generator and Discriminator in the Progressive Growing GAN architecture.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from math import log2

# In generator conv blocks, the channels go like "512->512->512->512->256->128->64->32->16"
factors = [1,1,1,1,1/2,1/4,1/8,1/16,1/32]

class WSConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=2):
        """
        Weight-scaled convolutional layer with equalized learning rates.

        Parameters:
        - in_channels: Number of input channels.
        - out_channels: Number of output channels.
        - kernel_size: Size of the convolutional kernel (default: 3).
        - stride: Stride of the convolution (default: 1).
        - padding: Padding of the convolution (default: 1).
        - gain: Gain for weight scaling (default: 2).
        """
        super(WSConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (gain / (in_channels * (kernel_size ** 2))) ** 0.5
        self.bias = self.conv.bias
        self.conv.bias = None

        # Initialize conv layer
        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        """Applies the convolution operation with weight scaling and adds bias."""
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)

class PixelNorm(nn.Module):
    def __init__(self):
        """Pixel-wise normalization to ensure consistent scales across pixels."""
        super(PixelNorm, self).__init__()
        self.epsilon = 1e-8

    def forward(self, x):
        """Normalizes the input tensor by dividing each pixel by the square root of the mean of squared pixel values."""
        return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_pixelnorm=True):
        """
        Convolutional block with Leaky ReLU activation.

        Parameters:
        - in_channels: Number of input channels.
        - out_channels: Number of output channels.
        - use_pixelnorm: Boolean indicating whether to use pixel-wise normalization (default: True).
        """
        super(ConvBlock, self).__init__()
        self.use_pn = use_pixelnorm
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2)
        self.pn = PixelNorm()

    def forward(self, x):
        """Applies Leaky ReLU activation after each convolutional layer and optionally applies pixel-wise normalization."""
        x = self.leaky(self.conv1(x))
        x = self.pn(x) if self.use_pn else x
        x = self.leaky(self.conv2(x))
        x = self.pn(x) if self.use_pn else x
        return x


# Generator in Progressive Growing GAN

The Generator in a Progressive Growing GAN is a neural network responsible for creating realistic images. Its primary function is to generate images by transforming a random noise vector into increasingly detailed and high-resolution images. Here's how it achieves this:

## 1. Initial Block:
- The process starts with an initial block that takes a random noise vector as input.
- This block consists of operations like normalization (PixelNorm), a transposed convolution, and activation functions (LeakyReLU).
- It aims to create a preliminary feature map representing basic patterns.

## 2. Progressive Growing:
- The Generator progressively refines the feature map through a series of blocks.
- Each block includes equalized learning rate convolutions, activation functions (LeakyReLU), and optional normalization (PixelNorm).
- As the training progresses, additional blocks are added, increasing the spatial resolution of the feature map.

## 3. RGB Conversion:
- At each step, a corresponding 1x1 convolutional layer (RGB layer) converts the feature map to the desired number of image channels (commonly 3 for RGB).
- These RGB layers contribute to the final synthesized image.

## 4. Fade-in Operation:
- To ensure a smooth transition between different resolutions, a fade-in operation is employed.
- It linearly blends the output of the current resolution with the upscaled version from the previous step.
- This technique prevents abrupt changes and aids in stable training.

## 5. Linear Interpolation:
- Linear interpolation is used to smoothly transition between different resolutions during progressive growing.
- It involves blending the images produced at different resolutions, controlled by a linear interpolation factor.

## 6. Final Output:
- The final output of the Generator is a high-resolution, realistic image that has been progressively refined through multiple steps.

In summary, the Generator's role is to transform random noise into detailed and coherent images, progressively improving their quality during training. The fade-in operation and linear interpolation contribute to the stable evolution of generated images in the Progressive Growing GAN.


In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim, in_channels, img_channels=3):
        """
        Generator network for Progressive Growing GAN.

        Parameters:
        - z_dim: Dimension of the input noise vector.
        - in_channels: Number of channels in the initial convolutional layer.
        - img_channels: Number of channels in the generated images (default: 3 for RGB).
        """
        super(Generator, self).__init__()
        # Initial structure of Generator should be opposite of ending structure of Discriminator
        # Initial takes 1x1 -> 4x4
        self.initial = nn.Sequential(
            PixelNorm(),
            nn.ConvTranspose2d(z_dim, in_channels, 4, 1, 0),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2),
            PixelNorm(),
        )

        self.initial_rgb = WSConv2d(
            in_channels, img_channels, kernel_size=1, stride=1, padding=0
        )
        self.prog_blocks, self.rgb_layers = (
            nn.ModuleList([]),
            nn.ModuleList([self.initial_rgb]),
        )

        for i in range(len(factors) - 1):
            conv_in_c = int(in_channels * factors[i])
            conv_out_c = int(in_channels * factors[i + 1])
            self.prog_blocks.append(ConvBlock(conv_in_c, conv_out_c))
            self.rgb_layers.append(
                WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0)
            )

    def fade_in(self, alpha, upscaled, generated):
        """Fade-in between upscaled and generated images."""
        return torch.tanh(alpha * generated + (1 - alpha) * upscaled)

    def forward(self, x, alpha, steps):
        """
        Forward pass of the Generator.

        Parameters:
        - x: Input noise vector.
        - alpha: Linear interpolation factor for fade-in.
        - steps: Number of progressive growing steps.

        Returns:
        - Generated image.
        """
        out = self.initial(x)

        if steps == 0:
            return self.initial_rgb(out)

        for step in range(steps):
            upscaled = F.interpolate(out, scale_factor=2, mode="nearest")
            out = self.prog_blocks[step](upscaled)

        final_upscaled = self.rgb_layers[steps - 1](upscaled)
        final_out = self.rgb_layers[steps](out)
        return self.fade_in(alpha, final_upscaled, final_out)

# Discriminator in Progressive Growing GAN

The Discriminator in a Progressive Growing GAN is a neural network tasked with distinguishing between real and generated images. Its design allows it to handle images at different resolutions progressively. Let's explore its functionality:

## 1. Progressive Discrimination:
- The Discriminator works in a progressive manner, just like the Generator. It processes images at various resolutions, enabling the model to handle different levels of details.

## 2. Reverse Architecture:
- The Discriminator's architecture is designed in reverse compared to the Generator. It starts with higher-resolution processing and gradually decreases resolution.

## 3. From-RGB Layers:
- The Discriminator begins with a set of 'from-RGB' layers, each corresponding to a specific input resolution. These layers convert the input image to a feature map that the Discriminator can process.

## 4. Down-Sampling:
- To accommodate different resolutions, the Discriminator utilizes down-sampling through operations like Average Pooling. This downscales the input feature map while maintaining important information.

## 5. Minibatch Standard Deviation:
- The Discriminator incorporates a Minibatch Standard Deviation operation. This operation adds information about the variation within a batch, helping the Discriminator consider the diversity of the input images.

## 6. Final Block:
- The Discriminator concludes with a final block responsible for making the ultimate discrimination between real and generated images. This block includes convolutions, activation functions (LeakyReLU), and a final convolutional layer to produce the output.

## 7. Fade-in Operation:
- Similar to the Generator, the Discriminator employs a fade-in operation. This smoothens the transition between different resolutions, contributing to more stable training.

## 8. Linear Interpolation:
- Linear interpolation is used for a seamless transition in the Discriminator's evaluation of images at different resolutions. It ensures a consistent discrimination process during progressive growing.

In summary, the Discriminator's role is to distinguish between real and generated images, progressively adapting to varying levels of detail. The use of from-RGB layers, down-sampling, and the fade-in operation allows the Discriminator to effectively handle images at different resolutions in the Progressive Growing GAN.


In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels, img_channels=3):
        """
        Discriminator network for Progressive Growing GAN.

        Parameters:
        - in_channels: Number of channels in the initial convolutional layer.
        - img_channels: Number of channels in the input images (default: 3 for RGB).
        """
        super(Discriminator, self).__init__()
        self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
        self.leaky = nn.LeakyReLU(0.2)

        for i in range(len(factors) - 1, 0, -1):
            conv_in = int(in_channels * factors[i])
            conv_out = int(in_channels * factors[i - 1])
            self.prog_blocks.append(ConvBlock(conv_in, conv_out, use_pixelnorm=False))
            self.rgb_layers.append(
                WSConv2d(img_channels, conv_in, kernel_size=1, stride=1, padding=0)
            )

        self.initial_rgb = WSConv2d(
            img_channels, in_channels, kernel_size=1, stride=1, padding=0
        )
        self.rgb_layers.append(self.initial_rgb)
        self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)

        self.final_block = nn.Sequential(
            WSConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, 1, kernel_size=1, padding=0, stride=1),
        )

    def fade_in(self, alpha, downscaled, out):
        """Fade-in between downscaled and output images."""
        return alpha * out + (1 - alpha) * downscaled

    def minibatch_std(self, x):
        """Calculate minibatch standard deviation and concatenate it with the input."""
        batch_statistics = torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
        return torch.cat([x, batch_statistics], dim=1)

    def forward(self, x, alpha, steps):
        """
        Forward pass of the Discriminator.

        Parameters:
        - x: Input image.
        - alpha: Linear interpolation factor for fade-in.
        - steps: Number of progressive growing steps.

        Returns:
        - Discriminator output.
        """
        cur_step = len(self.prog_blocks) - steps

        out = self.leaky(self.rgb_layers[cur_step](x))

        if steps == 0:
            out = self.minibatch_std(out)
            return self.final_block(out).view(out.shape[0], -1)

        downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
        out = self.avg_pool(self.prog_blocks[cur_step](out))
        out = self.fade_in(alpha, downscaled, out)

        for step in range(cur_step + 1, len(self.prog_blocks)):
            out = self.prog_blocks[step](out)
            out = self.avg_pool(out)

        out = self.minibatch_std(out)
        return self.final_block(out).view(out.shape[0], -1)

# Configurations & Hyperparameters

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

START_TRAIN_IMG_SIZE = 16
DATASET = "../input/human-faces"

CHECKPOINT_GEN = "generator.pth"
CHECKPOINT_CRITIC = "critic.pth"
SAVE_MODEL = False
LOAD_MODEL = False

LR = 1e-3
BATCH_SIZES = [32,32,32,32,16,16,16,4,4,4] ## modifiable/ Batch_sizes for each step
IMAGE_SIZE = 128 ## 1024 for paper
IMG_CHANNELS = 3
Z_DIM = 256 ## 512 for paper
IN_CHANNELS = 256 ## 512 for paper
LAMBDA_GP = 10
NUM_STEPS = int(log2(IMAGE_SIZE/4)) + 1

PROGRESSIVE_EPOCHS = [4] * len(BATCH_SIZES)
FIXED_NOISE = torch.randn(8,Z_DIM,1,1).to(DEVICE)
# NUM_WORKERS = 4
NUM_WORKERS = 2

print(DEVICE)

# Utils

In [None]:
import random
import torchvision
from torchvision.utils import save_image

def save_on_tensorboard(writer,loss_critic,loss_gen,real,fake,tensorboard_step):
    writer.add_scalar("Loss Critic",loss_critic,global_step=tensorboard_step)
    writer.add_scalar("Loss Generator", loss_gen, global_step=tensorboard_step)
    
    with torch.no_grad():
        img_grid_real = torchvision.utils.make_grid(real[:8],normalize=True)
        img_grid_fake = torchvision.utils.make_grid(fake[:8],normalize=True)
        
        writer.add_image("Real",img_grid_real,global_step = tensorboard_step)
        writer.add_image("Fake",img_grid_fake,global_step = tensorboard_step)
        
def gradient_penalty(critic,real,fake,alpha,train_step,device="cpu"):
    BATCH_SIZE,C,H,W = real.shape
    beta = torch.rand((BATCH_SIZE,1,1,1)).repeat(1,C,H,W).to(device)
    
    interpolated_images = real * beta + fake.detach() * (1-beta)
    interpolated_images.requires_grad_(True)
    
    ## Calculate critic scores
    mixed_scores = critic(interpolated_images,alpha,train_step)
    
    ## Take the gradient of the scores with respect to the image
    gradient = torch.autograd.grad(
        inputs = interpolated_images,
        outputs = mixed_scores,
        grad_outputs = torch.ones_like(mixed_scores),
        create_graph = True,
        retain_graph = True
    )[0]
    
    gradient = gradient.view(gradient.shape[0],-1)
    gradient_norm = gradient.norm(2,dim=1)
    penalty = torch.mean((gradient_norm - 1)**2)
    return penalty

def save_checkpoint(model,optimizer,filename="my_checkpoint.pth"):
    print("Saving Checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer" : optimizer.state_dict()
    }
    torch.save(checkpoint,filename)
    
def load_checkpoint(checkpoint_file,model,optimizer,lr):
    print("Loading Checkpoint")
    checkpoint = torch.load(checkpoint_file,map_location="cuda")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr
        
def generate_examples(gen,current_epoch,steps,n=16):
    gen.eval()
    aplha = 1.0
    
    for i in range(n):
        with torch.no_grad():
            noise = torch.randn(1,Z_DIM,1,1).to(DEVICE)
            generated_img = gen(noise,alpha=alpha,steps=steps)
            save_image(generated_img*0.5+0.5,f"generated_images/step{steps}_epoch{current_epoch}_{i}.png")
#             save_image(generated_img*0.5+0.5,f"step:{steps}_epoch{current_epoch}_{i}.png")
    
    gen.train()

# Train

In [None]:
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

In [None]:
torch.backends.cudnn.benchmarks = True

def get_loader(img_size):
    transform = transforms.Compose(
    [
        transforms.Resize((img_size,img_size)),
        transforms.ToTensor(),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.Normalize([0.5 for _ in range(IMG_CHANNELS)],[0.5 for _ in range(IMG_CHANNELS)])
    ])
    
    batch_size = BATCH_SIZES[int(log2(img_size/4))]
    dataset = datasets.ImageFolder(root=DATASET,transform=transform)
    loader = DataLoader(dataset,batch_size=batch_size,shuffle=True,num_workers=NUM_WORKERS,pin_memory=True)
    
    return loader,dataset

def train_fn(gen,critic,loader,dataset,step,alpha,opt_gen,opt_critic,tensorboard_step,writer,scaler_gen,scaler_critic):
    loop = tqdm(loader,leave=True)
    
    i = 0
    for batch_idx,(real,_) in enumerate(loop):
        i += 1
        if i%2 == 0:
            continue
        real = real.to(DEVICE)
        cur_batch_size = real.shape[0]
        noise = torch.randn(cur_batch_size,Z_DIM,1,1).to(DEVICE)
        
        ## Train Critic
        ## Wasserstein Loss : Maximize "E[Critic(real)] - E[Critic(fake)]"   ==   Minimize "-(E[Critic(real)] - E[Critic(fake)])"
        with torch.cuda.amp.autocast():
            fake = gen(noise,alpha,step).to(DEVICE)
            critic_real = critic(real,alpha,step)
            critic_fake = critic(fake.detach(),alpha,step)
            gp = gradient_penalty(critic,real,fake,alpha,step,device=DEVICE)
            loss_critic = -1 * (torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp + 0.001 * torch.mean(critic_real**2)
        
        critic.zero_grad()
        scaler_critic.scale(loss_critic).backward()
        scaler_critic.step(opt_critic)
        scaler_critic.update()
        
        ## Train Generator
        ## Maximize "E[Critic(fake)]"   ==   Minimize "- E[Critic(fake)]"
        with torch.cuda.amp.autocast():
            gen_fake = critic(fake,alpha,step)
            loss_gen = -1 * torch.mean(gen_fake)
            
        gen.zero_grad()
        scaler_gen.scale(loss_gen).backward()
        scaler_gen.step(opt_gen)
        scaler_gen.update()
    
        alpha += (cur_batch_size/len(dataset)) * (1/PROGRESSIVE_EPOCHS[step]) * 2
        alpha = min(alpha,1)
        
        if batch_idx % 500 == 0:
            with torch.no_grad():
                fixed_fakes = gen(FIXED_NOISE,alpha,step) * 0.5 + 0.5
                save_on_tensorboard(writer,loss_critic.item(),loss_gen.item(),real.detach(),fixed_fakes.detach(),tensorboard_step)
                tensorboard_step += 1
    
    return tensorboard_step,alpha
        
## build model
gen = Generator(Z_DIM,IN_CHANNELS,IMG_CHANNELS).to(DEVICE)
critic = Discriminator(IN_CHANNELS,IMG_CHANNELS).to(DEVICE)

## initialize optimizer,scalers (for FP16 training)
opt_gen = optim.Adam(gen.parameters(),lr=LR,betas=(0.0,0.99))
opt_critic = optim.Adam(critic.parameters(),lr=LR,betas=(0.0,0.99))
scaler_gen = torch.cuda.amp.GradScaler()
scaler_critic = torch.cuda.amp.GradScaler()

## tensorboard writer
writer = SummaryWriter(f"runs/PG_GAN")
tensorboard_step = 0

## if checkpoint files exist, load model
if LOAD_MODEL:
    load_checkpoint(CHECKPOINT_GEN,gen,opt_gen,LR)
    load_checkpoint(CHECKPOINT_CRITIC,critic,opt_critic,LR)
    
gen.train()
critic.train()

step = int(log2(START_TRAIN_IMG_SIZE/4)) ## starts from 0

global_epoch = 0
generate_examples_at = [4,8,12,16,20,24,28,32]

for num_epochs in PROGRESSIVE_EPOCHS[step:]:
    alpha = 1e-4
    loader,dataset = get_loader(4*2**step)
    print(f"Image size:{4*2**step} | Current step:{step}")
    
    for epoch in range(num_epochs):
        print(f"Epoch [{epoch+1}/{num_epochs}] Global Epoch:{global_epoch}")
        tensorboard_step,alpha = train_fn(gen,critic,loader,dataset,step,alpha,opt_gen,opt_critic,tensorboard_step,writer,scaler_gen,scaler_critic)
        global_epoch += 1
        if global_epoch in generate_examples_at:
            generate_examples(gen,global_epoch,step,n=6)
        
        if SAVE_MODEL and (epoch+1)%8==0:
            save_checkpoint(gen,opt_gen,filename="CHECKPOINT_GEN")
            save_checkpoint(critic,opt_critic,filename="CHECKPOINT_CRITIC")
            
    step += 1 ## Progressive Growing
    
print("Training finished")