In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as T
import matplotlib.pyplot as plt
import torchvision.utils as vutils
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim

In [None]:
transform = T.Compose([
    T.Resize((256, 256)),
    T.ToTensor(),
    T.Normalize([0.5]*3, [0.5]*3)
])


In [None]:
def get_image_filenames(face_dir):
    return sorted(os.listdir(face_dir))  # assumes same names in face/ and comics/


In [None]:
def get_dataset(root_dir, image_size=256):
    face_dir = os.path.join(root_dir, 'face')
    comic_dir = os.path.join(root_dir, 'comics')
    filenames = get_image_filenames(face_dir)

    transform = T.Compose([
        T.Resize((image_size, image_size)),
        T.ToTensor(),
        T.Normalize([0.5]*3, [0.5]*3)
    ])

    def dataset_fn(index):
        fname = filenames[index]
        face_path = os.path.join(face_dir, fname)
        comic_path = os.path.join(comic_dir, fname)

        face = Image.open(face_path).convert("RGB")
        comic = Image.open(comic_path).convert("RGB")

        face = transform(face)
        comic = transform(comic)

        return {
            'face': face,
            'comic': comic,
            'filename': fname
        }

    return dataset_fn, len(filenames)


In [None]:
def get_dataloader_from_fn(dataset_fn, dataset_len, batch_size=16, shuffle=True, num_workers=2, subset_indices=None):
    class FunctionalDataset(Dataset):
        def __len__(self):
            return len(subset_indices) if subset_indices is not None else dataset_len
        
        def __getitem__(self, idx):
            idx = subset_indices[idx] if subset_indices is not None else idx
            return dataset_fn(idx)

    return DataLoader(FunctionalDataset(), batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)


In [None]:
# Get dataset function and length
dataset_fn, dataset_len = get_dataset('archive/face2comics', image_size=256)

# Create subset with the first 500 images
subset_indices = list(range(100))  # First 500 images

# Create a DataLoader for the subset
dataloader = get_dataloader_from_fn(dataset_fn, dataset_len, batch_size=4, shuffle=True, num_workers=0, subset_indices=subset_indices)

# Check if it loads properly
for batch in dataloader:
    print(batch['face'].shape, batch['comic'].shape)
    break

In [None]:
def show_real_vs_comic(real, comic, max_images=8):
    real = real[:max_images]
    comic = comic[:max_images]

    # De-normalize from [-1, 1] back to [0, 1]
    real = (real * 0.5) + 0.5
    comic = (comic * 0.5) + 0.5

    fig, axes = plt.subplots(2, max_images, figsize=(3 * max_images, 6))

    for i in range(max_images):
        axes[0, i].imshow(real[i].permute(1, 2, 0).cpu().numpy())
        axes[0, i].axis('off')
        axes[0, i].set_title("Real")

        axes[1, i].imshow(comic[i].permute(1, 2, 0).cpu().numpy())
        axes[1, i].axis('off')
        axes[1, i].set_title("Comic")

    plt.tight_layout()
    plt.show()

In [None]:
for batch in dataloader:
    show_real_vs_comic(batch['face'], batch['comic'], max_images=4)
    break


In [None]:
def plot_pixel_histogram(image_tensor, title="Histogram"):
    img = image_tensor[0]  # Take first image in batch
    img = (img * 0.5) + 0.5  # De-normalize to [0, 1]
    img = img.permute(1, 2, 0).cpu().numpy()

    plt.figure(figsize=(8, 4))
    for i, color in enumerate(['r', 'g', 'b']):
        plt.hist(img[..., i].ravel(), bins=256, color=color, alpha=0.5, label=color)
    plt.title(title)
    plt.xlabel("Pixel Intensity")
    plt.ylabel("Count")
    plt.legend()
    plt.show()


In [None]:
for batch in dataloader:
    plot_pixel_histogram(batch['face'], title="Real Face Histogram")
    plot_pixel_histogram(batch['comic'], title="Comic Histogram")
    break


U-Net Generator

In [None]:
# Downsampling block
def down_block(in_channels, out_channels, normalize=True):
    layers = [nn.Conv2d(in_channels, out_channels, 4, stride=2, padding=1)]
    if normalize:
        layers.append(nn.BatchNorm2d(out_channels))
    layers.append(nn.LeakyReLU(0.2, inplace=True))
    return nn.Sequential(*layers)

# Upsampling block
def up_block(in_channels, out_channels, dropout=0.0):
    layers = [
        nn.ConvTranspose2d(in_channels, out_channels, 4, stride=2, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    ]
    if dropout:
        layers.append(nn.Dropout(dropout))
    return nn.Sequential(*layers)

# Full U-Net Generator
class UNetGenerator(nn.Module):
    def __init__(self):
        super(UNetGenerator, self).__init__()
        # Encoder
        self.d1 = down_block(3, 64, normalize=False)  # No norm in first layer
        self.d2 = down_block(64, 128)
        self.d3 = down_block(128, 256)
        self.d4 = down_block(256, 512)
        self.d5 = down_block(512, 512)

        # Decoder
        self.u1 = up_block(512, 512)
        self.u2 = up_block(1024, 256)
        self.u3 = up_block(512, 128)
        self.u4 = up_block(256, 64)

        # Final output layer
        self.final = nn.Sequential(
            nn.ConvTranspose2d(128, 3, 4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        # Encoder
        d1_out = self.d1(x)
        d2_out = self.d2(d1_out)
        d3_out = self.d3(d2_out)
        d4_out = self.d4(d3_out)
        d5_out = self.d5(d4_out)

        # Decoder with skip connections
        u1_out = self.u1(d5_out)
        u1_out = torch.cat((u1_out, d4_out), dim=1)

        u2_out = self.u2(u1_out)
        u2_out = torch.cat((u2_out, d3_out), dim=1)

        u3_out = self.u3(u2_out)
        u3_out = torch.cat((u3_out, d2_out), dim=1)

        u4_out = self.u4(u3_out)
        u4_out = torch.cat((u4_out, d1_out), dim=1)

        output = self.final(u4_out)
        return output

LightResNet Generator

In [None]:
# Residual block
def residual_block(in_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(in_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(in_channels)
    )

# Full LightResNet Generator
class LightResNetGenerator(nn.Module):
    def __init__(self):
        super(LightResNetGenerator, self).__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
        )
        
        # Residual blocks
        self.res_blocks = nn.Sequential(
            residual_block(256),
            residual_block(256),
            residual_block(256)
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1),
            nn.Tanh()
        )
        
    def forward(self, x):
        e = self.encoder(x)
        r = self.res_blocks(e)
        out = self.decoder(r)
        return out

PatchGAN Discriminator

In [None]:
class PatchGANDiscriminator(nn.Module):
    def __init__(self):
        super(PatchGANDiscriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(6, 64, 4, stride=2, padding=1),  # input is real+fake images stacked
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, 4, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(512, 1, 4, stride=1, padding=1),
        )

    def forward(self, x_real, x_fake):
        # Concatenate real and generated images along the channel dimension
        x = torch.cat((x_real, x_fake), dim=1)
        return self.model(x)


LightResNet Loss Function

In [None]:
# Loss functions for LightResNet (GAN loss only)

bce_loss = nn.BCEWithLogitsLoss()

def gan_loss(predicted, target_is_real):
    """Calculates standard GAN loss"""
    label_smooth = 0.1  # Set label smoothing factor
    if target_is_real:
        target = torch.ones_like(predicted) * (1 - label_smooth)  # Real labels are slightly less than 1
    else:
        target = torch.zeros_like(predicted) + label_smooth  # Fake labels are slightly more than 0
    return bce_loss(predicted, target)


U-Net Generator Loss Function

In [None]:
# Loss functions for U-Net (GAN + L1)

l1_loss = nn.L1Loss()

def unet_combined_loss(pred_fake, pred_real, fake_image, target_image, lambda_L1=100):
    """
    pred_fake: output from discriminator on fake images
    pred_real: output from discriminator on real images
    fake_image: generated image
    target_image: ground truth image
    """
    gan = gan_loss(pred_fake, True)  # Generator wants to fool the discriminator
    l1 = l1_loss(fake_image, target_image) * lambda_L1
    return gan + l1


In [None]:
# Visualization helper functions
def plot_loss_curves(loss_D_list, loss_G_list):
    epochs = range(len(loss_D_list))
    
    plt.figure(figsize=(10, 5))
    
    # Plot Discriminator Loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, loss_D_list, label='Discriminator Loss')
    plt.title('Discriminator Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    
    # Plot Generator Loss
    plt.subplot(1, 2, 2)
    plt.plot(epochs, loss_G_list, label='Generator Loss')
    plt.title('Generator Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    
    plt.tight_layout()
    plt.show()

In [None]:
# Display side-by-side image comparison every 5 epochs
def display_images(input_image, real_image, fake_image, epoch, num_images=1):
    input_image = input_image[:num_images].cpu().detach().numpy()
    real_image = real_image[:num_images].cpu().detach().numpy()
    fake_image = fake_image[:num_images].cpu().detach().numpy()

    plt.figure(figsize=(15, 5))
    
    for i in range(num_images):
        plt.subplot(3, num_images, i+1)
        plt.imshow(input_image[i].transpose(1, 2, 0))
        plt.title('Input Image')
        plt.axis('off')
        
        plt.subplot(3, num_images, i+1+num_images)
        plt.imshow(real_image[i].transpose(1, 2, 0))
        plt.title('Real Image')
        plt.axis('off')
        
        plt.subplot(3, num_images, i+1+2*num_images)
        plt.imshow(fake_image[i].transpose(1, 2, 0))
        plt.title(f'Generated Image (Epoch {epoch})')
        plt.axis('off')

    plt.show()

In [None]:
def train_model(generator, discriminator, dataloader, num_epochs=100, device='cuda', lambda_L1=100, learning_rate=2e-4, use_l1=True):
    generator = generator.to(device)
    discriminator = discriminator.to(device)
    
    # Optimizers
    opt_G = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
    opt_D = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
    
    # Loss tracking lists
    loss_D_list = []
    loss_G_list = []
    
    # Training loop
    for epoch in tqdm(range(num_epochs), desc="Training Epochs"):
        epoch_loss_D = 0
        epoch_loss_G = 0
        for idx, data in enumerate(dataloader):
            input_image = data['face'].to(device)
            target_image = data['comic'].to(device)

            # === Train Discriminator ===
            opt_D.zero_grad()

            # Real pair
            pred_real = discriminator(input_image, target_image)
            loss_real = gan_loss(pred_real, True)

            # Fake pair
            fake_image = generator(input_image)
            pred_fake = discriminator(input_image, fake_image.detach())  # Detach to avoid gradients to G
            loss_fake = gan_loss(pred_fake, False)

            # Total discriminator loss
            loss_D = (loss_real + loss_fake) * 0.5
            loss_D.backward()
            opt_D.step()

            epoch_loss_D += loss_D.item()

            # === Train Generator ===
            opt_G.zero_grad()

            # Generator tries to fool discriminator
            pred_fake_for_g = discriminator(input_image, fake_image)

            if use_l1:
                loss_G = unet_combined_loss(pred_fake_for_g, pred_real, fake_image, target_image, lambda_L1)
            else:
                loss_G = gan_loss(pred_fake_for_g, True)

            loss_G.backward()
            opt_G.step()

            epoch_loss_G += loss_G.item()

            if idx % 50 == 0:
                print(f"Epoch [{epoch}/{num_epochs}] Batch [{idx}/{len(dataloader)}] Loss_D: {loss_D.item():.4f} Loss_G: {loss_G.item():.4f}")
        
        # Store average losses for this epoch
        loss_D_list.append(epoch_loss_D / len(dataloader))
        loss_G_list.append(epoch_loss_G / len(dataloader))
        
        # Display comparison of generated images every 5 epochs
        if epoch % 5 == 0:
            display_images(input_image, target_image, fake_image, epoch)

    # Plot loss curves after training
    plot_loss_curves(loss_D_list, loss_G_list)

    return loss_D_list, loss_G_list


In [None]:
# Instantiate models first
#lightresnet_generator = build_lightresnet_generator()
#unet_generator = build_unet_generator()
#patchgan_discriminator = build_patchgan_discriminator()

unet_generator = UNetGenerator()
lightresnet_generator = LightResNetGenerator()
discriminator = PatchGANDiscriminator()

# Move to device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Divice: {device}")
#generator = generator.to(device)
#discriminator = discriminator.to(device)

In [None]:
# Hyperparameter tuning settings
lambda_L1_values = [50, 100, 200]  # Values for lambda_L1
learning_rate_values = [1e-4, 2e-4, 5e-4]  # Values for learning rate

In [None]:
# Main function to perform hyperparameter tuning
def hyperparameter_tuning(generator, discriminator, dataloader, device='cuda', num_epochs=50, use_l1=True):
    best_loss = float('inf')
    best_params = None

    # Iterate over all combinations of hyperparameters
    for lambda_L1 in lambda_L1_values:
        for learning_rate in learning_rate_values:
            print(f"\nTraining with lambda_L1={lambda_L1}, learning_rate={learning_rate}")
            #generator = generator_class()  # Assuming a function that returns a new generator model
            # Set up optimizers with the current learning rate
            #opt_G = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
            
            # Train the model with current hyperparameters
            loss_D_list, loss_G_list = train_model(generator, discriminator, dataloader, num_epochs, device, lambda_L1, learning_rate, use_l1)

            # Use the final generator loss as a performance measure (can be modified for your needs)
            final_loss_G = loss_G_list[-1]  # Get the final generator loss

            # Update best parameters based on loss
            if final_loss_G < best_loss:
                best_loss = final_loss_G
                best_params = {
                    'lambda_L1': lambda_L1,
                    'learning_rate': learning_rate
                }

            # Visualize losses (this can also be adjusted as needed)
            plot_losses(loss_D_list, loss_G_list)

    print(f"\nBest Parameters: lambda_L1={best_params[0]}, learning_rate={best_params[1]}")
    return best_params

In [None]:
#train_model(unet_generator, patchgan_discriminator, dataloader, num_epochs=200, device='cuda', use_l1=True)
best_params = hyperparameter_tuning(unet_generator, discriminator, dataloader, device, use_l1=True)


In [None]:
train_model(lightresnet_generator, patchgan_discriminator, dataloader, num_epochs=200, device='cuda', use_l1=False)
