In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
from PIL import Image
import glob
import random
from tqdm.notebook import tqdm

In [None]:
# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Configuration parameters
class Config:
    # Model parameters
    FEATURE_MAPS = 64
    RESIDUAL_BLOCKS = 16
    LEAKY_ALPHA = 0.2
    DISC_BLOCKS = 4
    RESIDUAL_SCALAR = 0.2
    
    # Training parameters
    PRETRAIN_LR = 1e-4
    FINETUNE_LR = 3e-5
    PRETRAIN_EPOCHS = 100  # Reduced for demonstration
    FINETUNE_EPOCHS = 200  # Reduced for demonstration
    BATCH_SIZE = 16
    SCALE_FACTOR = 4
    
    # Dataset paths
    BASE_DATA_PATH = "dataset"
    DIV2K_PATH = os.path.join(BASE_DATA_PATH, "div2k")
    
    # Output paths
    BASE_OUTPUT_PATH = "output"
    PRETRAINED_GEN_PATH = os.path.join(BASE_OUTPUT_PATH, "pretrained_generator.h5")
    GEN_PATH = os.path.join(BASE_OUTPUT_PATH, "esrgan_generator.h5")

# Create output directory
os.makedirs(Config.BASE_OUTPUT_PATH, exist_ok=True)

# Dataset preparation
class DIV2KDataset(Dataset):
    def __init__(self, hr_dir, scale_factor=4, crop_size=128, transform=None):
        self.hr_dir = hr_dir
        self.scale_factor = scale_factor
        self.crop_size = crop_size
        self.transform = transform
        self.hr_images = sorted(glob.glob(os.path.join(hr_dir, "*.png")))
        
    def __len__(self):
        return len(self.hr_images)
    
    def __getitem__(self, idx):
        hr_img = Image.open(self.hr_images[idx]).convert('RGB')
        
        # Random crop
        if self.transform:
            hr_img = self.transform(hr_img)
        
        # Create LR image through downsampling
        lr_img = F.interpolate(
            hr_img.unsqueeze(0), 
            scale_factor=1.0/self.scale_factor, 
            mode='bicubic', 
            align_corners=False
        ).squeeze(0)
        
        return lr_img, hr_img

# Data transformations
transform = transforms.Compose([
    transforms.RandomCrop(Config.SCALE_FACTOR * 32),  # Ensure divisibility by scale factor
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor()
])

In [None]:

# ESRGAN Model Architecture
class ResidualDenseBlock(nn.Module):
    def __init__(self, filters=64, res_scale=0.2):
        super(ResidualDenseBlock, self).__init__()
        self.res_scale = res_scale
        
        self.conv1 = nn.Conv2d(filters, filters, 3, padding=1, bias=True)
        self.conv2 = nn.Conv2d(filters*2, filters, 3, padding=1, bias=True)
        self.conv3 = nn.Conv2d(filters*3, filters, 3, padding=1, bias=True)
        self.conv4 = nn.Conv2d(filters*4, filters, 3, padding=1, bias=True)
        self.conv5 = nn.Conv2d(filters*5, filters, 3, padding=1, bias=True)
        self.lrelu = nn.LeakyReLU(negative_slope=Config.LEAKY_ALPHA, inplace=True)
    
    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        return x5 * self.res_scale + x

class RRDB(nn.Module):
    def __init__(self, filters, res_scale=0.2):
        super(RRDB, self).__init__()
        self.res_scale = res_scale
        self.rdb1 = ResidualDenseBlock(filters, res_scale)
        self.rdb2 = ResidualDenseBlock(filters, res_scale)
        self.rdb3 = ResidualDenseBlock(filters, res_scale)
    
    def forward(self, x):
        out = self.rdb1(x)
        out = self.rdb2(out)
        out = self.rdb3(out)
        return out * self.res_scale + x

class Generator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, filters=64, num_res_blocks=16, upscale_factor=4):
        super(Generator, self).__init__()
        
        # First conv layer
        self.conv_first = nn.Conv2d(in_channels, filters, kernel_size=3, stride=1, padding=1)
        
        # RRDB blocks
        rrdb_blocks = []
        for _ in range(num_res_blocks):
            rrdb_blocks.append(RRDB(filters))
        self.rrdb_blocks = nn.Sequential(*rrdb_blocks)
        
        # Second conv layer after residual blocks
        self.conv_after_blocks = nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1)
        
        # Upsampling layers
        upsampling = []
        for _ in range(2):  # For 4x upscaling (2^2 = 4)
            upsampling.extend([
                nn.Conv2d(filters, filters * 4, kernel_size=3, stride=1, padding=1),
                nn.PixelShuffle(2),
                nn.LeakyReLU(Config.LEAKY_ALPHA, inplace=True)
            ])
        self.upsampling = nn.Sequential(*upsampling)
        
        # Final output layer
        self.conv_last = nn.Conv2d(filters, out_channels, kernel_size=3, stride=1, padding=1)
    
    def forward(self, x):
        # First conv
        feat = self.conv_first(x)
        trunk = feat
        
        # RRDB blocks
        trunk = self.rrdb_blocks(trunk)
        
        # Second conv
        trunk = self.conv_after_blocks(trunk)
        feat = feat + trunk
        
        # Upsampling
        feat = self.upsampling(feat)
        
        # Final conv
        out = self.conv_last(feat)
        return out

class Discriminator(nn.Module):
    def __init__(self, input_shape=(3, 128, 128), filters=64):
        super(Discriminator, self).__init__()
        
        def discriminator_block(in_filters, out_filters, stride=1, normalize=True):
            layers = [nn.Conv2d(in_filters, out_filters, 3, stride, 1)]
            if normalize:
                layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(Config.LEAKY_ALPHA, inplace=True))
            return layers
        
        self.model = nn.Sequential(
            *discriminator_block(input_shape[0], filters, normalize=False),
            *discriminator_block(filters, filters, stride=2),
            *discriminator_block(filters, filters * 2),
            *discriminator_block(filters * 2, filters * 2, stride=2),
            *discriminator_block(filters * 2, filters * 4),
            *discriminator_block(filters * 4, filters * 4, stride=2),
            *discriminator_block(filters * 4, filters * 8),
            *discriminator_block(filters * 8, filters * 8, stride=2),
            nn.Conv2d(filters * 8, 1, 3, stride=1, padding=1)
        )
    
    def forward(self, img):
        return self.model(img)

In [None]:

# VGG19 Feature Extractor for Perceptual Loss
class VGGFeatureExtractor(nn.Module):
    def __init__(self, feature_layer=35, use_bn=False):
        super(VGGFeatureExtractor, self).__init__()
        from torchvision.models import vgg19
        vgg = vgg19(pretrained=True)
        self.features = nn.Sequential(*list(vgg.features.children())[:feature_layer]).eval()
        for param in self.features.parameters():
            param.requires_grad = False
    
    def forward(self, x):
        return self.features(x)

# Loss Functions
class ContentLoss(nn.Module):
    def __init__(self):
        super(ContentLoss, self).__init__()
        self.l1 = nn.L1Loss()
    
    def forward(self, sr, hr):
        return self.l1(sr, hr)

class PerceptualLoss(nn.Module):
    def __init__(self):
        super(PerceptualLoss, self).__init__()
        self.feature_extractor = VGGFeatureExtractor().to(device)
        self.l1 = nn.L1Loss()
    
    def forward(self, sr, hr):
        sr_features = self.feature_extractor(sr)
        hr_features = self.feature_extractor(hr)
        return self.l1(sr_features, hr_features)

class AdversarialLoss(nn.Module):
    def __init__(self):
        super(AdversarialLoss, self).__init__()
        self.bce = nn.BCEWithLogitsLoss()
    
    def forward(self, pred, target):
        return self.bce(pred, target)

# Training Functions
def pretrain_generator(generator, dataloader, epochs, optimizer, criterion, device):
    generator.train()
    
    for epoch in range(epochs):
        epoch_loss = 0
        progress_bar = tqdm(dataloader)
        
        for i, (lr, hr) in enumerate(progress_bar):
            lr = lr.to(device)
            hr = hr.to(device)
            
            optimizer.zero_grad()
            
            sr = generator(lr)
            loss = criterion(sr, hr)
            
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            progress_bar.set_description(f"Epoch {epoch+1}/{epochs} | Loss: {epoch_loss/(i+1):.6f}")
        
        # Save sample images
        if (epoch + 1) % 10 == 0:
            with torch.no_grad():
                sample_lr = next(iter(dataloader))[0][:4].to(device)
                sample_sr = generator(sample_lr)
                
                # Denormalize
                sample_lr = torch.clamp(sample_lr, 0, 1)
                sample_sr = torch.clamp(sample_sr, 0, 1)
                
                # Save grid of images
                grid_lr = make_grid(sample_lr, nrow=2, normalize=True)
                grid_sr = make_grid(sample_sr, nrow=2, normalize=True)
                
                save_image(grid_lr, os.path.join(Config.BASE_OUTPUT_PATH, f"pretrain_lr_epoch_{epoch+1}.png"))
                save_image(grid_sr, os.path.join(Config.BASE_OUTPUT_PATH, f"pretrain_sr_epoch_{epoch+1}.png"))
    
    # Save the pretrained generator
    torch.save(generator.state_dict(), Config.PRETRAINED_GEN_PATH)
    print(f"Pretrained generator saved to {Config.PRETRAINED_GEN_PATH}")

In [None]:
def train_esrgan(generator, discriminator, dataloader, epochs, g_optimizer, d_optimizer, 
                content_criterion, perceptual_criterion, adversarial_criterion, device):
    generator.train()
    discriminator.train()
    
    for epoch in range(epochs):
        epoch_g_loss = 0
        epoch_d_loss = 0
        progress_bar = tqdm(dataloader)
        
        for i, (lr, hr) in enumerate(progress_bar):
            batch_size = lr.size(0)
            
            # Move data to device
            lr = lr.to(device)
            hr = hr.to(device)
            
            # Adversarial ground truths
            valid = torch.ones((batch_size, 1, 8, 8), requires_grad=False).to(device)
            fake = torch.zeros((batch_size, 1, 8, 8), requires_grad=False).to(device)
            
            # ------------------
            #  Train Generator
            # ------------------
            g_optimizer.zero_grad()
            
            # Generate high-resolution images
            sr = generator(lr)
            
            # Adversarial loss
            pred_fake = discriminator(sr)
            g_adv_loss = adversarial_criterion(pred_fake, valid)
            
            # Content loss (pixel-wise)
            g_content_loss = content_criterion(sr, hr)
            
            # Perceptual loss
            g_percep_loss = perceptual_criterion(sr, hr)
            
            # Total generator loss
            g_loss = 0.01 * g_adv_loss + 1.0 * g_content_loss + 1.0 * g_percep_loss
            
            g_loss.backward()
            g_optimizer.step()
            
            #  Train Discriminator

            d_optimizer.zero_grad()
            
            # Real loss
            pred_real = discriminator(hr)
            d_real_loss = adversarial_criterion(pred_real, valid)
            
            # Fake loss
            pred_fake = discriminator(sr.detach())
            d_fake_loss = adversarial_criterion(pred_fake, fake)
            
            # Total discriminator loss
            d_loss = (d_real_loss + d_fake_loss) / 2
            
            d_loss.backward()
            d_optimizer.step()
            
            # Update progress bar
            epoch_g_loss += g_loss.item()
            epoch_d_loss += d_loss.item()
            progress_bar.set_description(
                f"Epoch {epoch+1}/{epochs} | G Loss: {epoch_g_loss/(i+1):.6f} | D Loss: {epoch_d_loss/(i+1):.6f}"
            )
        
        # Save sample images
        if (epoch + 1) % 10 == 0:
            with torch.no_grad():
                sample_lr = next(iter(dataloader))[0][:4].to(device)
                sample_sr = generator(sample_lr)
                
                # Denormalize
                sample_lr = torch.clamp(sample_lr, 0, 1)
                sample_sr = torch.clamp(sample_sr, 0, 1)
                
                # Save grid of images
                grid_lr = make_grid(sample_lr, nrow=2, normalize=True)
                grid_sr = make_grid(sample_sr, nrow=2, normalize=True)
                
                save_image(grid_lr, os.path.join(Config.BASE_OUTPUT_PATH, f"train_lr_epoch_{epoch+1}.png"))
                save_image(grid_sr, os.path.join(Config.BASE_OUTPUT_PATH, f"train_sr_epoch_{epoch+1}.png"))
    
    # Save the trained generator
    torch.save(generator.state_dict(), Config.GEN_PATH)
    print(f"ESRGAN generator saved to {Config.GEN_PATH}")

# Inference function
def inference(generator, lr_img_path, output_path):
    generator.eval()
    
    # Load and preprocess the low-resolution image
    lr_img = Image.open(lr_img_path).convert('RGB')
    lr_tensor = transforms.ToTensor()(lr_img).unsqueeze(0).to(device)
    
    with torch.no_grad():
        sr_tensor = generator(lr_tensor)
        sr_tensor = torch.clamp(sr_tensor, 0, 1)
    
    # Convert tensor to image and save
    sr_img = transforms.ToPILImage()(sr_tensor.squeeze(0).cpu())
    sr_img.save(output_path)
    
    return sr_img

In [None]:
# Main execution
if __name__ == "__main__":
    # Create dataset and dataloader
    dataset = DIV2KDataset(Config.DIV2K_PATH, scale_factor=Config.SCALE_FACTOR, transform=transform)
    dataloader = DataLoader(dataset, batch_size=Config.BATCH_SIZE, shuffle=True, num_workers=4)
    
    # Initialize models
    generator = Generator(
        filters=Config.FEATURE_MAPS,
        num_res_blocks=Config.RESIDUAL_BLOCKS,
        upscale_factor=Config.SCALE_FACTOR
    ).to(device)
    
    discriminator = Discriminator(
        filters=Config.FEATURE_MAPS
    ).to(device)
    
    # Initialize loss functions
    content_criterion = ContentLoss().to(device)
    perceptual_criterion = PerceptualLoss().to(device)
    adversarial_criterion = AdversarialLoss().to(device)
    
    # Pretrain generator
    print("Pretraining generator...")
    pretrain_optimizer = optim.Adam(generator.parameters(), lr=Config.PRETRAIN_LR)
    pretrain_generator(
        generator=generator,
        dataloader=dataloader,
        epochs=Config.PRETRAIN_EPOCHS,
        optimizer=pretrain_optimizer,
        criterion=content_criterion,
        device=device
    )
    
    # Load pretrained generator
    generator.load_state_dict(torch.load(Config.PRETRAINED_GEN_PATH))
    
    # Train full ESRGAN
    print("Training ESRGAN...")
    g_optimizer = optim.Adam(generator.parameters(), lr=Config.FINETUNE_LR)
    d_optimizer = optim.Adam(discriminator.parameters(), lr=Config.FINETUNE_LR)
    
    train_esrgan(
        generator=generator,
        discriminator=discriminator,
        dataloader=dataloader,
        epochs=Config.FINETUNE_EPOCHS,
        g_optimizer=g_optimizer,
        d_optimizer=d_optimizer,
        content_criterion=content_criterion,
        perceptual_criterion=perceptual_criterion,
        adversarial_criterion=adversarial_criterion,
        device=device
    )
    
    # Inference example
    print("Running inference on a sample image...")
    sample_lr_path = "sample_lr.png"  # Replace with your sample image path
    sample_sr_path = os.path.join(Config.BASE_OUTPUT_PATH, "sample_sr.png")
    
    if os.path.exists(sample_lr_path):
        sr_img = inference(generator, sample_lr_path, sample_sr_path)
        
        # Display results
        lr_img = Image.open(sample_lr_path).convert('RGB')
        
        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.title("Low Resolution")
        plt.imshow(lr_img)
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.title("Super Resolution (ESRGAN)")
        plt.imshow(sr_img)
        plt.axis('off')
        
        plt.tight_layout()
        plt.show()
    else:
        print(f"Sample image not found at {sample_lr_path}")
