In [None]:
import math
import random
from os import listdir
from os.path import join

import cv2
import lpips
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchmetrics
import torchvision.transforms.functional as TF
from PIL import Image
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from torchvision.models.vgg import vgg16
from torchvision.transforms import (CenterCrop, Compose, RandomCrop, Resize,
                                    ToPILImage, ToTensor)
from tqdm.auto import tqdm


In [None]:

def initialize_metrics(device):
    return {
        "mse": nn.MSELoss().to(device),
        "psnr": torchmetrics.PeakSignalNoiseRatio(data_range=1.0).to(device),
        "ssim": torchmetrics.StructuralSimilarityIndexMeasure(data_range=1.0).to(device),
        "lpips": lpips.LPIPS(net="alex").to(device)
    }

def compute_metrics(metrics_fns, prediction, target):
    return {
        "mse": metrics_fns["mse"](prediction, target).item(),
        "psnr": metrics_fns["psnr"](prediction, target).item(),
        "ssim": metrics_fns["ssim"](prediction, target).item(),
        "lpips": metrics_fns["lpips"](prediction, target).item()
    }

def process_image_batch(lr_img, hr_img, model, device):
    lr_img = lr_img.unsqueeze(0).to(device)
    hr_img = hr_img.unsqueeze(0).to(device)
    lr_np = lr_img.squeeze(0).cpu().permute(1, 2, 0).numpy()
    bicubic_np = cv2.resize(lr_np, (256, 256), interpolation=cv2.INTER_CUBIC)
    bicubic = torch.tensor(bicubic_np).permute(2, 0, 1).unsqueeze(0).to(device)
    model_output = torch.clip(model(lr_img), 0, 1)
    return lr_img, hr_img, bicubic, model_output

def visualize_image(lr_img, hr_img, bicubic, model_output, ax_row, metrics_bicubic, metrics_model):
    for ax, img, title in zip(ax_row, 
                              [lr_img, hr_img, bicubic, model_output], 
                              ["Low Res (64x64)", "High Res (256x256)", "Bicubic Interpolation", "Model Output"]):
        img_np = img.squeeze(0).cpu().permute(1, 2, 0).numpy()
        ax.imshow(img_np)
        ax.set_title(title)
        ax.axis("off")

    ax_row[2].text(0.5, -0.15, f"MSE: {metrics_bicubic['mse']:.4f}, PSNR: {metrics_bicubic['psnr']:.2f}\n"
                   f"SSIM: {metrics_bicubic['ssim']:.4f}, LPIPS: {metrics_bicubic['lpips']:.4f}",
                   transform=ax_row[2].transAxes, ha="center", fontsize=10)
    ax_row[3].text(0.5, -0.15, f"MSE: {metrics_model['mse']:.4f}, PSNR: {metrics_model['psnr']:.2f}\n"
                   f"SSIM: {metrics_model['ssim']:.4f}, LPIPS: {metrics_model['lpips']:.4f}",
                   transform=ax_row[3].transAxes, ha="center", fontsize=10)

def evaluate_and_visualize(model, dataset, n=5, device="cuda" if torch.cuda.is_available() else "cpu"):
    model.eval()
    model.to(device)

    metrics_fns = initialize_metrics(device)
    all_metrics = {"bicubic": {"mse": [], "psnr": [], "ssim": [], "lpips": []},
                   "model": {"mse": [], "psnr": [], "ssim": [], "lpips": []}}
    with torch.no_grad():
        for lr_img, hr_img in dataset:
            lr_img, hr_img, bicubic, model_output = process_image_batch(lr_img, hr_img, model, device)
            bicubic_metrics = compute_metrics(metrics_fns, bicubic, hr_img)
            model_metrics = compute_metrics(metrics_fns, model_output, hr_img)
            for metric in all_metrics["bicubic"]:
                all_metrics["bicubic"][metric].append(bicubic_metrics[metric])
                all_metrics["model"][metric].append(model_metrics[metric])
    avg_metrics = {
        method: {metric: np.mean(values) for metric, values in metrics.items()}
        for method, metrics in all_metrics.items()
    }
    indices = np.random.choice(len(dataset), n, replace=False)
    samples = [dataset[idx] for idx in indices]
    fig, axes = plt.subplots(n, 4, figsize=(20, 5 * n))
    with torch.no_grad():
        for i, (lr_img, hr_img) in enumerate(samples):
            lr_img, hr_img, bicubic, model_output = process_image_batch(lr_img, hr_img, model, device)
            bicubic_metrics = compute_metrics(metrics_fns, bicubic, hr_img)
            model_metrics = compute_metrics(metrics_fns, model_output, hr_img)
            ax_row = axes[i] if n > 1 else axes
            visualize_image(lr_img, hr_img, bicubic, model_output, ax_row, bicubic_metrics, model_metrics)
    print("\nŚrednie metryki dla całego datasetu walidacyjnego:")
    for method in ["bicubic", "model"]:
        print(f"{method.capitalize()} Metrics:")
        print(f"MSE: {avg_metrics[method]['mse']:.4f}, PSNR: {avg_metrics[method]['psnr']:.2f}, "
              f"SSIM: {avg_metrics[method]['ssim']:.4f}, LPIPS: {avg_metrics[method]['lpips']:.4f}")
    plt.tight_layout()
    plt.show()


In [None]:

class Generator(nn.Module):
    def __init__(self, scale_factor):
        super().__init__()
        upsample_block_num = int(math.log(scale_factor, 2))
        
        self.block1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=9, padding=4), nn.PReLU()
        )
        
        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(64) for _ in range(5)]
        )
        
        self.block7 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64)
        )
        
        self.block8 = nn.Sequential(
            *[UpsampleBLock(64, 2) for _ in range(upsample_block_num)],
            nn.Conv2d(64, 3, kernel_size=9, padding=4)
        )
    
    def forward(self, x):
        block1 = self.block1(x)
        residual = self.residual_blocks(block1)
        block7 = self.block7(residual)
        block8 = self.block8(block1 + block7)
        return (torch.tanh(block8) + 1) / 2


class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            self._conv_block(3, 64, 3, 1, False),
            self._conv_block(64, 64, 3, 2),
            self._conv_block(64, 128, 3, 1),
            self._conv_block(128, 128, 3, 2),
            self._conv_block(128, 256, 3, 1),
            self._conv_block(256, 256, 3, 2),
            self._conv_block(256, 512, 3, 1),
            self._conv_block(512, 512, 3, 2),
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, 1024, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(1024, 1, 1)
        )
    
    def _conv_block(self, in_channels, out_channels, kernel_size, stride, batch_norm=True):
        layers = [
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        ]
        if batch_norm:
            layers.insert(1, nn.BatchNorm2d(out_channels))
        return nn.Sequential(*layers)
    
    def forward(self, x):
        return torch.sigmoid(self.net(x).view(x.size(0)))


class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.BatchNorm2d(channels),
            nn.PReLU(),
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.BatchNorm2d(channels)
        )
    
    def forward(self, x):
        return x + self.block(x)


class UpsampleBLock(nn.Module):
    def __init__(self, in_channels, up_scale):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, in_channels * up_scale ** 2, 3, padding=1),
            nn.PixelShuffle(up_scale),
            nn.PReLU()
        )
    
    def forward(self, x):
        return self.block(x)


In [None]:


class GeneratorLoss(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = vgg16(pretrained=True)
        self.loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
        for param in self.loss_network.parameters():
            param.requires_grad = False
        self.mse_loss = nn.MSELoss()
        self.tv_loss = TVLoss()
    
    def forward(self, out_labels, out_images, target_images):
        adversarial_loss = torch.mean(1 - out_labels)
        perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images))
        image_loss = self.mse_loss(out_images, target_images)
        tv_loss = self.tv_loss(out_images)
        return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss


class TVLoss(nn.Module):
    def __init__(self, tv_loss_weight=1):
        super().__init__()
        self.tv_loss_weight = tv_loss_weight
    
    def forward(self, x):
        batch_size, _, h_x, w_x = x.size()
        h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :-1, :]), 2).sum()
        w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :-1]), 2).sum()
        return self.tv_loss_weight * 2 * (h_tv / self.tensor_size(x[:, :, 1:, :]) + w_tv / self.tensor_size(x[:, :, :, 1:])) / batch_size

    @staticmethod
    def tensor_size(t):
        return t.numel() // t.size(0)


In [None]:
def is_image_file(filename):
    return any(filename.lower().endswith(ext) for ext in ['.png', '.jpg', '.jpeg'])


def calculate_valid_crop_size(crop_size, upscale_factor):
    return crop_size - (crop_size % upscale_factor)


def train_hr_transform(crop_size):
    return Compose([RandomCrop(crop_size), ToTensor()])


def train_lr_transform(crop_size, upscale_factor):
    return Compose([ToPILImage(), Resize(crop_size // upscale_factor, interpolation=Image.BICUBIC), ToTensor()])


class TrainDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, crop_size, upscale_factor):
        super().__init__()
        self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]
        crop_size = calculate_valid_crop_size(crop_size, upscale_factor)
        self.hr_transform = train_hr_transform(crop_size)
        self.lr_transform = train_lr_transform(crop_size, upscale_factor)
        self.rotation_angles = [0, 90, 180, 270]

    def __getitem__(self, index):
        hr_image = Image.open(self.image_filenames[index]).convert('RGB')
        hr_image = TF.rotate(hr_image, random.choice(self.rotation_angles))
        hr_image = self.hr_transform(hr_image)
        lr_image = self.lr_transform(hr_image)
        return lr_image, hr_image

    def __len__(self):
        return len(self.image_filenames)


class ValDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, upscale_factor):
        super().__init__()
        self.upscale_factor = upscale_factor
        self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]

    def __getitem__(self, index):
        hr_image = Image.open(self.image_filenames[index])
        crop_size = calculate_valid_crop_size(min(hr_image.size), self.upscale_factor)
        lr_scale = Resize(crop_size // self.upscale_factor, interpolation=Image.BICUBIC)
        hr_image = CenterCrop(crop_size)(hr_image)
        lr_image = lr_scale(hr_image)
        return ToTensor()(lr_image), ToTensor()(hr_image)

    def __len__(self):
        return len(self.image_filenames)


In [None]:
def train_gan(config):
    train_set = TrainDatasetFromFolder(config["TRAIN_DATA_PATH"], crop_size=config["CROP_SIZE"], upscale_factor=config["UPSCALE_FACTOR"])
    val_set = ValDatasetFromFolder(config["VAL_DATA_PATH"], upscale_factor=config["UPSCALE_FACTOR"])
    train_loader = DataLoader(dataset=train_set, num_workers=config["NUM_WORKERS"], batch_size=config["BATCH_SIZE"], shuffle=True)

    netG = Generator(config["UPSCALE_FACTOR"])
    netD = Discriminator()

    generator_criterion = GeneratorLoss()

    if torch.cuda.is_available():
        netG.cuda()
        netD.cuda()
        generator_criterion.cuda()

    optimizerG = optim.Adam(netG.parameters())
    optimizerD = optim.Adam(netD.parameters())

    for epoch in range(1, config["NUM_EPOCHS"] + 1):
        train_epoch(train_loader, netG, netD, generator_criterion, optimizerG, optimizerD, epoch, config["NUM_EPOCHS"])
        if epoch % 5 == 0:
            netG.eval()
            evaluate_and_visualize(netG, val_set)


def train_epoch(train_loader, netG, netD, generator_criterion, optimizerG, optimizerD, epoch, num_epochs):
    netG.train()
    netD.train()
    
    running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}

    with tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}", leave=True, dynamic_ncols=True) as train_bar:
        for data, target in train_bar:
            batch_size = data.size(0)
            running_results['batch_sizes'] += batch_size

            real_img = target.cuda() if torch.cuda.is_available() else target.float()
            z = data.cuda() if torch.cuda.is_available() else data.float()
            
            fake_img = netG(z)
            fake_out = netD(fake_img).mean()

            optimizerG.zero_grad()
            g_loss = generator_criterion(fake_out, fake_img, real_img)
            g_loss.backward()
            optimizerG.step()

            real_out = netD(real_img).mean()
            fake_out = netD(fake_img.detach()).mean()
            d_loss = 1 - real_out + fake_out

            optimizerD.zero_grad()
            d_loss.backward()
            optimizerD.step()

            running_results['g_loss'] += g_loss.item() * batch_size
            running_results['d_loss'] += d_loss.item() * batch_size
            running_results['d_score'] += real_out.item() * batch_size
            running_results['g_score'] += fake_out.item() * batch_size

            train_bar.set_postfix({
                "Loss_D": f"{running_results['d_loss'] / running_results['batch_sizes']:.4f}",
                "Loss_G": f"{running_results['g_loss'] / running_results['batch_sizes']:.4f}",
                "D(x)": f"{running_results['d_score'] / running_results['batch_sizes']:.4f}",
                "D(G(z))": f"{running_results['g_score'] / running_results['batch_sizes']:.4f}"
            })


In [None]:
def set_random_seed(seed):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    
set_random_seed(42)


In [None]:
config = {
    "CROP_SIZE": 256,
    "UPSCALE_FACTOR": 4,
    "NUM_EPOCHS": 50,
    "BATCH_SIZE": 8,
    "NUM_WORKERS": 4,
    "TRAIN_DATA_PATH": 'data/train/256',
    "VAL_DATA_PATH": 'data/valid/256'
}
train_gan(config)


In [None]:

config = {
    "CROP_SIZE": 256,
    "UPSCALE_FACTOR": 8,
    "NUM_EPOCHS": 50,
    "BATCH_SIZE": 8,
    "NUM_WORKERS": 4,
    "TRAIN_DATA_PATH": 'data/train/256',
    "VAL_DATA_PATH": 'data/valid/256'
}
train_gan(config)


# Podsumowanie

Zaproponowana architektura inspirowana podejściem SRGAN osiąga nieco lepsze wyniki niż interpolacja bikubiczna.
Obrazy generowane przez sieć są bardziej ostre niż obrazy interpolowane. Posiadają jednak widoczne artefakty: obramówka, szachownica, zduplikowane krawędzie. 
Trening przy GAN szybko przestał uwzględniać dyskryminator, który stawał się zbyt słaby i 
zawsze klasyfikował obrazy jako realistyczne, przez co generator nie otrzymuje gradientów od dyskryminatora i 
w zasadzie trening staje się zwykłym treningiem generatora a nie GAN. Uczenie dłuższe niż 50 epok mogłoby nieco poprawić generacje,
ponieważ metryki cały czas się nieznacznie poprawiały ale zdecydowano zamiast tego podjąć próbę poprawy treningu GAN