In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision.utils import save_image, make_grid
from torch.utils.data import DataLoader, Dataset
import time
import os
from PIL import Image
from math import log2

In [2]:
#Dataset Handler Class, as well as model saving and loading
class CelebAHQDataset(Dataset):
    def __init__(self, root_dir, transform):
        self.root_dir = root_dir
        self.transform = transform
        self.image_filenames = sorted(os.listdir(root_dir))
    
    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.image_filenames[idx])
        image = Image.open(img_name).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        return image
    
def get_transform(resolution):
    return transforms.Compose([
        transforms.Resize((resolution, resolution)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor()])

def get_data_loader(resolution):
    batch_size = BATCH_SIZES[int(log2(resolution / 4))]
    dataset = CelebAHQDataset(root_dir='data/celeba-hq/images/', transform=get_transform(resolution = resolution))
    loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    return loader, dataset

def save_checkpoint(model, optimizer, filename):
    print('=> Saving Checkpoint')

    checkpoint = {'state_dict': model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict()}
    torch.save(checkpoint, filename)

def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print('=> Loading Checkpoint')

    checkpoint = torch.load(checkpoint_file, map_location='cuda')
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer']) # else 'optimizer'

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [3]:
# Overall Model architecture implementation 
# https://github.com/aladdinpersson/Machine-Learning-Collection/tree/master/ML/Pytorch/GANs/ProGAN

class WeightedConv2D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=2):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (gain / (in_channels * (kernel_size ** 2))) ** 0.5
        self.bias = self.conv.bias
        self.conv.bias = None

        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)
    
class PixelNorm(nn.Module):
    def  __init__(self):
        super(PixelNorm, self).__init__()
        self.epsilon = 1e-8
        
    def forward(self, x):
        return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)
    
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_pixelnorm=True):
        super().__init__()
        self.use_pn = use_pixelnorm
        self.conv1 = WeightedConv2D(in_channels, out_channels)
        self.conv2 = WeightedConv2D(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2)
        self.pn = PixelNorm()

    def forward(self, x):
        x = self.leaky(self.conv1(x))
        x = self.pn(x) if self.use_pn else x
        x = self.leaky(self.conv2(x))
        x = self.pn(x) if self.use_pn else x
        return x

class Generator(nn.Module):
    def __init__(self, z_dim, in_channels, img_channels=3):
        super().__init__()
        self.initial = nn.Sequential(
            PixelNorm(),
            nn.ConvTranspose2d(in_channels=z_dim, out_channels=in_channels, kernel_size=4, stride=1, padding=0),
            nn.LeakyReLU(0.2),
            WeightedConv2D(in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2),
            PixelNorm(),
        )

        self.initial_rgb = WeightedConv2D(in_channels=in_channels, out_channels=img_channels, kernel_size=1, stride=1, padding=0)
        self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([self.initial_rgb])

        for i in range(len(factors) - 1):
            conv_in_c = int(in_channels * factors[i])
            conv_out_c = int(in_channels * factors[i+1])
            self.prog_blocks.append(ConvBlock(conv_in_c, conv_out_c))
            self.rgb_layers.append(WeightedConv2D(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0))
        
    def fade_in(self, alpha, upscaled, generated):
        return torch.tanh(alpha * generated + (1 - alpha) * upscaled)
    
    def forward(self, x, alpha, steps):
        out = self.initial(x)
        if steps == 0:
            return self.initial_rgb(out)
        
        for step in range(steps):
            upscaled = F.interpolate(out, scale_factor=2, mode='nearest')
            out = self.prog_blocks[step](upscaled)

        final_upscaled = self.rgb_layers[steps - 1](upscaled)
        final_out = self.rgb_layers[steps](out)

        return self.fade_in(alpha, final_upscaled, final_out)

class Discriminator(nn.Module):
    def __init__(self, in_channels, img_channels=3):
        super().__init__()
        self.prog_blocks, self.rgb_layers= nn.ModuleList(), nn.ModuleList()
        self.leaky = nn.LeakyReLU(0.2)

        for i in range(len(factors) -1, 0, -1):
            conv_in_c = int(in_channels * factors[i])
            conv_out_c = int(in_channels * factors[i -1])
            self.prog_blocks.append(ConvBlock(in_channels=conv_in_c, out_channels=conv_out_c, use_pixelnorm=False))
            self.rgb_layers.append(WeightedConv2D(in_channels=img_channels, out_channels=conv_in_c, kernel_size=1, stride=1, padding=0))
        
        self.initial_rgb = WeightedConv2D(img_channels, in_channels, kernel_size=1, stride=1, padding=0)
        self.rgb_layers.append(self.initial_rgb)
        self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)

        self.final_block = nn.Sequential(
            WeightedConv2D(in_channels=in_channels + 1, out_channels= in_channels, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2),
            WeightedConv2D(in_channels=in_channels, out_channels=in_channels, kernel_size=4, stride=1, padding=0),
            nn.LeakyReLU(0.2),
            WeightedConv2D(in_channels=in_channels, out_channels=1, kernel_size=1, stride=1, padding=0),
        )

    def fade_in(self, alpha, downscaled, out):
        return alpha * out + (1 - alpha) * downscaled
    
    def minibatch_std(self, x):
        batch_statistics = torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
        return torch.cat([x, batch_statistics], dim=1)
    
    def forward(self, x, alpha, steps):
        cur_step = len(self.prog_blocks) - steps
        out = self.leaky(self.rgb_layers[cur_step](x))

        if steps == 0:
            out = self.minibatch_std(out)
            return self.final_block(out).view(out.shape[0], -1)
        
        downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
        out = self.avg_pool(self.prog_blocks[cur_step](out))
        out = self.fade_in(alpha, downscaled, out)

        for step in range(cur_step + 1, len(self.prog_blocks)):
            out = self.prog_blocks[step](out)
            out = self.avg_pool(out)
        
        out = self.minibatch_std(out)
        return self.final_block(out).view(out.shape[0], -1)

In [4]:
#WS-Loss with Gradient Penalty
def wasserstein_loss(d_real, d_fake):
    return torch.mean(d_fake) - torch.mean(d_real)

def gradient_panelty(discriminator, real_samples, fake_samples, alpha, step, device):
    batch_size, c, h, w = real_samples.shape
    beta = torch.rand((batch_size, 1, 1, 1)).repeat(1, c, h, w).to(device)
    interpolates = (beta * real_samples + ((1 - beta) * fake_samples.detach()))
    interpolates.requires_grad = True
    
    d_interpolate = discriminator(interpolates, alpha, step)
    fake = torch.ones(real_samples.size(0), 1).to(device)
    gradients = torch.autograd.grad(outputs=d_interpolate, inputs=interpolates, grad_outputs=fake, create_graph=True, retain_graph=True, only_inputs=True)[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_panelty = ((gradients.norm(2, dim=1) -1) ** 2).mean()
    return gradient_panelty

In [5]:

def trainer(num_epochs, generator, discriminator, dataloader, dataset, optimizer_gen, optimizer_critic, z_dim, step, alpha, scaler_gen, scaler_critic, logging_interval, save_interval):
    for epoch in range(num_epochs):
        data_dict = {'Generator_losses':[],
                'Discriminator_losses':[]}
        
        for batch_idx, real_images in enumerate(dataloader):
            real_images = real_images.to(DEVICE)
            cur_batch_size = real_images.shape[0]

            z = torch.randn(cur_batch_size, z_dim, 1, 1).to(DEVICE)
            
            #Train Descriminator
            with torch.cuda.amp.autocast():
                fake_image = generator(z, alpha, step)
                real_output = discriminator(real_images, alpha, step)
                fake_output = discriminator(fake_image.detach(), alpha, step)
                
                #WS Loss with gradient panelty
                gp = gradient_panelty(discriminator=discriminator, real_samples=real_images, fake_samples=fake_image, alpha=alpha, step=step, device=DEVICE)
                w_loss = wasserstein_loss(d_real=real_output, d_fake=fake_output)
                loss_critic = w_loss + LAMBDA_GP * gp + (0.001 * torch.mean(real_output ** 2))

            optimizer_critic.zero_grad()
            scaler_critic.scale(loss_critic).backward()
            scaler_critic.step(optimizer_critic)
            scaler_critic.update()

            # Train Generator
            with torch.cuda.amp.autocast():
                gen_fake = discriminator(fake_image, alpha, step)
                gen_loss = -torch.mean(gen_fake)
            
            optimizer_gen.zero_grad()
            scaler_gen.scale(gen_loss).backward()
            scaler_gen.step(optimizer_gen)
            scaler_gen.update()
        
            alpha += cur_batch_size / (
                 (PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset)
            )
            alpha = min(alpha, 1)
            
            if batch_idx % logging_interval == 0:
                    print(f"Epoch [{epoch+1}/{num_epochs}],Of Batch [{batch_idx}/{len(dataloader)}], "
                        f"Genenerator Loss: {gen_loss.item():.4f}, Discriminator Loss: {loss_critic.item():.4f}")
           
            if batch_idx % save_interval == 0:
                 with torch.no_grad():
                      z = torch.randn(16, z_dim, 1, 1).to(DEVICE)
                      generated_imgs = generator(z, alpha, step).detach().cpu()
                      save_image(generated_imgs, fp='data/generated_img_256/Resolution_'+ str(current_resolution) + '_epoch_'+ str(epoch) +'.png', nrow=4, normalize=True)

            data_dict['Generator_losses'].append(gen_loss.item())
            data_dict['Discriminator_losses'].append(loss_critic.item())
    return data_dict

In [6]:
#Hyper Parameters
CUDA_DEVICE_NUM = 0
DEVICE = torch.device(f'cuda:{CUDA_DEVICE_NUM}' if torch.cuda.is_available() else 'cpu')
print('Device:', DEVICE)
LR = 0.001
BATCH_SIZES = [32, 32, 16, 16, 16, 16, 8, 4]
PROGRESSIVE_EPOCHS = [10] * len(BATCH_SIZES)
START_TRAIN_AT_IMG_SIZE = 256
Z_DIM = 256
IN_CHANNELS = 256
LAMBDA_GP = 10
MAX_RESOLUTION = 512
factors = [1, 1, 1, 1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32]


gen = Generator(z_dim=Z_DIM, in_channels=IN_CHANNELS,img_channels=3).to(DEVICE)
critic = Discriminator(in_channels=IN_CHANNELS, img_channels=3).to(DEVICE)

optimizer_gen = optim.Adam(gen.parameters(), lr=LR, betas=(0.0, 0.99))
optimizer_critic = optim.Adam(critic.parameters(), lr=LR, betas=(0.0,0.99))


gen.train()
critic.train()

scaler_gen = torch.cuda.amp.GradScaler()
scaler_critic = torch.cuda.amp.GradScaler()

Device: cuda:0
=> Loading Checkpoint


In [None]:
step = int(log2(START_TRAIN_AT_IMG_SIZE / 4))

epoch_counter = 0
start_time = time.time()
for num_epochs in PROGRESSIVE_EPOCHS[step:]:
    epoch_counter += num_epochs
    current_resolution = 4 * 2 ** step
    if current_resolution > MAX_RESOLUTION:
        break
    alpha = 1e-5
    loader, dataset = get_data_loader(4 * 2 ** step) # 4->0, 8->1, 16->2, 32->3, 64 -> 4
    print(f"Current image size: {4 * 2 ** step}")

    train = trainer(num_epochs=num_epochs, generator=gen, discriminator=critic, dataloader=loader, dataset=dataset, optimizer_gen=optimizer_gen,
                    optimizer_critic= optimizer_critic, z_dim=Z_DIM, step=step, alpha=alpha, scaler_gen=scaler_gen, scaler_critic=scaler_critic, 
                    logging_interval=700, save_interval=2500)
    step+=1
    print(f"Time elapsed: {(time.time() - start_time)/60:.2f} minutes")

    save_checkpoint(gen, optimizer_gen, filename=f'data/pro_gan_models/Checkpoint_Gen_Resolution_{current_resolution}_.pth')
    save_checkpoint(critic, optimizer_critic, filename=f'data/pro_gan_models/Checkpoint_Critic_Resolution_{current_resolution}_.pth')
end_time = time.time()

print(f"Training finished. Total time: {(end_time - start_time)/60:.2f} minutes")
print(f"Total epochs trained: {epoch_counter}")

In [None]:
# Plot generator and discriminator losses
generator_losses = train['Generator_losses']
discriminator_losses = train['Discriminator_losses']


# Calculate the average loss per epoch or batch
avg_generator_losses = [sum(generator_losses[:i+1]) / len(generator_losses[:i+1]) for i in range(len(generator_losses))]
avg_discriminator_losses = [sum(discriminator_losses[:i+1]) / len(discriminator_losses[:i+1]) for i in range(len(discriminator_losses))]

# Plot the learning curve
plt.figure(figsize=(10, 6))
plt.plot(avg_generator_losses, label='Average Generator Loss', color='blue')
plt.plot(avg_discriminator_losses, label='Average Discriminator Loss', color='orange')
plt.xlabel('Epochs or Batches')
plt.ylabel('Average Loss')
plt.title('Learning Curve')
plt.legend()
plt.grid(True)
plt.show()

In [26]:
def generate_examples(gen, steps, n=50):
    gen.eval()
    alpha = 1.0
    
    # Generate random latent vectors
    latent_vectors = torch.randn(n, 256, 1, 1, device=DEVICE)
    
    # Generate examples
    with torch.no_grad():
        images = gen(latent_vectors, alpha, steps)
    
    # Post-process the generated images (optional)
    images = (images * 0.5) + 0.5  # Scale images to [0, 1] range
    
    # Save the images
    save_image(images, 'data/generated_examples.png', nrow=3)  # Save as a 3x3 grid
    #plt.figure(figsize=(10, 10))
    #grid = make_grid(images, nrow=3, normalize=True, scale_each=True)
    #plt.imshow(grid.permute(1, 2, 0).cpu().numpy())
    #plt.axis('off')
    #plt.show()
    gen.train()

In [37]:
generate_examples(gen, steps=int(log2(START_TRAIN_AT_IMG_SIZE / 4)))