In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/celebdata/img_align_celeba/000304.jpg
/kaggle/input/celebdata/img_align_celeba/001940.jpg
/kaggle/input/celebdata/img_align_celeba/001229.jpg
/kaggle/input/celebdata/img_align_celeba/000875.jpg
/kaggle/input/celebdata/img_align_celeba/001989.jpg
/kaggle/input/celebdata/img_align_celeba/001071.jpg
/kaggle/input/celebdata/img_align_celeba/000780.jpg
/kaggle/input/celebdata/img_align_celeba/001506.jpg
/kaggle/input/celebdata/img_align_celeba/003113.jpg
/kaggle/input/celebdata/img_align_celeba/003224.jpg
/kaggle/input/celebdata/img_align_celeba/002056.jpg
/kaggle/input/celebdata/img_align_celeba/004100.jpg
/kaggle/input/celebdata/img_align_celeba/000705.jpg
/kaggle/input/celebdata/img_align_celeba/000748.jpg
/kaggle/input/celebdata/img_align_celeba/000900.jpg
/kaggle/input/celebdata/img_align_celeba/002793.jpg
/kaggle/input/celebdata/img_align_celeba/000651.jpg
/kaggle/input/celebdata/img_align_celeba/000857.jpg
/kaggle/input/celebdata/img_align_celeba/000853.jpg
/kaggle/inpu

In [2]:
import glob
import random
import os
import numpy as np

import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

# Normalization parameters for pre-trained PyTorch models
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])


def denormalize(tensors):
    """ Denormalizes image tensors using mean and std """
    for c in range(3):
        tensors[:, c].mul_(std[c]).add_(mean[c])
    return torch.clamp(tensors, 0, 255)


class ImageDataset(Dataset):
    def __init__(self, root, hr_shape):
        hr_height, hr_width = hr_shape
        # Transforms for low resolution images and high resolution images
        self.lr_transform = transforms.Compose(
            [
                transforms.Resize((hr_height // 4, hr_height // 4), Image.BICUBIC),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]
        )
        self.hr_transform = transforms.Compose(
            [
                transforms.Resize((hr_height, hr_height), Image.BICUBIC),
                transforms.ToTensor(),
                transforms.Normalize(mean, std),
            ]
        )

        self.files = sorted(glob.glob(root + "/*.*"))

    def __getitem__(self, index):
        img = Image.open(self.files[index % len(self.files)])
        img_lr = self.lr_transform(img)
        img_hr = self.hr_transform(img)

        return {"lr": img_lr, "hr": img_hr}

    def __len__(self):
        return len(self.files)


In [3]:
import torch.nn as nn
import torch.nn.functional as F
import torch
from torchvision.models import vgg19
import math


class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        vgg19_model = vgg19(pretrained=True)
        self.vgg19_54 = nn.Sequential(*list(vgg19_model.features.children())[:35])

    def forward(self, img):
        return self.vgg19_54(img)


class DenseResidualBlock(nn.Module):
    """
    The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
    """

    def __init__(self, filters, res_scale=0.2):
        super(DenseResidualBlock, self).__init__()
        self.res_scale = res_scale

        def block(in_features, non_linearity=True):
            layers = [nn.Conv2d(in_features, filters, 3, 1, 1, bias=True)]
            if non_linearity:
                layers += [nn.LeakyReLU()]
            return nn.Sequential(*layers)

        self.b1 = block(in_features=1 * filters)
        self.b2 = block(in_features=2 * filters)
        self.b3 = block(in_features=3 * filters)
        self.b4 = block(in_features=4 * filters)
        self.b5 = block(in_features=5 * filters, non_linearity=False)
        self.blocks = [self.b1, self.b2, self.b3, self.b4, self.b5]

    def forward(self, x):
        inputs = x
        for block in self.blocks:
            out = block(inputs)
            inputs = torch.cat([inputs, out], 1)
        return out.mul(self.res_scale) + x


class ResidualInResidualDenseBlock(nn.Module):
    def __init__(self, filters, res_scale=0.2):
        super(ResidualInResidualDenseBlock, self).__init__()
        self.res_scale = res_scale
        self.dense_blocks = nn.Sequential(
            DenseResidualBlock(filters), DenseResidualBlock(filters), DenseResidualBlock(filters)
        )

    def forward(self, x):
        return self.dense_blocks(x).mul(self.res_scale) + x


class GeneratorRRDB(nn.Module):
    def __init__(self, channels, filters=64, num_res_blocks=16, num_upsample=2):
        super(GeneratorRRDB, self).__init__()

        # First layer
        self.conv1 = nn.Conv2d(channels, filters, kernel_size=3, stride=1, padding=1)
        # Residual blocks
        self.res_blocks = nn.Sequential(*[ResidualInResidualDenseBlock(filters) for _ in range(num_res_blocks)])
        # Second conv layer post residual blocks
        self.conv2 = nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1)
        # Upsampling layers
        upsample_layers = []
        for _ in range(num_upsample):
            upsample_layers += [
                nn.Conv2d(filters, filters * 4, kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU(),
                nn.PixelShuffle(upscale_factor=2),
            ]
        self.upsampling = nn.Sequential(*upsample_layers)
        # Final output block
        self.conv3 = nn.Sequential(
            nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(filters, channels, kernel_size=3, stride=1, padding=1),
        )

    def forward(self, x):
        out1 = self.conv1(x)
        out = self.res_blocks(out1)
        out2 = self.conv2(out)
        out = torch.add(out1, out2)
        out = self.upsampling(out)
        out = self.conv3(out)
        return out


class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()

        self.input_shape = input_shape
        in_channels, in_height, in_width = self.input_shape
        patch_h, patch_w = int(in_height / 2 ** 4), int(in_width / 2 ** 4)
        self.output_shape = (1, patch_h, patch_w)

        def discriminator_block(in_filters, out_filters, first_block=False):
            layers = []
            layers.append(nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1))
            if not first_block:
                layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            layers.append(nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=2, padding=1))
            layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        layers = []
        in_filters = in_channels
        for i, out_filters in enumerate([64, 128, 256, 512]):
            layers.extend(discriminator_block(in_filters, out_filters, first_block=(i == 0)))
            in_filters = out_filters

        layers.append(nn.Conv2d(out_filters, 1, kernel_size=3, stride=1, padding=1))

        self.model = nn.Sequential(*layers)

    def forward(self, img):
        return self.model(img)


In [4]:
# Importing necessary libraries and modules
import os
import numpy as np
import itertools

import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid

from torch.utils.data import DataLoader
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

# Helper function to denormalize images before saving
def denormalize(img):
    return img.mul(0.5).add(0.5)

# Define the Generator, Discriminator, and FeatureExtractor classes (you need to implement these classes)

# ... (Define the classes for Generator, Discriminator, and FeatureExtractor here)

# Set your desired hyperparameters here
opt = {
    "epoch": 0,
    "n_epochs": 5,
    "dataset_name": "/kaggle/input/celebdata/img_align_celeba",
    "batch_size": 4,
    "lr": 0.0002,
    "b1": 0.9,
    "b2": 0.999,
    "decay_epoch": 100,
    "n_cpu": 8,
    "hr_height": 256,
    "hr_width": 256,
    "channels": 3,
    "sample_interval": 100,
    "checkpoint_interval": 5000,
    "residual_blocks": 23,
    "warmup_batches": 500,
    "lambda_adv": 5e-3,
    "lambda_pixel": 1e-2,
}

# Create the necessary directories to save training images and model checkpoints
os.makedirs("/kaggle/working/images/training", exist_ok=True)
os.makedirs("/kaggle/working/saved_models", exist_ok=True)

# Set your desired GPU device if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Initialize generator, discriminator, and feature extractor models
generator = GeneratorRRDB(opt["channels"], filters=64, num_res_blocks=opt["residual_blocks"]).to(device)
discriminator = Discriminator(input_shape=(opt["channels"], opt["hr_height"], opt["hr_width"])).to(device)
feature_extractor = FeatureExtractor().to(device)

# Set feature extractor to inference mode
feature_extractor.eval()

# Loss functions
criterion_GAN = torch.nn.BCEWithLogitsLoss().to(device)
criterion_content = torch.nn.L1Loss().to(device)
criterion_pixel = torch.nn.L1Loss().to(device)

# Initialize optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt["lr"], betas=(opt["b1"], opt["b2"]))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt["lr"], betas=(opt["b1"], opt["b2"]))

# DataLoader for the dataset
dataloader = DataLoader(
    ImageDataset(opt["dataset_name"], hr_shape=(opt["hr_height"], opt["hr_width"])),
    batch_size=opt["batch_size"],
    shuffle=True,
    num_workers=opt["n_cpu"],
)

# Start the training loop with your desired number of epochs
for epoch in range(opt["epoch"], opt["n_epochs"]):
    for i, imgs in enumerate(dataloader):

        batches_done = epoch * len(dataloader) + i

        # Configure model input
        imgs_lr = Variable(imgs["lr"].to(device))
        imgs_hr = Variable(imgs["hr"].to(device))

        # Adversarial ground truths
        valid = Variable(torch.ones((imgs_lr.size(0), *discriminator.output_shape)).to(device), requires_grad=False)
        fake = Variable(torch.zeros((imgs_lr.size(0), *discriminator.output_shape)).to(device), requires_grad=False)

        # Train Generators
        optimizer_G.zero_grad()

        # Generate a high-resolution image from the low-resolution input
        gen_hr = generator(imgs_lr)

        # Measure pixel-wise loss against ground truth
        loss_pixel = criterion_pixel(gen_hr, imgs_hr)

        if batches_done < opt["warmup_batches"]:
            # Warm-up (pixel-wise loss only)
            loss_pixel.backward()
            optimizer_G.step()
            print("[Epoch %d/%d] [Batch %d/%d] [G pixel: %f]" % (epoch, opt["n_epochs"], i, len(dataloader), loss_pixel.item()))
            continue

        # Extract validity predictions from discriminator
        pred_real = discriminator(imgs_hr).detach()
        pred_fake = discriminator(gen_hr)

        # Adversarial loss (relativistic average GAN)
        loss_GAN = criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), valid)

        # Content loss
        gen_features = feature_extractor(gen_hr)
        real_features = feature_extractor(imgs_hr).detach()
        loss_content = criterion_content(gen_features, real_features)

        # Total generator loss
        loss_G = loss_content + opt["lambda_adv"] * loss_GAN + opt["lambda_pixel"] * loss_pixel

        loss_G.backward()
        optimizer_G.step()

        # Train Discriminator
        optimizer_D.zero_grad()

        pred_real = discriminator(imgs_hr)
        pred_fake = discriminator(gen_hr.detach())

        # Adversarial loss for real and fake images (relativistic average GAN)
        loss_real = criterion_GAN(pred_real - pred_fake.mean(0, keepdim=True), valid)
        loss_fake = criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), fake)

        # Total loss
        loss_D = (loss_real + loss_fake) / 2

        loss_D.backward()
        optimizer_D.step()

        # Log Progress
        print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, content: %f, adv: %f, pixel: %f]" % (
            epoch, opt["n_epochs"], i, len(dataloader), loss_D.item(), loss_G.item(),
            loss_content.item(), loss_GAN.item(), loss_pixel.item()))

        if batches_done % opt["sample_interval"] == 0:
            # Save image grid with upsampled inputs and ESRGAN outputs
            imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)
            img_grid = denormalize(torch.cat((imgs_lr, gen_hr), -1))
            save_image(img_grid, "/kaggle/working/images/training/%d.png" % batches_done, nrow=1, normalize=False)

        if batches_done % opt["checkpoint_interval"] == 0:
            # Save model checkpoints
            torch.save(generator.state_dict(), "/kaggle/working/saved_models/generator_%d.pth" % epoch)
            torch.save(discriminator.state_dict(), "/kaggle/working/saved_models/discriminator_%d.pth" % epoch)
torch.save(generator.state_dict(), "/kaggle/working/saved_models/generator_final.pth")
torch.save(discriminator.state_dict(), "/kaggle/working/saved_models/discriminator_final.pth")

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:03<00:00, 177MB/s]


[Epoch 0/5] [Batch 0/1000] [G pixel: 1.175385]
[Epoch 0/5] [Batch 1/1000] [G pixel: 0.958973]
[Epoch 0/5] [Batch 2/1000] [G pixel: 0.966229]
[Epoch 0/5] [Batch 3/1000] [G pixel: 0.845876]
[Epoch 0/5] [Batch 4/1000] [G pixel: 0.677188]
[Epoch 0/5] [Batch 5/1000] [G pixel: 0.517680]
[Epoch 0/5] [Batch 6/1000] [G pixel: 0.527534]
[Epoch 0/5] [Batch 7/1000] [G pixel: 0.442639]
[Epoch 0/5] [Batch 8/1000] [G pixel: 0.421113]
[Epoch 0/5] [Batch 9/1000] [G pixel: 0.384654]
[Epoch 0/5] [Batch 10/1000] [G pixel: 0.359078]
[Epoch 0/5] [Batch 11/1000] [G pixel: 0.396437]
[Epoch 0/5] [Batch 12/1000] [G pixel: 0.349398]
[Epoch 0/5] [Batch 13/1000] [G pixel: 0.317381]
[Epoch 0/5] [Batch 14/1000] [G pixel: 0.309264]
[Epoch 0/5] [Batch 15/1000] [G pixel: 0.270890]
[Epoch 0/5] [Batch 16/1000] [G pixel: 0.286889]
[Epoch 0/5] [Batch 17/1000] [G pixel: 0.286446]
[Epoch 0/5] [Batch 18/1000] [G pixel: 0.336552]
[Epoch 0/5] [Batch 19/1000] [G pixel: 0.298718]
[Epoch 0/5] [Batch 20/1000] [G pixel: 0.233191]
[E