# Lets just train a GAN which has all layers turned on but we only pass x through the num of current layers

In [124]:
# All the imports required for this implementation
import torch
import torchvision

import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.utils import spectral_norm
import torch.autograd as autograd

import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder

from torch.utils.data import TensorDataset, ConcatDataset, random_split, DataLoader, Dataset

from torchinfo import summary # Allows us to summarise the params and layers

import numpy as np
import matplotlib.pyplot as plt
import copy
import math
import random

In [2]:
# We can make use of a GPU if you have one on your computer. This works for Nvidia and M series GPU's
if torch.backends.mps.is_available():
    device = torch.device("mps")
    # These 2 lines assign some data on the memory of the device and output it. The output confirms
    # if we have set the intended device
    x = torch.ones(1, device=device)
    print (x)
elif torch.backends.cuda.is_built():
    device = torch.device("cuda")
    x = torch.ones(1, device=device)
    print (x)
else:
    device = ("cpu")
    x = torch.ones(1, device=device)
    print (x)

tensor([1.], device='cuda:0')


In [None]:
del x

In [3]:
# function to show an image
def imshow(img):
    img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

def show_images(images, num_images=16, figsize=(10,10)):
    # Ensure the input is on CPU
    images = images.cpu().detach()
    
    # Normalize images from [-1, 1] to [0, 1]
    images = (images + 1) / 2
    
    # Clamp values to [0, 1] range
    images = torch.clamp(images, 0, 1)
    
    # Make a grid of images
    grid = torchvision.utils.make_grid(images[:num_images], nrow=4)
    
    # Convert to numpy and transpose
    grid = grid.numpy().transpose((1, 2, 0))
    
    # Display the grid
    plt.figure(figsize=figsize)
    plt.imshow(grid)
    plt.axis('off')
    plt.show()

In [200]:
# Load the data and display some images
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((32,32)),  # Resize images to 32x32
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# In the paper batch size changes as the model scales up the images to save memory, however I think with modern
# equipment we can ignore this?
batch_size = 16

# To load the data you must move the images to a folder within the dir they are in
dataset = ImageFolder(root='./celeba_hq_256', transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True)

In [110]:
# Lets define the equalized LR conv and linear layers, from https://github.com/KimRass/PGGAN/blob/main/model.py#L26
class EqualLRLinear(nn.Module):
    def __init__(self, in_features, out_features, c=0.2):
        super().__init__()
        
        self.in_features = in_features
        self.out_features = out_features
        self.c = c
        
        self.scale = np.sqrt(c / in_features) # Per layer norm constant?
        
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        self.bias = nn.Parameter(torch.Tensor(out_features))
        
        nn.init.normal_(self.weight)
        nn.init.zeros_(self.bias)
        
    def forward(self, x):
        x = F.linear(x, weight=self.weight * self.scale, bias=self.bias)
        return x

class EqualLRConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, c=0.2):
        super().__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.c = c
        
        self.scale = (c / (in_channels * kernel_size[0] * kernel_size[1])) ** 0.5
        
        self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, kernel_size[0], kernel_size[1]))
        self.bias = nn.Parameter(torch.Tensor(out_channels))
        
        nn.init.normal_(self.weight)
        nn.init.zeros_(self.bias)
                
    def forward(self, x):
        x = F.conv2d(x, weight=self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)
        return x
    

In [111]:
# Lets create some tooling to assess how our GAN is doing 
# First idea compare the stddev and mean and the overall distribution of the Real and Gen pixels (using hist)
def real_and_gen_stats(real_images, gen_images):
    # Convert to numpy arrays
    real = real_images.cpu().detach().numpy()
    gen = gen_images.cpu().detach().numpy()

    # Reshape to (num_images * height * width, channels)
    real = real.reshape(-1, real.shape[1])
    gen = gen.reshape(-1, gen.shape[1])

    # Plot histograms
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    channel_names = ['Red', 'Green', 'Blue']
    for i in range(3):
        axs[i].hist(real[:, i], bins=50, alpha=0.5, label='Real', density=True)
        axs[i].hist(gen[:, i], bins=50, alpha=0.5, label='Generated', density=True)
        axs[i].set_title(f'{channel_names[i]} Channel Distribution')
        axs[i].legend()
    plt.show()

    # Calculate mean and std
    real_mean = np.mean(real, axis=0)
    real_std = np.std(real, axis=0)
    gen_mean = np.mean(gen, axis=0)
    gen_std = np.std(gen, axis=0)

    print("Real images - Mean:", real_mean, "Std:", real_std)
    print("Generated images - Mean:", gen_mean, "Std:", gen_std)


In [132]:
# Let's define a function which can generate the conv block
def d_conv_block(in_channels, out_channels, kernel_size1=None, kernel_size2=None):
    if kernel_size2 is not None:
        block = nn.Sequential(
            Mbatch_stddev(),
            #nn.Conv2d(in_channels, out_channels, kernel_size1, padding=(1,1)),
            EqualLRConv2d(in_channels, out_channels, kernel_size1, padding=(1,1)),
            nn.LeakyReLU(0.2),
            #nn.BatchNorm2d(out_channels),
            #nn.Conv2d(out_channels, out_channels, kernel_size2),
            EqualLRConv2d(out_channels, out_channels, kernel_size2),
            nn.LeakyReLU(0.2),
            #nn.BatchNorm2d(out_channels),
        )
    else:
        block = nn.Sequential(
            #nn.Conv2d(in_channels, in_channels, kernel_size1, padding=(1,1)),
            EqualLRConv2d(in_channels, in_channels, kernel_size1, padding=(1,1)),
            nn.LeakyReLU(0.2),
            #nn.BatchNorm2d(in_channels),
            #nn.Conv2d(in_channels, out_channels, kernel_size1, padding=(1,1)),
            EqualLRConv2d(in_channels, out_channels, kernel_size1, padding=(1,1)),
            nn.LeakyReLU(0.2),
            #nn.BatchNorm2d(out_channels),
            # Downsample
            nn.AvgPool2d(kernel_size=(2,2)),
        )
    
    return block

def g_conv_block(in_channels, out_channels, kernel_size1=None, kernel_size2=None, upsample=False):
    if upsample:
        block = nn.Sequential(
            #nn.Conv2d(in_channels, out_channels, kernel_size1, padding=(1,1)),
            EqualLRConv2d(in_channels, out_channels, kernel_size1, padding=(1,1)),
            #nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2),
            PixelNorm(),
            #nn.Conv2d(out_channels, out_channels, kernel_size1, padding=(1,1)),
            EqualLRConv2d(out_channels, out_channels, kernel_size1, padding=(1,1)),
            #nn.InstanceNorm2d(out_channels),
            #nn.LocalResponseNorm(x.size(0), alpha=1, beta=2, k=10e-8),
            nn.LeakyReLU(0.2),
            PixelNorm(),
        )
    else:
        block = nn.Sequential(
            #nn.Conv2d(in_channels, out_channels, kernel_size1, padding=(3,3)),
            EqualLRConv2d(in_channels, out_channels, kernel_size1, padding=(3,3)),
            nn.LeakyReLU(0.2),
            PixelNorm(),
            #nn.Conv2d(out_channels, out_channels, kernel_size2, padding=(1,1)),
            EqualLRConv2d(out_channels, out_channels, kernel_size2, padding=(1,1)),
            #nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2),
            PixelNorm(),
        )
    
    return block

def d_output_layer(input_dim):
    #layer = nn.Linear(input_dim, 1)
    layer = EqualLRLinear(input_dim, 1)
    return layer

def from_to_RGB(in_channels=None, out_channels=None):
    block = nn.Sequential(
        #nn.Conv2d(in_channels, out_channels, kernel_size=(1,1)),
        EqualLRConv2d(in_channels, out_channels, kernel_size=(1,1)),
        nn.LeakyReLU(0.2),
    )
    return block

def upsample(x):
    return nn.ConvTranspose2d(in_channels=channels, out_channels=channels, kernel_size=2, stride=2)

class Mbatch_stddev(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        '''N, _, H, W = x.shape
    
        # First calculate the stddev for each feature in each spatial location over the batch
        # Which means calculate the stddev of each feature map
        featuremap_stddevs = torch.std(x, dim=0, unbiased=False)
        # Then average these estimates over all features and spatial locations to arrive at a single value
        mean_stddev = torch.mean(featuremap_stddevs)
        stddev_featuremap = mean_stddev * torch.ones((N, 1, H, W), device=x.device)

        x = torch.cat((x, stddev_featuremap), dim=1)

        return x'''
        b, _, h, w = x.shape
        # "We compute the standard deviation for each feature in each spatial location over the minibatch.
        # We then average these estimates over all features and spatial locations to arrive at a single value.
        # We replicate the value and concatenate it to all spatial locations and over the minibatch,
        # yielding one additional (constant) feature map."
        feat_map = x.std(dim=0, keepdim=True).mean(dim=(1, 2, 3), keepdim=True)
        x = torch.cat([x, feat_map.repeat(b, 1, h, w)], dim=1)
        return x

class PixelNorm(nn.Module):
    def __init__(self, epsilon=1e-8):
        super(PixelNorm, self).__init__()
        self.epsilon = epsilon
    
    def forward(self, x):
        #square_sum = torch.pow(x, 2).sum(dim=1, keepdim=True)  # Sum across all channels
        #norm_factor = torch.sqrt(square_sum / x.size(1) + self.epsilon)  # Divide by N (number of channels)
        #return x / norm_factor  # Normalize
        x = x / torch.sqrt((x ** 2).mean(dim=1, keepdim=True)+ self.epsilon)
        return x

        

In [None]:
class Discriminator_32(nn.Module):
    def __init__(self):
        super().__init__()
        
        '''self.block9 = d_conv_block(in_channels=16, out_channels=32, kernel_size1=(3,3)).to(device)
        self.block8 = d_conv_block(in_channels=32, out_channels=64, kernel_size1=(3,3)).to(device)
        self.block7 = d_conv_block(in_channels=64, out_channels=128, kernel_size1=(3,3)).to(device)
        self.block6 = d_conv_block(in_channels=128, out_channels=256, kernel_size1=(3,3)).to(device)
        self.block5 = d_conv_block(in_channels=256, out_channels=512, kernel_size1=(3,3)).to(device)'''
        self.block4 = d_conv_block(in_channels=512, out_channels=512, kernel_size1=(3,3)).to(device)
        self.block3 = d_conv_block(in_channels=512, out_channels=512, kernel_size1=(3,3)).to(device)
        self.block2 = d_conv_block(in_channels=512, out_channels=512, kernel_size1=(3,3)).to(device)
        self.block1 = d_conv_block(in_channels=513, out_channels=512, kernel_size1=(3,3), kernel_size2=(4,4)).to(device)
        
        self.down = nn.AvgPool2d(kernel_size=(2,2), stride=2).to(device)  # This isnt used for the layers but the res connection
        
        '''self.from_rgb9 = from_to_RGB(in_channels=3, out_channels=16).to(device)
        self.from_rgb8 = from_to_RGB(in_channels=3, out_channels=32).to(device)
        self.from_rgb7 = from_to_RGB(in_channels=3, out_channels=64).to(device)
        self.from_rgb6 = from_to_RGB(in_channels=3, out_channels=128).to(device)
        self.from_rgb5 = from_to_RGB(in_channels=3, out_channels=256).to(device)'''
        self.from_rgb4 = from_to_RGB(in_channels=3, out_channels=512).to(device)
        self.from_rgb3 = from_to_RGB(in_channels=3, out_channels=512).to(device)
        self.from_rgb2 = from_to_RGB(in_channels=3, out_channels=512).to(device)
        self.from_rgb1 = from_to_RGB(in_channels=3, out_channels=512).to(device)
        
        self.FC1 = nn.Identity()
        
        '''self.blocks = [
            self.block1, self.block2, self.block3, self.block4, 
            self.block5, self.block6, self.block7, self.block8, self.block9,
        ]
        self.from_rgbs = [
            self.from_rgb1, self.from_rgb2, self.from_rgb3, self.from_rgb4,
            self.from_rgb5, self.from_rgb6, self.from_rgb7, self.from_rgb8, self.from_rgb9,
        ]'''
        
        self.blocks = [
            self.block1, self.block2, self.block3, self.block4,
            #self.block5, self.block6, self.block7,
        ]
        self.from_rgbs = [
            self.from_rgb1, self.from_rgb2, self.from_rgb3, self.from_rgb4,
            #self.from_rgb5, self.from_rgb6, self.from_rgb7,
        ]
    
    def forward(self, x, alpha=1, layer_num=0):
        in_x = torch.clone(x)
        x = self.from_rgbs[layer_num-1](x)
        
        for i in reversed(range(layer_num)):
            #print(f'Layer_num: {i}')
            #print(f'x before block: {x.shape}')
            #print(self.blocks[i])
            x = self.blocks[i](x)
            #print(f'x after block: {x.shape}')
            if i == layer_num-1 and alpha < 1 and layer_num > 1:
                # Fade in the new layer
                downscaled = self.down(in_x)
                from_rgb = self.from_rgbs[layer_num-2](downscaled)
                x = (alpha * x) + ((1 - alpha) * from_rgb)
        
        # Last FC layer
        x = x.view(x.size(0), -1) # Reshape the output, i.e. flatten it 
        self.FC1 = d_output_layer(x.size(1)).to(x.device)
        x = self.FC1(x)
                
        return x
        
d_32 = Discriminator_32() 
d_32 = d_32.to(device)

In [None]:
class Generator_32(nn.Module):
    def __init__(self):
        super().__init__()

        self.block1 = g_conv_block(in_channels=512, out_channels=512, kernel_size1=(4,4), kernel_size2=(3,3)).to(device)
        #self.up1 = upsample(512).to(device)
        self.block2 = g_conv_block(in_channels=512, out_channels=512, kernel_size1=(3,3), kernel_size2=(3,3), upsample=True).to(device)
        #self.up2 = upsample(512).to(device)
        self.block3 = g_conv_block(in_channels=512, out_channels=512, kernel_size1=(3,3), kernel_size2=(3,3), upsample=True).to(device)
        #self.up3 = upsample(512).to(device)
        self.block4 = g_conv_block(in_channels=512, out_channels=512, kernel_size1=(3,3), kernel_size2=(3,3), upsample=True).to(device)
        # Lets leave out the last 5 layers for 32x32 generations
        '''self.up4 = upsample(512).to(device)
        self.block5 = g_conv_block(in_channels=512, out_channels=256, kernel_size1=(3,3), kernel_size2=(3,3), upsample=True).to(device)
        #self.up5 = upsample(256).to(device)
        self.block6 = g_conv_block(in_channels=256, out_channels=128, kernel_size1=(3,3), kernel_size2=(3,3), upsample=True).to(device)
        #self.up6 = upsample(128).to(device)
        self.block7 = g_conv_block(in_channels=128, out_channels=64, kernel_size1=(3,3), kernel_size2=(3,3), upsample=True).to(device)
        #self.up7 = upsample(64).to(device)
        self.block8 = g_conv_block(in_channels=64, out_channels=32, kernel_size1=(3,3), kernel_size2=(3,3), upsample=True).to(device)
        #self.up8 = upsample(32).to(device)
        self.block9 = g_conv_block(in_channels=32, out_channels=16, kernel_size1=(3,3), kernel_size2=(3,3), upsample=True).to(device)'''
                
        self.to_rgb1 = from_to_RGB(in_channels=512, out_channels=3).to(device)
        self.to_rgb2 = from_to_RGB(in_channels=512, out_channels=3).to(device)
        self.to_rgb3 = from_to_RGB(in_channels=512, out_channels=3).to(device)
        self.to_rgb4 = from_to_RGB(in_channels=512, out_channels=3).to(device)
        '''self.to_rgb5 = from_to_RGB(in_channels=256, out_channels=3).to(device)
        self.to_rgb6 = from_to_RGB(in_channels=128, out_channels=3).to(device)
        self.to_rgb7 = from_to_RGB(in_channels=64, out_channels=3).to(device)
        self.to_rgb8 = from_to_RGB(in_channels=32, out_channels=3).to(device)
        self.to_rgb9 = from_to_RGB(in_channels=16, out_channels=3).to(device)'''
        
        self.tanh = nn.Tanh()
        
        '''self.blocks = [
            self.block1, self.block2, self.block3, self.block4,
            self.block5, self.block6, self.block7, self.block8, self.block9
        ]
        self.ups = [
            self.up1, self.up2, self.up3, 
            self.up4, self.up5, self.up6,
            self.up7, self.up8,
        ]
        self.to_rgbs = [
            self.to_rgb1, self.to_rgb2, self.to_rgb3, self.to_rgb4,
            self.to_rgb5, self.to_rgb6, self.to_rgb7, self.to_rgb8, self.to_rgb9,
        ]'''
        
        self.blocks = [
            self.block1, self.block2, self.block3, self.block4,
            #self.block5, self.block6, self.block7
        ]
        #self.ups = [
        #    self.up1, self.up2, self.up3, 
            #self.up4, self.up5, self.up6,
        #]
        self.to_rgbs = [
            self.to_rgb1, self.to_rgb2, self.to_rgb3, self.to_rgb4,
            #self.to_rgb5, self.to_rgb6, self.to_rgb7
        ]
        
    def forward(self, x, alpha=1, layer_num=0):
        for i in range(layer_num):
            x = self.blocks[i](x)
            if i < layer_num - 1:
                #x = self.ups[i](x)
                x = F.interpolate(x, scale_factor=2, mode="nearest")
            if i == layer_num - 2:
                res_x = torch.clone(x)
            
        out = self.to_rgbs[layer_num-1](x)
        
        if layer_num > 1 and alpha < 1:
            prev_rgb = self.to_rgbs[layer_num-2](res_x)
            
            # Interpolate between the two outputs
            out = (1 - alpha) * prev_rgb + alpha * out
        
        out = self.tanh(out)
        
        return out
    
g_32 = Generator_32()
g_32 = g_32.to(device)

In [None]:
d_in = torch.randn(2, 3, 32, 32).to(device)

# For 4x4 resolution (first stage)
out = d_32(d_in, alpha=0.5, layer_num=1)

# For 8x8 resolution with alpha=0.5
#out = d_32(d_in, alpha=0.5, layer_num=2)

# For 16x16 resolution
#out = d_32(d_in, alpha=0.5, layer_num=3)

# For  32x32 resolution
#out = d_32(d_in, alpha=0.5, layer_num=4)

# For 256x256 res
#out = d_32(d_in, alpha=0.5, layer_num=7)

print(out)

In [None]:
g_in = torch.randn((1, 512, 1, 1), device=device)

# For 4x4 resolution (first stage)
#out = g_32(g_in, alpha=0, layer_num=1)

# For 8x8 resolution with alpha=0.5
#out = g_32(g_in, alpha=0.5, layer_num=2)

# For 16x16 resolution
#out = g_32(g_in, alpha=0.5, layer_num=3)

# For full 32x32 resolution
out = g_32(g_in, alpha=0.5, layer_num=4)

#out = g_32(g_in, alpha=0.5, layer_num=7)

print(out.shape)

In [8]:
class WGAN_GP_Loss(nn.Module):
    def __init__(self, lambda_gp=10, epsilon_drift=0.001):
        super().__init__()
        self.lambda_gp = lambda_gp
        self.epsilon_drift = epsilon_drift

    def compute_gradient_penalty(self, discriminator, real_samples, fake_samples, alpha, layer_num):
        batch_size = real_samples.size(0)
        epsilon = torch.rand(batch_size, 1, 1, 1).to(real_samples.device)
        interpolates = (epsilon * real_samples + ((1 - epsilon) * fake_samples)).requires_grad_(True)
        d_interpolates = discriminator(interpolates, alpha, layer_num)
        fake = torch.ones(batch_size, 1).to(real_samples.device)
        gradients = autograd.grad(
            outputs=d_interpolates,
            inputs=interpolates,
            grad_outputs=fake,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        gradients = gradients.view(batch_size, -1)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        return gradient_penalty

    def forward(self, discriminator, real_imgs, fake_imgs, alpha, layer_num):
        real_validity = discriminator(real_imgs, alpha, layer_num)
        fake_validity = discriminator(fake_imgs, alpha, layer_num)
        
        gradient_penalty = self.compute_gradient_penalty(discriminator, real_imgs, fake_imgs, alpha, layer_num)
        
        # Add drift penalty
        drift_penalty = self.epsilon_drift * torch.mean(real_validity**2)
        
        d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + self.lambda_gp * gradient_penalty + drift_penalty
        g_loss = -torch.mean(fake_validity)
        #g_loss = -fake_validity.mean() * 10  # Scale the loss

        
        return d_loss, g_loss

In [9]:
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.normal_(m.weight, mean=0.0, std=0.02)
        if m.bias is not None:
            nn.init.zeros_(m.bias)

In [None]:
class SimpleGenerator(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(512, 3*4*4)
    def forward(self, x, alpha=1, layer_num=1):
        return self.fc(x.view(x.size(0), -1)).view(-1, 3, 4, 4)

g_32 = SimpleGenerator().to(device)

In [201]:
# Lets build a training loop using just BCELoss and see what happens
# For intial experiment I will use BCELoss however the actual paper uses: https://arxiv.org/abs/1704.00028
#criterion = nn.BCEWithLogitsLoss()
criterion = WGAN_GP_Loss()

d_32 = Discriminator_32() 
d_32.apply(weights_init)
d_32 = d_32.to(device)

g_32 = Generator_32() 
#g_32 = SimpleGenerator()
g_32.apply(weights_init)
g_32 = g_32.to(device)


# Intialise two optimisers
optim_D = torch.optim.Adam(d_32.parameters(), lr=0.0001, betas=(0, 0.99), eps=10**(-8))
optim_G = torch.optim.Adam(g_32.parameters(), lr=0.0001, betas=(0, 0.99), eps=10**(-8))

latent_dim = (batch_size, 512, 1, 1)

In [202]:
def check_zero_biases(model):
    zero_bias = True
    for name, param in model.named_parameters():
        if 'bias' in name:
            if param.data.sum() != 0:
                print(f"Non-zero bias found in {name}: sum = {param.data.sum().item()}")
                zero_bias = False
    
    if zero_bias:
        print("All biases are zero.")
    else:
        print("Some biases are non-zero.")

def compute_gradient_norm(model):
    total_norm = 0
    for p in model.parameters():
        if p.grad is not None:
            param_norm = p.grad.data.norm(2)
            total_norm += param_norm.item() ** 2
    total_norm = total_norm ** 0.5
    return total_norm     

def check_gradients(model):
    print(model)
    for name, param in model.named_parameters():
        if param.grad is None:
            print(f"No gradient for {name}")
        elif param.grad.abs().sum() == 0:
            print(f"Zero gradient for {name}")
        else:
            print(f"Gradient present for {name}: {param.grad.abs().mean().item()}")
            
# Usage:
check_zero_biases(g_32)  # Check generator
check_zero_biases(d_32)  # Check discriminator

All biases are zero.
All biases are zero.


In [203]:
def hypersphere(z, radius=1):
    return z * radius / z.norm(p=2, dim=1, keepdim=True)

In [None]:
for layer in range(1,5):
    print(f'Training layer: {layer}')
    alpha = 0
    
    for epoch_grow in range(50):
        for i, data in enumerate(dataloader):
            real_images, _ = data
            real_images = real_images.to(device)
            #n_critic = 5
            #for n in range(n_critic):
            #    real_images = sample_batch(dataloader.dataset, dataloader.batch_size).to(device)
                
            noise_tensor = torch.randn(latent_dim, device=device)

            with torch.no_grad():
                gen_images = g_32(noise_tensor, alpha=alpha, layer_num=layer)

            real_images = F.interpolate(real_images, size=gen_images.shape[2:], mode='area')

            #gen_labels = torch.zeros((batch_size, 1)).to(device)
            #real_labels = torch.ones((batch_size, 1)).to(device)

            #combined_images = torch.cat((real_images, gen_images))
            #combined_labels = torch.cat((real_labels, gen_labels))

            # First update the D model
            d_32.zero_grad(set_to_none=False)
            #d_outputs_combined = d_32(combined_images, alpha=alpha, layer_num=layer)
            #loss_d = criterion(d_outputs_combined, combined_labels)
            loss_d, _ = criterion(d_32, real_images, gen_images, alpha, layer)
            loss_d.backward()
            optim_D.step()
            
            d_grad_norm = compute_gradient_norm(d_32)
            
            # Generate new images for updating G
            noise_tensor = torch.randn(latent_dim, device=device)
            
            # Next update the G model, 
            g_32.zero_grad(set_to_none=False)
            gen_images = g_32(noise_tensor, alpha=alpha, layer_num=layer)  # This needs to be on
            #d_outputs_generated = d_32(gen_images, alpha=alpha, layer_num=layer)
            #loss_g = criterion(d_outputs_generated, real_labels)
            _, loss_g = criterion(d_32, real_images, gen_images, alpha, layer)
            #print(f'Loss_D: {loss_d.item()}, Loss_G: {loss_g.item()}')
            
            #print(f"G loss before backward: {loss_g.item()}")            
            loss_g.backward()
            #print(f"G loss after backward: {loss_g.item()}")
        
            #check_gradients(g_32)
            optim_G.step()
            
            g_grad_norm = compute_gradient_norm(g_32)
            
        #imshow(torchvision.utils.make_grid(gen_images.cpu()))
        
        if epoch_grow % 5 == 0:
            print(f'Epoch: {epoch_grow} Outputting statistics: ')
            real_and_gen_stats(real_images, gen_images)
            show_images(gen_images)
            print(f'Layer {layer}: Loss_D: {loss_d.item()}, Loss_G: {loss_g.item()}')
            print(f'D Grad Norm : {d_grad_norm:.4f}, G Grad Norm: {g_grad_norm:.4f}')
        
        alpha += 1/50
        alpha = round(alpha, 2)
        
    print(f'Alpha after grow: {alpha}')
    for epoch_train in range(50):
        for i, data in enumerate(dataloader):
            real_images, _ = data
            real_images = real_images.to(device)

            noise_tensor = torch.randn(latent_dim, device=device)
            
            with torch.no_grad():
                gen_images = g_32(noise_tensor, alpha=alpha, layer_num=layer)

            real_images = F.interpolate(real_images, size=gen_images.shape[2:], mode='area')

            #gen_labels = torch.zeros((batch_size, 1)).to(device)
            #real_labels = torch.ones((batch_size, 1)).to(device)

            #combined_images = torch.cat((real_images, gen_images))
            #combined_labels = torch.cat((real_labels, gen_labels))
            
            # First update the D model
            d_32.zero_grad()   
            #d_outputs_combined = d_32(combined_images, alpha=alpha, layer_num=layer)
            #loss_d = criterion(d_outputs_combined, combined_labels)
            loss_d, _ = criterion(d_32, real_images, gen_images, alpha, layer)
            loss_d.backward()
            optim_D.step()
            
            # Generate new images for updating G
            noise_tensor = torch.randn(latent_dim, device=device)

            # Next update the G model, 
            g_32.zero_grad()
            gen_images = g_32(noise_tensor, alpha=alpha, layer_num=layer)
            #d_outputs_generated = d_32(gen_images, alpha=alpha, layer_num=layer)
            #loss_g = criterion(d_outputs_generated, real_labels)
            _, loss_g = criterion(d_32, real_images, gen_images, alpha, layer)
            loss_g.backward()
            optim_G.step()
            
    
    print(f'FINAL | Layer {layer}: Loss_D: {loss_d.item()}, Loss_G: {loss_g.item()}')
    #imshow(torchvision.utils.make_grid(real_images.cpu()))
    #imshow(torchvision.utils.make_grid(gen_images.cpu()))
    show_images(real_images)
    show_images(gen_images)
    

Training layer: 1


In [None]:
def test_generator_grad_flow(g_32):
    noise = torch.randn(1, 512, 1, 1, requires_grad=True).to(device)
    output = g_32(noise, alpha=1, layer_num=4)
    loss = output.sum()
    loss.backward()
    for name, param in g_32.named_parameters():
        print(f"{name}: {param.grad is not None}")

test_generator_grad_flow(g_32)

In [None]:
print(f"Optimizer params: {len(optim_G.param_groups[0]['params'])}")
print(f"Generator params: {len(list(g_32.parameters()))}")

In [None]:
gen_images = g_32(noise_tensor, alpha=alpha, layer_num=layer)
print(f"Gen images shape: {gen_images.shape}")
print(f"Gen images requires grad: {gen_images.requires_grad}")
print(f"Gen images contains NaN: {torch.isnan(gen_images).any()}")

In [None]:
print("g_32 grads")
for name, param in g_32.named_parameters():
    if param.grad is None:
        print(f"No gradient for {name}")
    elif param.grad.abs().sum() == 0:
        print(f"Zero gradient for {name}")

print("\n" ,"d_32 grads")
for name, param in d_32.named_parameters():
    if param.grad is None:
        print(f"No gradient for {name}")
    elif param.grad.abs().sum() == 0:
        print(f"Zero gradient for {name}")


In [175]:
d_32

Discriminator_32(
  (block4): Sequential(
    (0): EqualLRConv2d()
    (1): LeakyReLU(negative_slope=0.2)
    (2): EqualLRConv2d()
    (3): LeakyReLU(negative_slope=0.2)
    (4): AvgPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0)
  )
  (block3): Sequential(
    (0): EqualLRConv2d()
    (1): LeakyReLU(negative_slope=0.2)
    (2): EqualLRConv2d()
    (3): LeakyReLU(negative_slope=0.2)
    (4): AvgPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0)
  )
  (block2): Sequential(
    (0): EqualLRConv2d()
    (1): LeakyReLU(negative_slope=0.2)
    (2): EqualLRConv2d()
    (3): LeakyReLU(negative_slope=0.2)
    (4): AvgPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0)
  )
  (block1): Sequential(
    (0): Mbatch_stddev()
    (1): EqualLRConv2d()
    (2): LeakyReLU(negative_slope=0.2)
    (3): EqualLRConv2d()
    (4): LeakyReLU(negative_slope=0.2)
  )
  (down): AvgPool2d(kernel_size=(2, 2), stride=2, padding=0)
  (from_rgb4): Sequential(
    (0): EqualLRConv2d()
    (1): LeakyReLU(negati

Lets make some tests to ensure the training is going smoothly, these come from https://github.com/soumith/ganhacks

![image.png](attachment:image.png)