In [None]:
import torch
import torch.nn as nn
from torch.nn.utils import spectral_norm

# ------------------------ Residual Block ------------------------ #
class ResBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        residual = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual
        return self.relu(out)

# ------------------------ Generator ------------------------ #
class Generator(nn.Module):
    def __init__(self, in_channels=3, num_features=64):
        super().__init__()
        self.conv1       = nn.Conv2d(in_channels, num_features, 3, padding=1)
        self.res_blocks  = nn.Sequential(*[ResBlock(num_features) for _ in range(8)])
        
        # first upsample + sharpen
        self.upconv1     = nn.ConvTranspose2d(num_features, num_features, 4, stride=2, padding=1)
        self.res_up1     = ResBlock(num_features)
        
        # second upsample + sharpen
        self.upconv2     = nn.ConvTranspose2d(num_features, num_features, 4, stride=2, padding=1)
        self.res_up2     = ResBlock(num_features)
        
        # final to RGB
        self.conv_final  = nn.Conv2d(num_features, in_channels, 3, padding=1)
        self.tanh        = nn.Tanh()

    def forward(self, x):
        # 1) initial conv 
        feat          = self.conv1(x)
        
        # 2) deep residual blocks
        out           = self.res_blocks(feat)
        
        # 3) upsample #1 + sharpening
        out           = self.upconv1(out)
        out           = self.res_up1(out)
        
        # 4) upsample #2 + sharpening
        out           = self.upconv2(out)
        out           = self.res_up2(out)
        
        # 5) to RGB
        out           = self.conv_final(out)
        return out



# ------------------------ Discriminator ------------------------ #
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, base_features=64):
        super().__init__()
        def sn_conv(in_f, out_f, k, s, p):
            # SpectralNorm + Conv + LeakyReLU
            return nn.Sequential(
                spectral_norm(nn.Conv2d(in_f, out_f, kernel_size=k, stride=s, padding=p)),
                nn.LeakyReLU(0.2, inplace=True)
            )

        # 70×70 PatchGAN:
        self.model = nn.Sequential(
            # input: N×3×H×W → N×64×H/2×W/2
            *sn_conv(in_channels,   base_features,    4, 2, 1),
            # N×64→128, H/2→H/4
            *sn_conv(base_features, base_features*2,  4, 2, 1),
            # N×128→256, H/4→H/8
            *sn_conv(base_features*2, base_features*4,4, 2, 1),
            # N×256→512, H/8→H/16 (stride=1 to keep patch size ~70)
            *sn_conv(base_features*4, base_features*8,4, 1, 1),
            # Final conv to 1-channel “realness” map
            spectral_norm(nn.Conv2d(base_features*8, 1, kernel_size=4, stride=1, padding=1))
        )

    def forward(self, x):
        # returns N×1×H’×W’ patch map (no sigmoid)
        return self.model(x)

In [None]:
import torch
from torchsummary import summary 

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the model
model = Generator().to(device)

# Dummy input - adjust size based on your expected LR image size
dummy_input = torch.randn(1, 3, 64, 64).to(device)

# Forward pass
with torch.no_grad():
    output = model(dummy_input)

# Print input and output shapes
print(f"Input shape: {dummy_input.shape}")
print(f"Output shape: {output.shape}")

# Optional: summary of model
# summary(model, input_size=(3, 128, 128))


In [None]:
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import random

class DIV2KDataset(Dataset):
    def __init__(self, root_dir="/kaggle/input/div2k-dataset/", train=True, scale_factor=2):
        """
        Args:
            root_dir (str): Directory with DIV2K dataset
            train (bool): If True, creates dataset from training set, otherwise from validation set
            scale_factor (int): Scale factor for super-resolution (default: 4x upscaling)
        """
        self.root_dir = root_dir
        self.train = train
        self.scale_factor = scale_factor
        
        # Get list of image files
        self.image_files = self._get_image_files()
        
        # Define transforms for LR images
        self.lr_transform = transforms.Compose([
            transforms.Resize(size=(128 // self.scale_factor, 128 // self.scale_factor), 
                           interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
        ])

    def _get_image_files(self):
        """Get list of image files from DIV2K dataset"""
        if self.train:
            image_dir = os.path.join(self.root_dir, "DIV2K_train_HR")
        else:
            image_dir = os.path.join(self.root_dir, "DIV2K_valid_HR")
            
        files = [f for f in os.listdir(image_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
        files.sort()  # Ensure consistent ordering
        
        if self.train:
            return files[:800]  # First 800 images for training
        return files  # All validation images

    def __len__(self):
        """Return the total number of images in the dataset"""
        return len(self.image_files)

    def __getitem__(self, idx):
        """Return a dictionary containing both LR and HR images"""
        img_name = self.image_files[idx]
        
        # Load HR image
        if self.train:
            img_path = os.path.join(self.root_dir, "DIV2K_train_HR", img_name)
        else:
            img_path = os.path.join(self.root_dir, "DIV2K_valid_HR", img_name)
            
        hr_image = Image.open(img_path).convert('RGB')
        
        # Random crop for training
        if self.train:
            # Random crop to fixed size for HR
            hr_cropped = transforms.RandomCrop(128 * self.scale_factor)(hr_image)
            # Create LR version
            lr_image = self.lr_transform(hr_cropped)
            # Convert HR to tensor
            hr_tensor = transforms.ToTensor()(hr_cropped)
            # hr_tensor = hr_tensor * 2 - 1
        else:
            # For validation, use center crop
            hr_cropped = transforms.CenterCrop(128 * self.scale_factor)(hr_image)
            lr_image = self.lr_transform(hr_cropped)
            hr_tensor = transforms.ToTensor()(hr_cropped)

        return {
            'lr': lr_image,  # Low resolution image
            'hr': hr_tensor  # High resolution image
        }

# Example usage:
if __name__ == "__main__":
    # Create dataset instance
    dataset = DIV2KDataset(root_dir="/kaggle/input/div2k_train_hr/", train=True)
    print(f"Dataset size: {len(dataset)}")
    
    # Get a sample
    sample = dataset[0]
    print(f"LR image shape: {sample['lr'].shape}")
    print(f"HR image shape: {sample['hr'].shape}")


In [None]:
import torchvision.models as models
import torch.nn.functional as F
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os
import kornia
import kornia.losses
import kornia.filters
from torchvision.utils import save_image
import matplotlib.pyplot as plt
from tqdm import tqdm
import os


class PerceptualLoss(nn.Module):
    def __init__(self, layers=[3, 8, 15, 22]):
        super().__init__()
        vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_FEATURES).features
        self.selected_layers = layers
        self.vgg = vgg.eval()
        for param in self.vgg.parameters():
            param.requires_grad = False

    def forward(self, x, y):
        loss = 0
        for i, layer in enumerate(self.vgg):
            x = layer(x)
            y = layer(y)
            if i in self.selected_layers:
                loss += F.l1_loss(x, y)
        return loss

# Edge Loss
def edge_loss(pred, target):
    pred_edges = kornia.filters.sobel(pred)
    target_edges = kornia.filters.sobel(target)
    return nn.L1Loss()(pred_edges, target_edges)

def tv_loss(img):
    return torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:])) + \
           torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]))

def color_loss(fake, real):
    mean_fake = fake.mean(dim=[2,3]) 
    mean_real = real.mean(dim=[2,3])
    return nn.functional.l1_loss(mean_fake, mean_real)


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import torchmetrics

def train(epochs=5, batch_size=8, lr=0.0002, b1=0.5, b2=0.999):
    os.makedirs("results", exist_ok=True)
    os.makedirs("saved_models", exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # --- Create Models and Wrap with DataParallel ---
    generator = Generator().to(device)
    discriminator = Discriminator().to(device)

    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs")
        generator = nn.DataParallel(generator)
        discriminator = nn.DataParallel(discriminator)
    else:
        print("Using single GPU or CPU")

    # --- Losses and Optimizers ---
    criterion_GAN = nn.MSELoss()
    criterion_content = nn.L1Loss()
    criterion_edge = edge_loss
    criterion_perceptual = PerceptualLoss().to(device)

    optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

    # --- Metrics ---
    psnr_metric = torchmetrics.PeakSignalNoiseRatio().to(device)
    ssim_metric = torchmetrics.StructuralSimilarityIndexMeasure(data_range=1.0).to(device)

    dataset = DIV2KDataset(root_dir="/kaggle/input/div2k_train_hr/", train=True)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    generator_losses = []
    discriminator_losses = []
    psnr_scores = []
    ssim_scores = []

    for epoch in range(epochs):
        g_loss_epoch = 0
        d_loss_epoch = 0
        psnr_epoch = 0
        ssim_epoch = 0
        image_shown = False

        for i, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")):
            lr_imgs = batch['lr'].to(device)
            hr_imgs = batch['hr'].to(device)
            downsampled_hr_imgs = F.interpolate(hr_imgs, size=(64, 64), mode='bilinear', align_corners=False)

            valid = torch.empty((lr_imgs.size(0), 1, 30, 30), device=device).uniform_(0.9, 1.0)
            fake = torch.empty((lr_imgs.size(0), 1, 30, 30), device=device).uniform_(0.0, 0.1)

            # --- Train Generator ---
            optimizer_G.zero_grad()
            gen_hr = generator(lr_imgs)
            pred_fake = discriminator(gen_hr)

            gen_hr_clamped = gen_hr.clamp(0, 1)

            # Generator Loss
            loss_GAN = criterion_GAN(pred_fake, valid)
            loss_content = criterion_content(gen_hr, hr_imgs)
            loss_perceptual = criterion_perceptual(gen_hr, hr_imgs)
            loss_edge = criterion_edge(gen_hr, hr_imgs)
            loss_tv = tv_loss(gen_hr)
            loss_color = color_loss(gen_hr, hr_imgs)

            loss_G = (
                0.8 * loss_content +
                5e-3 * loss_GAN +
                0.05 * loss_perceptual +
                0.05 * loss_edge +
                0.001 * loss_tv +
                0.01 * loss_color  
            )

            loss_G.backward()
            optimizer_G.step()

            # --- Train Discriminator ---
            optimizer_D.zero_grad()
            pred_real = discriminator(hr_imgs)
            loss_real = criterion_GAN(pred_real, valid)

            pred_fake = discriminator(gen_hr.detach())
            loss_fake = criterion_GAN(pred_fake, fake)

            loss_D = (loss_real + loss_fake) / 2
            loss_D.backward()
            optimizer_D.step()

            g_loss_epoch += loss_G.item()
            d_loss_epoch += loss_D.item()

            # --- Calculate Metrics ---
            with torch.no_grad():
                # gen_hr_clamped = gen_hr.clamp(0, 1)
                hr_imgs_clamped = hr_imgs.clamp(0, 1)

                psnr_score = psnr_metric(gen_hr_clamped, hr_imgs_clamped)
                ssim_score = ssim_metric(gen_hr_clamped, hr_imgs_clamped)

                psnr_epoch += psnr_score.item()
                ssim_epoch += ssim_score.item()

        # Visualize Generated Images (first batch only)
        if not image_shown:
            num_images = min(gen_hr.size(0), 4)
            gen_images = gen_hr[:num_images].detach().cpu()
            real_images = hr_imgs[:num_images].detach().cpu()

            fig, axes = plt.subplots(2, num_images, figsize=(12, 6))
            for idx in range(num_images):
                axes[0, idx].imshow(gen_images[idx].permute(1, 2, 0).clamp(0, 1))
                axes[0, idx].set_title("Generated")
                axes[0, idx].axis("off")

                axes[1, idx].imshow(real_images[idx].permute(1, 2, 0).clamp(0, 1))
                axes[1, idx].set_title("Real")
                axes[1, idx].axis("off")

            plt.suptitle(f"Epoch {epoch+1}")
            plt.tight_layout()
            plt.show()
            image_shown = True

        # Save Average Loss and Metrics for Epoch
        avg_g_loss = g_loss_epoch / len(dataloader)
        avg_d_loss = d_loss_epoch / len(dataloader)
        avg_psnr = psnr_epoch / len(dataloader)
        avg_ssim = ssim_epoch / len(dataloader)

        generator_losses.append(avg_g_loss)
        discriminator_losses.append(avg_d_loss)
        psnr_scores.append(avg_psnr)
        ssim_scores.append(avg_ssim)

        print(f"[Epoch {epoch+1}/{epochs}] Generator Loss: {avg_g_loss:.4f}, Discriminator Loss: {avg_d_loss:.4f}")
        print(f"[Epoch {epoch+1}/{epochs}] PSNR: {avg_psnr:.2f}, SSIM: {avg_ssim:.4f}")

        # Save model checkpoints
        if (epoch + 1) % 50 == 0:
            torch.save(generator.module.state_dict(), f"saved_models/generator_epoch_{epoch+1}.pth")
            torch.save(discriminator.module.state_dict(), f"saved_models/discriminator_epoch_{epoch+1}.pth")

        torch.cuda.empty_cache()

    # --- Plot Losses ---
    plt.figure(figsize=(10, 6))
    plt.plot(generator_losses, label="Generator Loss", color='blue')
    plt.plot(discriminator_losses, label="Discriminator Loss", color='red')
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title("Training Losses Over Epochs")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig("results/loss_plot.png")
    plt.show()

    # --- Plot PSNR and SSIM ---
    plt.figure(figsize=(10, 6))
    plt.plot(psnr_scores, label="PSNR", color='green')
    plt.plot(ssim_scores, label="SSIM", color='purple')
    plt.xlabel("Epochs")
    plt.ylabel("Score")
    plt.title("Image Quality Metrics Over Epochs")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig("results/metrics_plot.png")
    plt.show()


In [None]:
torch.cuda.empty_cache()
train(epochs=200, batch_size=8)

In [None]:
torch.cuda.empty_cache()