In [None]:
import os
import cv2
import csv
import time
import torch
import pickle
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image
from tqdm import tqdm
from pathlib import Path
from datetime import datetime
from cartoon_gan_origin.models.generator import Generator  

import torch.nn as nn
import torch.nn.utils.prune as prune

import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import CyclicLR
import torchvision.utils as vutils

from cartoon_gan_origin.utils.loss import ContentLoss, AdversialLoss
from cartoon_gan_origin.utils.transforms import get_default_transforms, get_no_aug_transform
from cartoon_gan_origin.utils.datasets import get_dataloader
from cartoon_gan_origin.utils.transforms import get_pair_transforms
from torch.utils.tensorboard import SummaryWriter
from cartoon_gan_origin.models.discriminator import Discriminator

In [None]:
WEIGHTS = 'cartoon_gan_origin/checkpoints/trained_netG.pth'
IMG_SIZE = 256
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# DEVICE = "cpu"
netG_orig = Generator().to(DEVICE)
netG_orig.eval()
netG_orig.load_state_dict(torch.load(WEIGHTS))

In [None]:
import torchvision.transforms.functional as TF

def inv_normalize(img):
    # Adding 0.1 to all normalization values since the model is trained (erroneously) without correct de-normalization
    mean = torch.Tensor([0.485, 0.456, 0.406]).to(DEVICE)
    std = torch.Tensor([0.229, 0.224, 0.225]).to(DEVICE)

    img = img * std.view(1, 3, 1, 1) + mean.view(1, 3, 1, 1)
    img = img.clamp(0, 1)
    return img


input_path = 'cartoon_gan_origin/data/trainA/world_0012.jpg'  
image = Image.open(input_path).convert('RGB')
trf = get_no_aug_transform()
image_list = torch.from_numpy(trf(image).numpy()).unsqueeze(0).to(DEVICE)

with torch.no_grad():
    generated_images = netG_orig(image_list)

generated_images = inv_normalize(generated_images)
generated_image = generated_images[0].cpu()
TF.to_pil_image(generated_image)

## Unstructured pruning (no CPU speedup)

In [None]:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune


def structured_pruning_safe(model, amount=0.2):
    print(f"Pruning {amount*100}%")
    
    for module in model.down:
        if isinstance(module, nn.Sequential):
            if len(module) > 1 and isinstance(module[1], nn.Conv2d):
                conv_layer = module[1]
                prune.ln_structured(conv_layer, name="weight", amount=amount, n=2, dim=0)
                prune.remove(conv_layer, 'weight')
    res_blocks_container = model.res[0]
    
    for layer in res_blocks_container:
        if hasattr(layer, 'block'):            
            first_conv = layer.block[0][1]
            prune.ln_structured(first_conv, name="weight", amount=amount, n=2, dim=0)
            prune.remove(first_conv, 'weight')
        
structured_pruning_safe(netG, amount=0.05)
print(f"Zero weights: {torch.sum(netG.res[0][0].block[0][1].weight == 0)}")

In [None]:
input_path = 'cartoon_gan_origin/data/trainA/world_0012.jpg'  
image = Image.open(input_path).convert('RGB')
trf = get_no_aug_transform()
image_list = torch.from_numpy(trf(image).numpy()).unsqueeze(0).to(DEVICE)

with torch.no_grad():
    generated_images = netG(image_list)

generated_images = inv_normalize(generated_images)
generated_image = generated_images[0].cpu()
TF.to_pil_image(generated_image)

## Dependency-aware structured pruning

In [None]:
import torch
import torch.nn as nn
import torch_pruning as tp

netG = Generator().to(DEVICE)
netG.eval()
netG.load_state_dict(torch.load(WEIGHTS))

### load datasets

In [None]:
# Config
batch_size = 8
image_size = 256

# Dataloaders
torch.manual_seed(1)
real_dataloader    = get_dataloader("cartoon_gan_origin/data/trainA", size = image_size, bs = batch_size)
cartoon_dataloader = get_dataloader("cartoon_gan_origin/data/trainB_ghibli", size = image_size, bs = batch_size, )
edge_dataloader = get_dataloader("cartoon_gan_origin/data/trainB_ghibli", size = image_size, bs = batch_size)

last_epoch = 0
last_i = 0

In [None]:
tracked_images = next(iter(real_dataloader)).to(DEVICE)
original_images = tracked_images.detach().cpu()
grid = vutils.make_grid(original_images, padding=2, normalize=True, nrow=4)
plt.imshow(np.transpose(grid, (1,2,0)).numpy())
plt.imsave('cartoon_gan_origin/data/test/gt.jpg', np.transpose(grid, (1,2,0)).numpy())
plt.figure()

with torch.no_grad():
    real = netG(tracked_images).detach()
    real = inv_normalize(real).cpu()
    
grid = vutils.make_grid(real, padding=2, normalize=True, nrow=4)

os.makedirs("results", exist_ok=True)
plt.imshow(np.transpose(grid, (1,2,0)).numpy())
plt.imsave('cartoon_gan_origin/data/test/orig.jpg', np.transpose(grid, (1,2,0)).numpy())

### pruning

In [None]:
depth = 5

def prune_generator_modern(model, pruning_ratio=0.3):
    example_inputs = torch.randn(1, 3, 256, 256)
    imp = tp.importance.MagnitudeImportance(p=2)

    ignored_layers = []
    for name, module in netG.up[depth:].named_modules():
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d):
            ignored_layers.append(module)

    for name, module in netG.down[:-depth].named_modules():
        if isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d):
            ignored_layers.append(module)

    pruner = tp.pruner.BasePruner(
        model,
        example_inputs,
        importance=imp,
        pruning_ratio=pruning_ratio,
        ignored_layers=ignored_layers,
        round_to=8, 
    )

    base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
    print(f"Before Pruning: {base_nparams} params")

    pruner.step()

    _, new_nparams = tp.utils.count_ops_and_params(model, example_inputs)
    print(f"After Pruning:  {new_nparams} params")
    print(f"Reduction: {100 - (new_nparams/base_nparams)*100}%")

    return model

netG.to("cpu")
netG = prune_generator_modern(netG, pruning_ratio=0.1)


with torch.no_grad():
    dummy_input = torch.randn(1, 3, 256, 256)
    output = netG(dummy_input)
    print("Output shape:", output.shape)
    # Should be [1, 3, 256, 256]

netG = netG.to(DEVICE)

input_path = 'cartoon_gan_origin/data/trainA/world_0012.jpg'  
# input_path = 'cartoon_gan_origin/data/trainA/world_0032.jpg'  
image = Image.open(input_path).convert('RGB')
trf = get_no_aug_transform()
image_list = torch.from_numpy(trf(image).numpy()).unsqueeze(0).to(DEVICE)

with torch.no_grad():
    generated_images = netG(image_list)

generated_images = inv_normalize(generated_images)
generated_image = generated_images[0].cpu()
TF.to_pil_image(generated_image)

### Model distilation

In [None]:
import copy

epochs = 15

teacher_G = Generator().to(DEVICE)
teacher_G.load_state_dict(torch.load("cartoon_gan_origin/checkpoints/trained_netG.pth"))
teacher_G.eval()
for param in teacher_G.parameters():
    param.requires_grad = False 

student_G = netG 
student_G.train()

optimizerG = torch.optim.AdamW(student_G.parameters(), lr=1e-5, betas=(0.5, 0.999))

distillation_criterion = nn.L1Loss().to(DEVICE)

print("Starting Knowledge Distillation...")

for epoch in range(epochs):
    for i, real_data in enumerate(tqdm(real_dataloader)):
        if i % 100 == 0:
            with torch.no_grad():
                fake = student_G(tracked_images).detach()
                
            fake = inv_normalize(fake).cpu()
            grid = vutils.make_grid(fake, padding=2, normalize=True, nrow=3)

            os.makedirs("results", exist_ok=True)
            plt.imsave(f"results/tune_ep_{epoch}_i{i}.png", np.transpose(grid, (1,2,0)).numpy())
            
        real_data = real_data.to(DEVICE)
        
        with torch.no_grad():
            target_cartoon = teacher_G(real_data)
            
        predicted_cartoon = student_G(real_data)
        
        loss = distillation_criterion(predicted_cartoon, target_cartoon)
        
        optimizerG.zero_grad()
        loss.backward()
        optimizerG.step()
        
        if i % 100 == 0:
             print(f"Epoch {epoch} Iter {i} | Distillation Loss: {loss.item():.4f}")

best_model_state = copy.deepcopy(student_G.state_dict())

In [None]:
input_path = 'cartoon_gan_origin/data/test/ce2.jpg'
# input_path = 'cartoon_gan_origin/data/trainA/world_0012.jpg'  
# input_path = 'cartoon_gan_origin/data/trainA/world_0032.jpg'  
image = Image.open(input_path).convert('RGB')
trf = get_no_aug_transform()
image_list = torch.from_numpy(trf(image).numpy()).unsqueeze(0).to(DEVICE)

with torch.no_grad():
    generated_images = netG(image_list)

generated_images = inv_normalize(generated_images)
generated_image = generated_images[0].cpu()
TF.to_pil_image(generated_image)

In [None]:
# WEIGHTS1 = 'cartoon_gan_origin/checkpoints/prun_distilation_hard_G.pth'
# netG = torch.load(WEIGHTS1, map_location='cpu', weights_only=False).to(DEVICE)

In [None]:
with torch.no_grad():
    fake = netG(tracked_images).detach()
    fake = inv_normalize(fake).cpu()
    
grid = vutils.make_grid(tracked_images.cpu(), padding=2, normalize=True, nrow=4)

os.makedirs("results", exist_ok=True)
plt.figure(figsize=(8, 5))
plt.imshow(np.transpose(grid, (1,2,0)).numpy())

### load distil model

In [None]:
netG.load_state_dict(best_model_state)

In [None]:
input_path = 'cartoon_gan_origin/data/trainA/world_0012.jpg'  
image = Image.open(input_path).convert('RGB')
trf = get_no_aug_transform()
image_list = torch.from_numpy(trf(image).numpy()).unsqueeze(0).to(DEVICE)

with torch.no_grad():
    generated_images = netG(image_list)

generated_images = inv_normalize(generated_images)
generated_image = generated_images[0].cpu()
TF.to_pil_image(generated_image)

In [None]:
torch.save(netG, "cartoon_gan_origin/checkpoints/prun_distilation_hard_G.pth")

### Fine-tuning

In [None]:
learning_rate_G = 1e-5  
learning_rate_D = 1e-6 
beta1, beta2 = (.5, .99)
weight_decay = 1e-4

# Models
netD = Discriminator().to(DEVICE)
netD.load_state_dict(torch.load("cartoon_gan_origin/checkpoints/trained_netD.pth"))

optimizerD = AdamW(netD.parameters(), lr=learning_rate_D, betas=(beta1, beta2), weight_decay=weight_decay)
optimizerG = AdamW(netG.parameters(), lr=learning_rate_G, betas=(beta1, beta2), weight_decay=weight_decay)

schedulerD = CyclicLR(optimizer=optimizerD, base_lr=learning_rate_D, max_lr=learning_rate_D*1e1, cycle_momentum=False)
schedulerG = CyclicLR(optimizer=optimizerG, base_lr=learning_rate_G, max_lr=learning_rate_G*1e1, cycle_momentum=False)

# Labels
cartoon_labels = torch.ones(batch_size, 1, image_size // 4, image_size // 4).to(DEVICE)
fake_labels    = torch.zeros(batch_size, 1, image_size // 4, image_size // 4).to(DEVICE)

# Loss functions
content_loss = ContentLoss(omega = 10).to(DEVICE)
adv_loss     = AdversialLoss(cartoon_labels, fake_labels)
BCE_loss     = nn.BCEWithLogitsLoss().to(DEVICE)

In [None]:
img_list = []
G_losses = []
D_losses = []

start_epoch = last_epoch
start_i = last_i

epochs = 5

for epoch in range(start_epoch, epochs):    
    print(f"Epoch {epoch}")
    real_dl_iter = iter(real_dataloader)
    cartoon_dl_iter = iter(cartoon_dataloader)
    edge_dl_iter = iter(edge_dataloader)
    iterations =  min(len(real_dl_iter), len(cartoon_dl_iter))
    
    for i in tqdm(range(start_i, iterations)):
        real_data = next(real_dl_iter)
        cartoon_data = next(cartoon_dl_iter)
        edge_data = next(edge_dl_iter)

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
    
        netD.train()
        netG.eval()
        netD.zero_grad()
        
        cartoon_data   = cartoon_data.to(DEVICE)
        edge_data      = edge_data.to(DEVICE)
        real_data      = real_data.to(DEVICE)

        generated_data = netG(real_data)

        cartoon_pred   = netD(cartoon_data)      #.view(-1)
        edge_pred      = netD(edge_data)         #.view(-1)
        generated_pred = netD(generated_data)    #.view(-1)

        errD = adv_loss(cartoon_pred, generated_pred, edge_pred)

        errD.backward()
        D_x = cartoon_pred.mean().item() # Should be close to 1
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.train()
        netD.eval()
        netG.zero_grad()

        # Since we just updated D, perform another forward pass of all-fake batch through D
        generated_data = netG(real_data)
        generated_pred = netD(generated_data) #.view(-1)

        # Calculate G's loss based on this output
        cl = content_loss(generated_data, real_data)
        bce = BCE_loss(generated_pred, cartoon_labels)
        errG = bce + cl * 0.05
        errG.backward()
        D_G_z2 = generated_pred.mean().item() # Should be close to 1
        optimizerG.step()

        if i % 25 == 0:
            with torch.no_grad():
                fake = netG(tracked_images).detach()
                fake = inv_normalize(fake).cpu()
                
            grid = vutils.make_grid(fake, padding=2, normalize=True, nrow=3)
            time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

            os.makedirs("results", exist_ok=True)
            plt.imsave(f"results/E{epoch}_i{i}_{time}.png", np.transpose(grid, (1,2,0)).numpy())
            img_list.append(grid)
            
            print(('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f\t%s'
                % (epoch, epochs, i, iterations, errD.item(), errG.item(), D_x, D_G_z2, time)).expandtabs(25))
            
        # Save Losses for plotting later
        G_losses.append(errG.item())
        # D_losses.append(errD.item())
        
        # schedulerD.step()
        schedulerG.step()
        
        last_i = i
    start_i = 0
    last_epoch = epoch

In [None]:
torch.save(netG, "cartoon_gan_origin/checkpoints/tune_trained_netG_05prun.pth")
torch.save(netD.state_dict(), "cartoon_gan_origin/checkpoints/tune_trained_netD_05prun.pth")

In [None]:
input_path = 'cartoon_gan_origin/data/trainA/world_0012.jpg'  
image = Image.open(input_path).convert('RGB')
trf = get_no_aug_transform()
image_list = torch.from_numpy(trf(image).numpy()).unsqueeze(0).to(DEVICE)

with torch.no_grad():
    generated_images = netG(image_list)

generated_images = inv_normalize(generated_images)
generated_image = generated_images[0].cpu()
TF.to_pil_image(generated_image)