# Imports

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.amp import autocast, GradScaler

from PIL import Image
import matplotlib.pyplot as plt
import os, datetime, shutil

# Device

In [None]:
# --- Check Device Usage ---

print("CUDA available:", torch.cuda.is_available())
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
isCUDA = True if device == torch.device('cuda') else False
print("Device:", device)

# --- Device Configuration ---

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # Removes pytorch memory allocation problem
example = torch.randn(1).cuda() if isCUDA else 0 # Triggers CUDA context initialization

# Constants

In [None]:
DATA_DIR = "/kaggle/input/animefacedataset/images"

# Same noise to see the changes over epochs
valid_plot_noise = torch.randn(3, 5, 1, 100, dtype=torch.float32).to(device)

# Model weights save

In [None]:
# Model weights save file for both generator and critic
model_folder_path = "./model_save_files"

# Remove the folder if exists to start new iteration
shutil.rmtree(model_folder_path, ignore_errors=True)

# Create the folder
os.makedirs(model_folder_path, exist_ok=True)

# Hyper parameters

In [None]:
learning_rate = 0.0002
num_epochs = 40
betas = (0.0, 0.9) # Standard GAN betas

num_workers = 4 if isCUDA else 0
persistent_workers = True
batch_size = 128

latent_dim = 100 # Higher value = Higher diversity of images

# Utilities

In [None]:
# Custom dataset
class AnimeDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.transform = transform
        self.data_dir = data_dir
        self.images = os.listdir(self.data_dir) # All images inside dataset
        
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = os.path.join(self.data_dir, self.images[idx])
        img = Image.open(image_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
            
        # Return standard type of output (input, label) but label not required for WGANs
        return img, torch.tensor([1.0], dtype=torch.float32)

class PixelNorm(nn.Module):
    '''
    PixelNorm normalizes over channels which provides stability 
    over BatchNorm which relies on other images in the batch for 
    normalization.
    '''
    def __init__(self):
        super().__init__()

    def forward(self, x):
        # Normalize across channels
        return x / torch.sqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)

# Current time (in H:M:S format)
def now():
    # Change the timezone referred to the UTC
    ist_timezone = datetime.timezone(datetime.timedelta(hours=5, minutes=30))
    return datetime.datetime.now().astimezone(ist_timezone).strftime('%H:%M:%S')


# Plotting for visual validation
def valid_plot(model, epoch=None, rows=3, cols=5, latent_dim=100):
    with torch.no_grad():

        # Initialize subplots
        _, axs = plt.subplots(rows, cols, figsize=((10 / 3) * cols, rows*3))

        # Display the epoch number if exists
        if epoch:
            plt.suptitle(f"Epoch: {epoch}")
            
        for i in range(rows):
            for j in range(cols):

                # Generate the image and denormalize
                noise = valid_plot_noise[i][j]
                image = model(noise)[0].permute(1, 2, 0).to('cpu').detach().numpy()
                image = (image + 1) / 2 # Denormalize

                # Plot the images subplot wise
                axs[i][j].imshow(image)
                axs[i][j].get_xaxis().set_visible(False)
                axs[i][j].get_yaxis().set_visible(False)
        plt.show()

# Custom weights initialization
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):

        # Initialize to prevent exploding gradients
        torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)

# Custom padding function
def zero_pad(x, pad_len=3):

    # Make sure it is a string
    x = str(x)
    x_len = len(x)

    # Add zeros until pad_len
    for _ in range(pad_len - x_len):
        x = "0" + x
    return x

# Data Pre-process

In [None]:
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
    
transform = transforms.Compose([
    transforms.Resize((64, 64)),                 
    transforms.ToTensor(),                         
    transforms.Normalize(mean, std)
])

# Dataset
img_dataset = AnimeDataset(DATA_DIR, transform)

# Divide the dataset (numbers)
dataset_size = len(img_dataset[0])

# Load the images
train_loader = DataLoader(img_dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, pin_memory = True, persistent_workers = persistent_workers)

# Logging
print(f"Sizes of Train: {dataset_size}")

# Testing Images

In [None]:
for images, labels in train_loader:
    # Generate a single image
    image = images[0].permute(1, 2, 0).cpu().numpy()
    image_min, image_max = image.min().item(), image.max().item()
    image = (image + 1) / 2 # Denormalize
    image_min2, image_max2 = image.min().item(), image.max().item()

    # Plot the image
    plt.imshow(image)
    plt.show()

    # Test max and min values of image before and after
    print(f"Min: {image_min}, Max: {image_max}")
    print(f"Min: {image_min2}, Max: {image_max2}")
    break

# Critic

In [None]:
class Critic(nn.Module):
    def __init__(self):
        super().__init__()
        self.convA = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),      # [B, 64, 32, 32]
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(64, 128, 4, 2, 1),    # [B, 128, 16, 16]
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(128, 256, 4, 2, 1),   # [B, 256, 8, 8]
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(256, 512, 4, 2, 1),   # [B, 512, 4, 4]
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(512, 1024, 4, 2, 1),  # [B, 1024, 2, 2]
            nn.LeakyReLU(0.2),
            nn.Flatten()
        )
        # Two different paths to extract different features
        self.convB = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1),      # [B, 64, 64, 64]
            nn.LeakyReLU(0.2),
            nn.MaxPool2d(2, 2),             # [B, 64, 32, 32]
            
            nn.Conv2d(64, 128, 3, 1, 1),    # [B, 128, 32, 32]
            nn.LeakyReLU(0.2),
            nn.MaxPool2d(2, 2),             # [B, 256, 16, 16]
            
            nn.Conv2d(128, 256, 3, 1, 1),   # [B, 256, 16, 16]
            nn.LeakyReLU(0.2),
            nn.MaxPool2d(2, 2),             # [B, 256, 8, 8]
            
            nn.Conv2d(256, 512, 3, 1, 1),   # [B, 512, 8, 8]
            nn.LeakyReLU(0.2),
            nn.MaxPool2d(2, 2),             # [B, 512, 4, 4]
            
            nn.Conv2d(512, 1024, 3, 1, 1),  # [B, 1024, 4, 4]
            nn.LeakyReLU(0.2),
            nn.MaxPool2d(2, 2),             # [B, 1024, 2, 2]
            nn.Flatten()                    # [B, 4096]
        )
        self.fc = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(2048*2*2, 1)          # [B, 1]
        )

    def forward(self, x):
        xA = self.convA(x)
        xB = self.convB(x)
        x = torch.cat((xA, xB), dim=1)
        return self.fc(x)

# Generator

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim = 100):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(latent_dim, 1024*1*1),  # [B, 1024]
            nn.LeakyReLU(0.2),
            
            nn.Linear(1024*1*1, 512*4*4),     # [B, 8192]
            nn.LeakyReLU(0.2),
            nn.Unflatten(1, (512, 4, 4))      # [B, 512, 4, 4]
        )
        self.convA = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
            nn.Conv2d(512, 256, 3, 1, 1),     # [B, 512, 8, 8] -> [B, 256, 8, 8]
            nn.LeakyReLU(0.2),
            PixelNorm(),
            
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
            nn.Conv2d(256, 128, 3, 1, 1),     # [B, 256, 16, 16] -> [B, 128, 16, 16]
            nn.LeakyReLU(0.2),
            PixelNorm(),
            
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
            nn.Conv2d(128, 64, 3, 1, 1),      # [B, 128, 32, 32] -> [B, 64, 32, 32]
            nn.LeakyReLU(0.2),
            PixelNorm(),
            
            nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
            nn.Conv2d(64, 32, 3, 1, 1),       # [B, 64, 64, 64] -> [B, 32, 64, 64]     
            nn.LeakyReLU(0.2),
            PixelNorm()
        )
        # Two different paths to generate different features
        self.convB = nn.Sequential(

            nn.ConvTranspose2d(512, 256, 4, 2, 1),   # [B, 256, 8, 8]
            nn.LeakyReLU(0.2),
            PixelNorm(),

            nn.ConvTranspose2d(256, 128, 4, 2, 1),   # [B, 128, 16, 16]
            nn.LeakyReLU(0.2),
            PixelNorm(),

            nn.ConvTranspose2d(128, 64, 4, 2, 1),    # [B, 64, 32, 32]
            nn.LeakyReLU(0.2),
            PixelNorm(),
            
            nn.ConvTranspose2d(64, 32, 4, 2, 1),     # [B, 32, 64, 64]
            nn.LeakyReLU(0.2),
            PixelNorm()
        )
        self.convOut = nn.Sequential(
            nn.Conv2d(64, 3, 1, 1, 0),               # [B, 3, 64, 64]
            nn.Tanh()
        )

    def forward(self, x):
        x = self.fc(x)
        xA = self.convA(x)
        xB = self.convB(x)
        x = torch.cat((xA, xB), dim=1)
        return self.convOut(x)

# Generator Testing

In [None]:
gen = Generator().to(device)
valid_plot(gen)

# GAN

In [None]:
class GAN(nn.Module):
    def __init__(self, lr=learning_rate, latent_dim=100):
        super().__init__()
        self.latent_dim = latent_dim

        # Create both the model networks
        self.G = torch.jit.script(Generator(latent_dim=self.latent_dim))
        self.C = torch.jit.script(Critic())

        # Define lr separately for each
        lr_G = lr * 0.5
        lr_C = lr * 0.5

        # Optimizers
        self.optim_G = torch.optim.Adam(self.G.parameters(), lr=lr_G, betas=betas)
        self.optim_C = torch.optim.Adam(self.C.parameters(), lr=lr_C, betas=betas)

        # Batch no. counter
        self.batch_no_g = 0
        self.batch_no_c = 0

        # Scalers
        self.scaler_G = GradScaler('cuda')
        self.scaler_C = GradScaler('cuda')

    def forward(self, x):
        self.eval()
        # Generates the images
        return self.G(x)

    def epoch_reset(self):
        # Resets counters to 0
        self.batch_no_g = 0
        self.batch_no_c = 0

    def batch_nos(self):
        # Returns batch no. counters
        return self.batch_no_c, self.batch_no_g

    def gradient_penalty(self, real, fake, device):
        '''
        Gradient Penalty function
        -> Maintains the Lipschitz constant at 1
        -> Stable gradients
        -> Reduces vanishing or exploding of gradients
        '''
        batch_size = real.size(0)
        # Calculate gradients in real-fake space using epsilon
        epsilon = torch.rand(batch_size, 1, 1, 1, device=device)
        mixed = (epsilon * real + (1 - epsilon) * fake).requires_grad_(True)

        # Calculate gradients
        mixed_scores = self.C(mixed)
        grad_outputs = torch.ones_like(mixed_scores)
        gradients = torch.autograd.grad(
            outputs=mixed_scores,
            inputs=mixed,
            grad_outputs=grad_outputs,
            create_graph=True,
            retain_graph=True
        )[0]
        gradients = gradients.view(batch_size, -1)

        # Calculate penalty
        gp = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        return gp

    def critic_loss(self, real, fake, gp, lambda_gp=10):
        with autocast('cuda'):
            # Get scores of both images
            real_score = self.C(real).mean()
            fake_score = self.C(fake).mean()

        # Log gradient penality, real and fake losses and the gap
        if self.batch_no_c % (50 * 5) == 0:
            print(f"Gradient Penality(* \u03BB): {gp * lambda_gp:8.4f}, Real score: {real_score:8.4f}, Fake Score: {fake_score:8.4f}, Gap: {real_score - fake_score:8.4f}")

        # Return total loss with gradient penalty
        return fake_score - real_score + lambda_gp * gp

    def generator_loss(self, fake):
        with autocast('cuda'):
            loss = -self.C(fake).mean()
        return loss
    
    def train_model(self, batch):
        self.train()
        real_images, _ = batch # No labels required in WGANs
        real_images = real_images.to(device)
        num_images = len(real_images) # Batch size

        # No. of times critic should get trained per batch
        n_critic = 5

        # --- Critic ---

        for _ in range(n_critic):

            # Update counter
            self.batch_no_c += 1

            # Create fake images
            noise = torch.randn(num_images, self.latent_dim, dtype=torch.float32).to(device)
            fake_images = self.G(noise).detach()

            # Calculate gradient penalty and critic loss
            gp = self.gradient_penalty(real_images, fake_images, device)
            c_loss = self.critic_loss(real_images, fake_images, gp, 10)

            # Backward pass
            self.optim_C.zero_grad()
            self.scaler_C.scale(c_loss).backward()
            self.scaler_C.step(self.optim_C)
            self.scaler_C.update()

        # --- Generator ---

        # Update counter
        self.batch_no_g += 1

        # Create fake images
        noise = torch.randn(num_images, self.latent_dim, dtype=torch.float32).to(device)
        fake_images = self.G(noise)

        # Calculate generator loss
        g_loss = self.generator_loss(fake_images)

        # Backward pass
        self.optim_G.zero_grad()
        self.scaler_G.scale(g_loss).backward()
        self.scaler_G.step(self.optim_G)
        self.scaler_G.update()

        # Return losses for logging
        return c_loss * n_critic, g_loss # n_critic because it is training for more steps

# Initialize the GAN Model
model = GAN().to(device)

# Training and Validation Plotting

In [None]:
print(f"Training started! Time: {now()}")

# Model save file names
save_file_name = "gen_checkpoint_"
critic_file_name = "crit_checkpoint_"

for epoch in range(num_epochs):

    # Initializing losses to 0
    c_loss_total, g_loss_total = 0, 0

    # Reset model counters
    model.epoch_reset()

    for batch in train_loader:

        # Train model and get losses
        c_loss, g_loss = model.train_model(batch)

        # Add losses to the total losses
        c_loss_total += c_loss
        g_loss_total += g_loss

    # Get the batch no. counters
    batch_no_c, batch_no_g = model.batch_nos()

    # Calculate the mean of losses
    c_loss_total /= batch_no_c
    g_loss_total /= batch_no_g

    # Log
    print(f"\n\nEpoch: {epoch + 1}, Time: {now()}")
    print(f"c_loss: {c_loss_total:6.4f}")
    print(f"g_loss: {g_loss_total:6.4f}\n\n")
    valid_plot(model, epoch + 1)

    # Save model weights every iteration
    save_file_name_epoch = save_file_name + zero_pad(epoch + 1) + ".pth"
    save_loc = os.path.join(model_folder_path, save_file_name_epoch)
    torch.save(model.G.state_dict(), save_loc)
    
    critic_file_name_epoch = critic_file_name + zero_pad(epoch + 1) + ".pth"
    save_loc = os.path.join(model_folder_path, critic_file_name_epoch)
    torch.save(model.C.state_dict(), save_loc)

print("Training completed!")