In [None]:
#libraries
import os
from glob import glob
from PIL import Image
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision.transforms import ToTensor, Resize, Normalize, Compose
import matplotlib.pyplot as plt
import torchvision.transforms.functional as TF
import torch.nn as nn
import torch
from torchvision.models import vgg19, VGG19_Weights
import tensorflow as tf
from tensorflow.keras import layers, Model
from torch.cuda.amp import autocast, GradScaler
from google.colab import drive
import zipfile
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import gc







In [None]:
#extracting data

drive.mount('/content/drive')

zip_path = '/content/drive/MyDrive/Wider-Face/train/sharp/CelebA.zip'
extracted_path = '/content/data/CelebA'
os.makedirs(extracted_path, exist_ok=True)
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extracted_path)

print(f"Extracted files to: {extracted_path}")


In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


In [None]:


#Data class
class PairedSRDataset(Dataset):
    def __init__(self, root_dir, lr_folder='blur', hr_folder='sharp', lr_size=32):
        self.lr_size = lr_size
        self.hr_size = lr_size * 4  # 4× upscale

        lr_files = sorted(glob(os.path.join(root_dir, lr_folder, '*')))
        hr_files = sorted(glob(os.path.join(root_dir, hr_folder, '*')))

        # Create a dictionary of HR files for quick lookup by base name
        hr_dict = {os.path.splitext(os.path.basename(f))[0]: f for f in hr_files}

        self.lr_paths = []
        self.hr_paths = []

        # Find matching LR and HR files based on base name
        for lr_path in lr_files:
            lr_base_name = os.path.splitext(os.path.basename(lr_path))[0]
            if lr_base_name in hr_dict:
                self.lr_paths.append(lr_path)
                self.hr_paths.append(hr_dict[lr_base_name])

        print(f"Found {len(self.lr_paths)} paired images.")


        self.transform_lr = Compose([Resize((self.lr_size, self.lr_size)), ToTensor(), Normalize((0.5,)*3, (0.5,)*3)])
        self.transform_hr = Compose([Resize((self.hr_size, self.hr_size)), ToTensor(), Normalize((0.5,)*3, (0.5,)*3)])


    def __len__(self):
        return len(self.lr_paths)

    def __getitem__(self, idx):
        lr = Image.open(self.lr_paths[idx]).convert('RGB')
        hr = Image.open(self.hr_paths[idx]).convert('RGB')

        return self.transform_lr(lr), self.transform_hr(hr)



# Example: load training set
train_ds = PairedSRDataset('/content/data/CelebA/CelebA/train')

# Create subset of first 5k images
subset_size = min(5000, len(train_ds))
train_subset = Subset(train_ds, indices=range(subset_size))

#Data loader
train_loader = DataLoader(
    train_subset,
    batch_size=8,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

In [None]:
#printing image
print(len(train_subset))
#show image

lr, hr = train_subset[0]  # Low-res and high-res tensors

# Convert tensors back to images
lr_img = TF.to_pil_image((lr * 0.5 + 0.5))  # De-normalize
hr_img = TF.to_pil_image((hr * 0.5 + 0.5))  # De-normalize

# Plot the images
plt.figure(figsize=(12, 8))

plt.title("Low Resolution")
plt.imshow(lr_img)
plt.axis("off")
plt.show()
plt.figure(figsize=(12, 8))
plt.title("High Resolution")
plt.imshow(hr_img)
plt.axis("off")

plt.show()


In [None]:
#Generator
class ResidualDenseBlock(nn.Module):
    def __init__(self, num_feat=64, num_grow_ch=32):
        super().__init__()
        self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
        self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
        self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
        self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
        self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        return x5 * 0.2 + x

class RRDB(nn.Module):
    def __init__(self, num_feat=64, num_grow_ch=32):
        super().__init__()
        self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
        self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
        self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)

    def forward(self, x):
        out = self.rdb1(x)
        out = self.rdb2(out)
        out = self.rdb3(out)
        return out * 0.2 + x

class RRDBNet(nn.Module):
    def __init__(self, in_ch=3, out_ch=3, num_feat=64, num_block=23, num_grow_ch=32):
        super().__init__()
        self.conv_first = nn.Conv2d(in_ch, num_feat, 3, 1, 1)
        self.body = nn.Sequential(*[RRDB(num_feat, num_grow_ch) for _ in range(num_block)])
        self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
        self.conv_last = nn.Conv2d(num_feat, out_ch, 3, 1, 1)
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        feat = self.conv_first(x)
        body_feat = self.conv_body(self.body(feat))
        feat = feat + body_feat
        feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
        feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
        feat = self.conv_hr(feat)
        out = self.conv_last(feat)
        return out

#Discriminator
class VGGStyleDiscriminator(nn.Module):
    def __init__(self, in_ch=3, num_feat=64):
        super().__init__()
        self.conv0_0 = nn.Conv2d(in_ch, num_feat, 3, 1, 1)
        self.conv0_1 = nn.Conv2d(num_feat, num_feat, 4, 2, 1, bias=False)
        self.lrelu = nn.LeakyReLU(0.2, True)

        # Downsampling blocks
        self.conv1_0 = nn.Conv2d(num_feat, num_feat * 2, 3, 1, 1, bias=False)
        self.conv1_1 = nn.Conv2d(num_feat * 2, num_feat * 2, 4, 2, 1, bias=False)

        self.conv2_0 = nn.Conv2d(num_feat * 2, num_feat * 4, 3, 1, 1, bias=False)
        self.conv2_1 = nn.Conv2d(num_feat * 4, num_feat * 4, 4, 2, 1, bias=False)

        self.conv3_0 = nn.Conv2d(num_feat * 4, num_feat * 8, 3, 1, 1, bias=False)
        self.conv3_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False)

        # Final layers
        self.conv4 = nn.Conv2d(num_feat * 8, 1, 3, 1, 1)

    def forward(self, x):
        x = self.lrelu(self.conv0_0(x))
        x = self.lrelu(self.conv0_1(x))
        x = self.lrelu(self.conv1_0(x))
        x = self.lrelu(self.conv1_1(x))
        x = self.lrelu(self.conv2_0(x))
        x = self.lrelu(self.conv2_1(x))
        x = self.lrelu(self.conv3_0(x))
        x = self.lrelu(self.conv3_1(x))
        x = self.conv4(x)
        return x

# Loss Functions, perceptual loss
class PerceptualLoss(nn.Module):
    def __init__(self):
        super().__init__()
        # Using pretrained-vgg to extract the feature
        vgg = vgg19(weights=VGG19_Weights.IMAGENET1K_V1).features[:35].eval()
        for param in vgg.parameters():
            param.requires_grad = False
        self.vgg = vgg
        self.criterion = nn.L1Loss()
        self.mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
        self.std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)

    def forward(self, fake, real):
        fake = (fake - self.mean.to(fake.device)) / self.std.to(fake.device)
        real = (real - self.mean.to(real.device)) / self.std.to(real.device)
        features_fake = self.vgg(fake)
        features_real = self.vgg(real) # No need to detach here, as VGG has no grads
        return self.criterion(features_fake, features_real)

# Relativistic average Discriminator Loss
class RelativisticDiscriminatorLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, pred_real, pred_fake):
        real_logit = pred_real - torch.mean(pred_fake)
        fake_logit = pred_fake - torch.mean(pred_real)
        return self.bce(real_logit, torch.ones_like(real_logit)) + \
               self.bce(fake_logit, torch.zeros_like(fake_logit))

# Relativistic average Generator Loss
class RelativisticGeneratorLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss()

    def forward(self, pred_real, pred_fake):

        real_logit = pred_real - torch.mean(pred_fake)
        fake_logit = pred_fake - torch.mean(pred_real)
        return self.bce(real_logit, torch.zeros_like(real_logit)) + \
               self.bce(fake_logit, torch.ones_like(fake_logit))

class CombinedGeneratorLoss(nn.Module):
    def __init__(self, w_adv=0.005, w_perceptual=1.0, w_l1=0.01):
        super().__init__()
        self.w_adv = w_adv
        self.w_perceptual = w_perceptual
        self.w_l1 = w_l1
        self.perceptual = PerceptualLoss()
        self.adversarial = RelativisticGeneratorLoss()
        self.l1 = nn.L1Loss()

    def forward(self, fake_hr, real_hr, pred_fake, pred_real):
        adv_loss = self.adversarial(pred_real, pred_fake)
        perceptual_loss = self.perceptual(fake_hr, real_hr)
        l1_loss = self.l1(fake_hr, real_hr)

        return (self.w_adv * adv_loss +
                self.w_perceptual * perceptual_loss +
                self.w_l1 * l1_loss)


In [None]:

# Training part
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Initialize models
generator = RRDBNet(in_ch=3, out_ch=3, num_feat=64, num_block=23, num_grow_ch=32).to(device)
discriminator = VGGStyleDiscriminator(in_ch=3, num_feat=64).to(device)

# Load pretrained generator weights
weights_path = "/content/drive/MyDrive/RealESRGAN_x4.pth"
state_dict = torch.load(weights_path)

# pretrained weight of esrgan
if 'params_ema' in state_dict:
    clean_state_dict = state_dict['params_ema']
elif 'params' in state_dict:
    clean_state_dict = state_dict['params']
else:
    clean_state_dict = state_dict

generator.load_state_dict(clean_state_dict, strict=True)
print("Loaded")


# Freeze all parameters in the generator
for param in generator.parameters():
    param.requires_grad = False

#  Unfreeze only the desired layers for fine-tuning
for name, param in generator.named_parameters():
    if 'conv_last' in name:
        param.requires_grad = True
        print(f"Unfrozen for training: {name}")

# optimizer
opt_g = optim.Adam(
    filter(lambda p: p.requires_grad, generator.parameters()),
    lr=1e-5,  # Use a smaller learning rate for fine-tuning
    betas=(0.9, 0.999)
)
# The discriminator will be trained from scratch
opt_d = optim.Adam(discriminator.parameters(), lr=2e-5, betas=(0.9, 0.999))

# 4. Initialize Loss Functions
g_loss_fn = CombinedGeneratorLoss().to(device)
d_loss_fn = RelativisticDiscriminatorLoss().to(device)

# 5. Initialize AMP GradScalers for performance
g_scaler = GradScaler()
d_scaler = GradScaler()



In [None]:

# Training LOop
NUM_EPOCHS = 25
for epoch in range(NUM_EPOCHS):
    generator.train()
    discriminator.train()

    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")

    for batch_idx, (lr_imgs, hr_imgs) in progress_bar:
        lr_imgs = lr_imgs.to(device)
        hr_imgs = hr_imgs.to(device)


        # Train Discriminator

        opt_d.zero_grad()

        with autocast(str(device)):
            with torch.no_grad():
                fake_hr = generator(lr_imgs)

            pred_real = discriminator(hr_imgs)
            pred_fake = discriminator(fake_hr.detach())
            d_loss = d_loss_fn(pred_real, pred_fake)

        d_scaler.scale(d_loss).backward()
        d_scaler.step(opt_d)
        d_scaler.update()


        #  Train Generator

        opt_g.zero_grad()

        with autocast(str(device)):
            # Re-generate fake images to build computation graph for generator
            fake_hr = generator(lr_imgs)
            pred_fake = discriminator(fake_hr)
            pred_real = discriminator(hr_imgs).detach()

            g_loss = g_loss_fn(fake_hr, hr_imgs, pred_fake, pred_real)

        g_scaler.scale(g_loss).backward()
        g_scaler.step(opt_g)
        g_scaler.update()

        # Update progress bar
        progress_bar.set_postfix(G_Loss=f'{g_loss.item():.4f}', D_Loss=f'{d_loss.item():.4f}')

    print(f"\n End of Epoch {epoch+1}")

    # Optional: clean GPU cache
    gc.collect()
    torch.cuda.empty_cache()

    # Save checkpoint
    if (epoch + 1) % 5 == 0:
        torch.save({
            'generator': generator.state_dict(),
            'discriminator': discriminator.state_dict(),
            'opt_g': opt_g.state_dict(),
            'opt_d': opt_d.state_dict(),
            'epoch': epoch,
        }, f"finetuned_realesrgan_epoch_{epoch+1}.pth")
        print(f"Saved checkpoint for epoch {epoch+1}")