# Model

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from math import log2
#tensorboard
from torch.utils.tensorboard import SummaryWriter
#tqdm
from tqdm import tqdm

from AdnGAN import Generator_pro, Discriminator_pro,tools,get_loader

### Equalized learning

$$w_f=w_i\sqrt{\frac{2}{k*k**c}}$$

* k : Kernel size
* c : in channels
* w : weigh

**Weight Initialization:** Initializing weights from a normal distribution means that the initial weights have varying magnitudes. In convolutional neural networks (CNNs), each weight corresponds to a feature detector, and these detectors might have very different magnitudes initially.

**Consequence of Weight Initialization:** During the forward pass through a layer, the input is convolved with these weights. If the weights have varying magnitudes, the output of the convolution will also have varying magnitudes. This can lead to instability in training because the network might respond more strongly to some input features than to others, simply due to the magnitude of the weights.

**Normalization:** By normalizing the weights, you ensure that each weight vector has a consistent magnitude. This makes the network more robust and less sensitive to the scale of the input features.

**Scaling:** However, when you normalize the weights, you might inadvertently reduce the overall magnitude of the signal passing through the layer. To counteract this, you scale the input by a factor (self.scale in your code). This scaling factor ensures that the signal's magnitude remains approximately constant despite the weight normalization.

In [3]:
#equalized leaning rate for a conv2d
class WSConv2d(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=3,stride=1,padding=1,gain=2):
        super().__init__()
        self.conv = nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding)
        self.scale = (gain / (in_channels * (kernel_size ** 2))) ** 0.5
        self.bias = self.conv.bias #We are copying the bias of the current conv layer, we don't want the bias to be scaled, only the weights
        self.conv.bias = None #we remove the bias (i don't fucking understand why hhhhh)

        #initlaise conv layer
        nn.init.normal_(self.conv.weight) #initalise from normal disterbution
        nn.init.zeros_(self.bias) #initialise with zeros

    def forward(self,x):
        return self.conv(x*self.scale)+self.bias.view(1,self.bias.shape[0],1,1)


### Pixel normalsiation

In [4]:
#Pixel norm class (instead of batch norm)
class PixelNorm(nn.Module):
    def __init__(self):
        super().__init__()
        #epsilon : 1e-8

    def forward(self,x):
        return x/torch.sqrt(torch.mean(x**2,dim=1,keepdim=True)+1e-8) #dim=1 : The mean accros the channels since dim=0 correspend to the batch

### ConvBlock

In [5]:
class ConvBlock(nn.Module):
    """
    This block will be used in the G and D
    We will use the conv2D using the equalized learning rate (initialisation)
    """
    def __init__(self,in_channels,out_channels,use_pixelnorm=True):
        super().__init__()
        self.conv1 = WSConv2d(in_channels,out_channels)
        self.conv2 = WSConv2d(out_channels,out_channels)
        self.leaky = nn.LeakyReLU(0.2)
        self.pn = PixelNorm()
        self.use_pn = use_pixelnorm

    def forward(self,x):
        x = self.leaky(self.conv1(x))
        x = self.pn(x) if self.use_pn else x
        x = self.leaky(self.conv2(x))
        x = self.pn(x) if self.use_pn else x
        return x

### Generator

In [6]:
class Generator(nn.Module):
    def __init__(self,z_dim,in_channels,img_channels=3):
        super().__init__()
        self.initial = nn.Sequential(
            PixelNorm(),
            nn.ConvTranspose2d(z_dim,in_channels,kernel_size=4,stride=1,padding=0), #It take 1*1 -> 4*4
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels,in_channels,kernel_size=3,stride=1,padding=1),
            nn.LeakyReLU(0.2),
            PixelNorm(),
        )

        self.initial_rgb = WSConv2d(in_channels,img_channels,kernel_size=1,stride=1,padding=0)
        self.prog_blocks,self.rgb_layers = nn.ModuleList([]),nn.ModuleList([self.initial_rgb]) #progressive blocs and rgb_layers

        for i in range(len(factors)-1):
            #factors[i] => factors[i+1]
            conv_in_c = int(in_channels*factors[i])
            conv_out_c = int(in_channels*factors[i+1])
            self.prog_blocks.append(ConvBlock(conv_in_c,conv_out_c))
            self.rgb_layers.append(WSConv2d(conv_out_c,img_channels,kernel_size=1,stride=1,padding=0))

    def fade_in(self,alpha,upscaled,generated):
        return torch.tanh(alpha*generated+(1-alpha)*upscaled) #[-1,1]

    def forward(self,x,alpha,steps): # step=0(4*4) steps=1 (8*8) ...
        out = self.initial(x) #4*4
        if steps == 0 :
            return self.initial_rgb(out)
        
        for step in range(steps):
            upscaled = F.interpolate(out,scale_factor=2,mode="nearest") #upscale
            out = self.prog_blocks[step](upscaled) #run throught the prog block

        final_upscaled = self.rgb_layers[steps-1](upscaled)
        final_out = self.rgb_layers[steps](out)
        return self.fade_in(alpha,final_upscaled,final_out)

### Critic

In [7]:
class Discriminator(nn.Module):
    def __init__(self,in_channels,img_channels=3):
        super().__init__()
        self.prog_blocks,self.rgb_layers = nn.ModuleList([]),nn.ModuleList([])
        self.leaky = nn.LeakyReLU(0.2)

        for i in range(len(factors)-1,0,-1):
            conv_in_c = int(in_channels*factors[i])
            conv_out_c = int(in_channels*factors[i-1])
            self.prog_blocks.append(ConvBlock(conv_in_c,conv_out_c,use_pixelnorm=False))
            self.rgb_layers.append(WSConv2d(img_channels,conv_in_c,kernel_size=1,stride=1,padding=0))

        #This for 4*4 img resolution
        self.initial_rgb = WSConv2d(img_channels,in_channels,kernel_size=1,stride=1,padding=0)
        self.rgb_layers.append(self.initial_rgb)
        self.avg_pool = nn.AvgPool2d(kernel_size=2,stride=2)

        #block for 4*4 resolution
        self.final_block = nn.Sequential(
            WSConv2d(in_channels+1,in_channels,kernel_size=3,stride=1,padding=1), #513*4*4 to 512*4*4 (last block)
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels,in_channels,kernel_size=4,stride=1,padding=0),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels,1,kernel_size=1,stride=1,padding=0) #In the paper they did a linear layer (he said that it's the same thing)
        )


    def fade_in(self,alpha,downscaled,out):
        return alpha*out+(1-alpha)*downscaled

    def minibatch_std(self,x):
        batch_statistics = torch.std(x,dim=0).mean().repeat(x.shape[0],1,x.shape[2],x.shape[3]) #The std of every example of x N*C*H*W ==> N
        return torch.cat([x,batch_statistics],dim=1)

    def forward(self,x,alpha,steps): # steps = 0 (4*4), steps=1 (8*8) , etc6
        cur_step = len(self.prog_blocks)-steps
        out = self.leaky(self.rgb_layers[cur_step](x))

        if steps == 0:
            out = self.minibatch_std(out)
            return self.final_block(out).view(out.shape[0],-1)
        
        downscaled = self.leaky(self.rgb_layers[cur_step+1](self.avg_pool(x)))
        out = self.avg_pool(self.prog_blocks[cur_step](out))
        out = self.fade_in(alpha,downscaled,out)

        for step in range(cur_step+1,len(self.prog_blocks)):
            out = self.prog_blocks[step](out)
            out = self.avg_pool(out)

        out = self.minibatch_std(out)
        return self.final_block(out).view(out.shape[0],-1)


### Test

In [11]:
Z_DIM = 50
IN_CHANNELS = 256
gen = Generator_pro(Z_DIM,IN_CHANNELS)
critic = Discriminator_pro(IN_CHANNELS)

for img_size in [4,8,16,32,64,128,256,512,1024]:
    num_steps = int(log2(img_size/4))
    x = torch.randn(1,Z_DIM,1,1)
    z = gen(x,0.5,steps=num_steps)
    assert z.shape == (1,3,img_size,img_size)
    out = critic(z,alpha=0.5,steps=num_steps)
    assert out.shape == (1,1)
    print(f"Succes at image size :{img_size}")

  


Succes at image size :4
Succes at image size :8
Succes at image size :16
Succes at image size :32
Succes at image size :64
Succes at image size :128
Succes at image size :256
Succes at image size :512
Succes at image size :1024


---

# Configuration

In [2]:

import cv2
import torch
from math import log2

START_TRAIN_AT_IMG_SIZE = 128
CHECKPOINT_GEN = "gen_pro.pth"
CHECKPOINT_CRITIC = "cri_pro.pth"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SAVE_MODEL = True
LOAD_MODEL = False
LEARNING_RATE = 1e-3
BATCH_SIZES = [32, 32, 32, 16, 16, 16, 16, 8, 4]
IMAGE_SIZE = 512
CHANNELS_IMG = 1
Z_DIM = 256  # should be 512 in original paper
IN_CHANNELS = 256  # should be 512 in original paper
CRITIC_ITERATIONS = 1
LAMBDA_GP = 10
NUM_STEPS = int(log2(IMAGE_SIZE/4))+1
PROGRESSIVE_EPOCHS = [20] * len(BATCH_SIZES)
FIXED_NOISE = torch.randn(8, Z_DIM, 1, 1).to(DEVICE)
NUM_WORKERS = 4

---

# Training

In [5]:
torch.backends.cudnn.benchmark = True

In [None]:
def get_loder(image_size):
    return get_loader(channels_img=CHANNELS_IMG, image_size=IMAGE_SIZE, batch_size=BATCH_SIZES[int(log2(image_size / 4))]) 

In [6]:
def train_fn(
        critic,
        gen,
        loader,
        dataset,
        step,
        alpha,
        opt_critic,
        opt_gen,
        tensorboard_step,
        writer,
        scaler_gen,
        scaler_critic,
    ):
    loop = tqdm(loader, leave=True)
    for batch_idx,(real,_) in enumerate(loop):
        real = real.to(DEVICE)
        cur_batch_size = real.shape[0]

        #Train Critic : max E[critic(real)] - E[critic(fake)] <-> min -E[critic(real)] + E[critic(fake)]
        
    

In [None]:
def main():
    gen = Generator_pro(Z_DIM, IN_CHANNELS,CHANNELS_IMG).to(DEVICE)
    critic = Discriminator_pro(IN_CHANNELS,CHANNELS_IMG).to(DEVICE)

    #initialise optimizers and scaler for FP16 training
    opt_gen = torch.optim.Adam(gen.parameters(),lr=LEARNING_RATE,betas=(0.0,0.99))
    opt_critic = torch.optim.Adam(critic.parameters(),lr=LEARNING_RATE,betas=(0.0,0.99))
    scaler_gen = torch.cuda.amp.GradScaler()
    scaler_critic = torch.cuda.amp.GradScaler()

    #for tensorboard plotting
    writer = SummaryWriter(f"logs/ProGAN")

    if LOAD_MODEL:
        tools.load_checkpoint(CHECKPOINT_GEN,gen,opt_gen,scaler_gen,LEARNING_RATE)
        tools.load_checkpoint(CHECKPOINT_CRITIC,critic,opt_critic,scaler_critic,LEARNING_RATE)

    gen.train()
    critic.train()
    tensorboard_step = 0
    step =int(log2(START_TRAIN_AT_IMG_SIZE/4)) 

    for num_epochs in PROGRESSIVE_EPOCHS[step:]:
        alpha = 1e-5
        loader,dataset = get_loder(4*(2**step))
        print(f"Working on img size : {4*(2**step)}")

        for epoch in range(num_epochs):
            print(f"Epoch [{epoch}/{num_epochs}]")
            tensorboard_step,alpha = train_fn(
                critic,
                gen,
                loader,
                dataset,
                step,
                alpha,
                opt_critic,
                opt_gen,
                tensorboard_step,
                writer,
                scaler_gen,
                scaler_critic,
            )

            if SAVE_MODEL:
                tools.save_checkpoint(gen,opt_gen,scaler_gen,filename=CHECKPOINT_GEN)
                tools.save_checkpoint(critic,opt_critic,scaler_critic,filename=CHECKPOINT_CRITIC)

            step += 1
            