In [1]:
import math
import torch
from torch import nn
from torchvision.models.vgg import vgg16

from os import listdir
from os.path import join

from PIL import Image
from torch.utils.data.dataset import Dataset
from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, CenterCrop, Resize


In [2]:
def is_image_file(filename):
    return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])

def calculate_valid_crop_size(crop_size, upscale_factor):
    return crop_size - (crop_size % upscale_factor)

def train_hr_transform(crop_size):
    return Compose([
        RandomCrop(crop_size),
        ToTensor(),
    ])

def train_lr_transform(crop_size, upscale_factor):
    return Compose([
        ToPILImage(),
        Resize(crop_size // upscale_factor, interpolation=Image.BICUBIC),
        ToTensor()
    ])

def display_transform():
    return Compose([
        ToPILImage(),
        Resize(400),
        CenterCrop(400),
        ToTensor()
    ])

class TrainDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, crop_size, upscale_factor=2):
        super(TrainDatasetFromFolder, self).__init__()
        self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]
        crop_size = calculate_valid_crop_size(crop_size, upscale_factor)
        self.hr_transform = train_hr_transform(crop_size)
        self.lr_transform = train_lr_transform(crop_size, upscale_factor)

    def __getitem__(self, index):
        hr_image = self.hr_transform(Image.open(self.image_filenames[index]))
        lr_image = self.lr_transform(hr_image)
        return lr_image, hr_image

    def __len__(self):
        return len(self.image_filenames)


class ValDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, crop_size, upscale_factor=2):
        super(ValDatasetFromFolder, self).__init__()
        self.upscale_factor = upscale_factor
        self.crop_size = crop_size
        self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]

    def __getitem__(self, index):
        hr_image = Image.open(self.image_filenames[index])
        w, h = hr_image.size
        self.crop_size = calculate_valid_crop_size(min(w, h), self.upscale_factor)
        lr_scale = Resize(self.crop_size // self.upscale_factor, interpolation=Image.BICUBIC)
        hr_scale = Resize(self.crop_size, interpolation=Image.BICUBIC)
        hr_image = CenterCrop(crop_size)(hr_image)
        lr_image = lr_scale(hr_image)
        hr_restore_img = hr_scale(lr_image)
        return ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image)

    def __len__(self):
        return len(self.image_filenames)


class TestDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, upscale_factor=2):
        super(TestDatasetFromFolder, self).__init__()
        self.lr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/data/'
        self.hr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/target/'
        self.upscale_factor = upscale_factor
        self.lr_filenames = [join(self.lr_path, x) for x in listdir(self.lr_path) if is_image_file(x)]
        self.hr_filenames = [join(self.hr_path, x) for x in listdir(self.hr_path) if is_image_file(x)]

    def __getitem__(self, index):
        image_name = self.lr_filenames[index].split('/')[-1]
        lr_image = Image.open(self.lr_filenames[index])
        w, h = lr_image.size
        hr_image = Image.open(self.hr_filenames[index])
        hr_scale = Resize((self.upscale_factor * h, self.upscale_factor * w), interpolation=Image.BICUBIC)
        hr_restore_img = hr_scale(lr_image)
        return image_name, ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image)

    def __len__(self):
        return len(self.lr_filenames)

In [None]:
class GeneratorLoss(nn.Module):
    def __init__(self):
        super(GeneratorLoss, self).__init__()
        vgg = vgg16(pretrained=True)
        loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
        for param in loss_network.parameters():
            param.requires_grad = False
        self.loss_network = loss_network
        self.mse_loss = nn.MSELoss()
        self.tv_loss = TVLoss()

    def forward(self, out_labels, out_images, target_images):
        adversarial_loss = torch.mean(1 - out_labels)
        perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images))
        image_loss = self.mse_loss(out_images, target_images)
        tv_loss = self.tv_loss(out_images)
        return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss


class TVLoss(nn.Module):
    def __init__(self, tv_loss_weight=1):
        super(TVLoss, self).__init__()
        self.tv_loss_weight = tv_loss_weight

    def forward(self, x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h = self.tensor_size(x[:, :, 1:, :])
        count_w = self.tensor_size(x[:, :, :, 1:])
        h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
        w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
        return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size

    @staticmethod
    def tensor_size(t):
        return t.size()[1] * t.size()[2] * t.size()[3]


In [4]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(channels),
            nn.PReLU(),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(channels)
        )

    def forward(self, x):
        residual = self.layers(x)
        return x + residual


class UpsampleBLock(nn.Module):
    def __init__(self, in_channels, up_scale=2):
        super(UpsampleBLock, self).__init__()
        self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(up_scale)
        self.prelu = nn.PReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.prelu(x)
        return x

class Conv_Dis(nn.Module):
  def __init__(self, in_size:int, out_size:int):
    super(Conv_Dis, self).__init__()
    self.layers = nn.ModuleList([
        nn.Conv2d(in_size, out_size, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_size),
        nn.LeakyReLU(0.2),
    ])
    def forward(self, x):
      for layer in self.layers:
        x = layer(x)
      return x

class Generator(nn.Module):
    def __init__(self, scale_factor:int=2, num_hidden_residuals:int=5):
        upsample_block_num = int(math.log(scale_factor, 2))

        super(Generator, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=9, padding=4),
            nn.PReLU()
        )
        self.hidden_residuals = nn.Sequential(*[ResidualBlock(64) for _ in range(num_hidden_residuals)])
        self.block7 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64)
        )
        block8 = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)]
        block8.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))
        self.block8 = nn.Sequential(*block8)

    def forward(self, x):
        block1 = self.block1(x)
        block_n = self.hidden_residuals(block1)
        block_n1 = self.block8(block1 + block_n)

        return (torch.tanh(block_n1) + 1) / 2

class Discriminator(nn.Module):
    def __init__(self, in_size:int=3, out_size:int=1, hidden_size:int=64, num_hidden_layers=7):
        super(Discriminator, self).__init__()
        h = hidden_size
        self.start_net = nn.Sequential(
            nn.Conv2d(in_size, h, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2)
        )
        self.hidden_layers = nn.ModuleList()
        for _ in range(num_hidden_layers):
            h_next = h if _%2==0 else h * 2
            self.hidden_layers.append(Conv_Dis(h, h_next))
            h = h_next

        self.end_net = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(h, h*2, kernel_size=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(h*2, out_size, kernel_size=1)
        )

    def forward(self, x):
        batch_size = x.size(0)
        x = self.start_net(x)
        for layer in self.hidden_layers:
          x = layer(x)
        return torch.sigmoid(self.end_net(x).view(batch_size))


In [8]:
!wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_wild.zip
!wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip
!unzip DIV2K_train_LR_wild.zip -d datasets/
!unzip DIV2K_train_HR.zip -d datasets/
!rm -rf *.zip

--2025-06-16 15:37:45--  http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_wild.zip
Resolving data.vision.ee.ethz.ch (data.vision.ee.ethz.ch)... 129.132.52.178, 2001:67c:10ec:36c2::178
Connecting to data.vision.ee.ethz.ch (data.vision.ee.ethz.ch)|129.132.52.178|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_wild.zip [following]
--2025-06-16 15:37:46--  https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_wild.zip
Connecting to data.vision.ee.ethz.ch (data.vision.ee.ethz.ch)|129.132.52.178|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1079873385 (1.0G) [application/zip]
Saving to: ‘DIV2K_train_LR_wild.zip’


2025-06-16 15:38:25 (26.8 MB/s) - ‘DIV2K_train_LR_wild.zip’ saved [1079873385/1079873385]

--2025-06-16 15:38:25--  http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip
Resolving data.vision.ee.ethz.ch (data.vision.ee.ethz.ch)... 129.132.52.178, 2001:67c:1

In [15]:
!ls -l datasets/DIV2K_train_HR | grep png -c
!ls -l datasets/DIV2K_train_LR_wild | grep png -c

800
3200
