In [None]:
import torch
import numpy as np
import os
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from skimage import io
from torchvision.transforms.functional import InterpolationMode as IMode
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import lr_scheduler

### Data Preparation

In [None]:
class DatasetSuperResolution(Dataset):

    def __init__(
            self,
            path_to_data: str,
            mode: str = 'train',
            image_size: int = 1080,
            upscale_factor: int = 4

    ):
        super(DatasetSuperResolution, self).__init__()

        self.files = [os.path.join(path_to_data, x) for x in os.listdir(path_to_data)]

        if mode == "train":
            self.hr_transforms = transforms.Compose([
                transforms.RandomCrop(image_size, pad_if_needed = True)
            ])
        else:
            self.hr_transforms = transforms.CenterCrop(image_size, pad_if_needed = True)

        self.lr_transforms = transforms.Resize(
            image_size // upscale_factor,
            interpolation = IMode.BICUBIC,
            antialias = True
        )

    def __getitem__(self, _index: int) -> [torch.Tensor, torch.Tensor]:
        image = io.imread(self.files[_index])
        image = transforms.ToTensor()(image)

        hr_image = self.hr_transforms(image)
        lr_image = self.lr_transforms(hr_image)

        return lr_image, hr_image

    def __len__(self) -> int:
        return len(self.files)

In [None]:
train_set = DatasetSuperResolution('/kaggle/input/div2k/DIV2K_train_HR/DIV2K_train_HR/', image_size=256, upscale_factor=4)

In [None]:
valid_set = DatasetSuperResolution('/kaggle/input/div2k/DIV2K_valid_HR/DIV2K_valid_HR/', image_size=256, upscale_factor=4)

In [None]:
dataloader = DataLoader(train_set,
                            batch_size = 4,
                            shuffle = True,
                            num_workers = 2,
                            pin_memory = True)

dataloader_valid = DataLoader(valid_set, batch_size=2,
                             shuffle=True)

In [None]:
for x, y in dataloader:
    fig, axs = plt.subplots(1, 2, figsize=(10, 10))
    axs[0].imshow(x[0].squeeze().permute(1,2,0))
    axs[1].imshow(y[0].squeeze().permute(1,2,0))
    break

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
class ConvResBlock(nn.Module):
    """
    This is an implementation of a Residual Convolution Block from the article:
        https://arxiv.org/pdf/1609.04802.pdf
    """

    def __init__(
            self,
            channels: int = 64
    ):
        """
        :param channels: int
            How many channels should be in convolutional blocks
        """
        self.channels = channels
        super(ConvResBlock, self).__init__()

        self.conv_res_block = nn.Sequential(
            nn.Conv2d(
                in_channels = self.channels,
                out_channels = self.channels,
                kernel_size = 3,
                stride = 1,
                padding = 1
            ),
            nn.BatchNorm2d(self.channels),
            nn.PReLU(),
            nn.Conv2d(
                in_channels = self.channels,
                out_channels = self.channels,
                kernel_size = 3,
                stride = 1,
                padding = 1,
                bias = False
            ),
            nn.BatchNorm2d(self.channels)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of Residual Convolution Block
        :param x: torch.Tensor
            Input tensor
        :return: torch.Tensor
            Output tensor
        """
        initial_state = x
        out = self.conv_res_block(x)
        out = torch.add(out, initial_state)
        return out


class Generator(nn.Module):
    """
    Generator implementation of SRGAN from the article:
        https://arxiv.org/pdf/1609.04802.pdf
    """

    def __init__(
            self,
            input_channels: int = 3,
            out_channels: int = 64,
            input_kernel_size: int = 9,
            input_stride: int = 1,
            num_of_res_layers: int = 5
    ):
        super(Generator, self).__init__()

        self.input_channels = input_channels
        self.out_channels = out_channels
        self.input_kernel = input_kernel_size
        self.input_stride = input_stride
        self.first_conv_padding = int(np.ceil((self.input_kernel - self.input_stride) / 2))

        self.conv_1 = nn.Sequential(
            nn.Conv2d(
                in_channels = self.input_channels,
                out_channels = self.out_channels,
                kernel_size = self.input_kernel,
                stride = self.input_stride,
                padding = self.first_conv_padding
            ),
            nn.PReLU()
        )

        residual_block = []
        for i in range(num_of_res_layers):
            residual_block.append(ConvResBlock(self.out_channels))
        self.residual_block = nn.Sequential(*residual_block)

        self.conv_2 = nn.Sequential(
            nn.Conv2d(
                in_channels = self.out_channels,
                out_channels = self.out_channels,
                kernel_size = 3,
                stride = 1,
                padding = 1
            ),
            nn.BatchNorm2d(self.out_channels)
        )
        self.conv_3 = nn.Sequential(
            nn.Conv2d(
                in_channels = self.out_channels,
                out_channels = 256,
                kernel_size = 3,
                stride = 1,
                padding = 1
            ),
            nn.PixelShuffle(2),
            nn.PReLU(),
            nn.Conv2d(
                in_channels = self.out_channels,
                out_channels = 256,
                kernel_size = 3,
                stride = 1,
                padding = 1
            ),
            nn.PixelShuffle(2),
            nn.PReLU(),
        )

        self.conv_4 = nn.Conv2d(
            in_channels = self.out_channels,
            out_channels = 3,
            kernel_size = self.input_kernel,
            stride = self.input_stride,
            padding = self.first_conv_padding
        )

        self._initialize_weights()

    def _initialize_weights(self) -> None:
        """
        Weights initialization.
        For convolutional blocks there is "He initialization".
        :return:
            None
        """
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                nn.init.kaiming_normal_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
            elif isinstance(module, nn.BatchNorm2d):
                nn.init.constant_(module.weight, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of Generator
        :param x: torch.tensor
            Input Tensor
        :return: torch.Tensor
            Output tensor
        """
        output_1 = self.conv_1(x)
        output_2 = self.residual_block(output_1)
        output = self.conv_2(output_2)
        output = torch.add(output, output_1)
        output = self.conv_3(output)
        output = self.conv_4(output)

        return output

"""
Discriminator's block of the model
"""

class convBlockDiscriminator(nn.Module):
    '''
    Block in the discriminator with different stride and in/out channels.
    '''
    def __init__(self,
               stride_size: int = 1,
               in_channels_size: int = 64,
               out_channels_size: int = 64
               ):
        super(convBlockDiscriminator, self).__init__()

        self.conv_block = nn.Sequential(nn.Conv2d(kernel_size=3, in_channels=in_channels_size, out_channels=out_channels_size, stride=stride_size),
                                        nn.BatchNorm2d(out_channels_size), 
                                        nn.LeakyReLU(inplace=True)
                                        )
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.conv_block(x)


class Discriminator(nn.Module):
    """
    Discriminator implementation of SRGAN from the article:
        https://arxiv.org/pdf/1609.04802.pdf
    """
    def __init__(self,
               in_channels_size: int = 3,
               out_channels_size: int = 64,
               layer_size: int = 1000
               ):
    

        super(Discriminator, self).__init__()


        self.conv_start_block = nn.Sequential(nn.Conv2d(kernel_size=3,
                                                        in_channels=in_channels_size,
                                                        out_channels=out_channels_size,
                                                        stride=1),
                                              nn.LeakyReLU(inplace=True))

        self.conv_blocks = nn.ModuleList()

        self.conv_block_1 = convBlockDiscriminator(2, 64, 128)
        self.conv_block_2 = convBlockDiscriminator(1, 128, 128)
        self.conv_block_3 = convBlockDiscriminator(2, 128, 256)
        self.conv_block_4 = convBlockDiscriminator(1, 256, 256)
        self.conv_block_5 = convBlockDiscriminator(2, 256, 512)
        self.conv_block_6 = convBlockDiscriminator(1, 512, 512)
        self.conv_block_7 = convBlockDiscriminator(2, 512, 512)

        self.linear_block = nn.Sequential(nn.Linear(512*13*13, layer_size),
                                          nn.LeakyReLU(inplace=True),
                                          nn.Linear(layer_size, 1))


        self._initialize_weights()
    def _initialize_weights(self) -> None:
        """
        Weights initialization.
        For convolutional blocks there is "He initialization".
        :return:
            None
        """
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                nn.init.kaiming_normal_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
            elif isinstance(module, nn.BatchNorm2d):
                nn.init.constant_(module.weight, 1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        '''
        Forward pass of Discriminator
        :param x: torch.tensor
            Input Tensor
        :return: torch.Tensor
            Output tensor
        '''
        out = self.conv_start_block(x)

        out = self.conv_block_1(out)
        out = self.conv_block_2(out)
        out = self.conv_block_3(out)
        out = self.conv_block_4(out)
        out = self.conv_block_5(out)
        out = self.conv_block_6(out)
        out = self.conv_block_7(out)

        out = out.flatten(start_dim=1)


        out = self.linear_block(out) 

        return F.sigmoid(out)

In [None]:
generator = Generator().to(device)
discriminator = Discriminator().to(device)

generator_loss = nn.MSELoss()
discriminator_loss = nn.BCEWithLogitsLoss()

generator_optimizer = torch.optim.Adam(generator.parameters(),
                                  lr=1e-4)
discriminator_optimizer = torch.optim.Adam(discriminator.parameters(),
                                  lr=1e-4)

In [None]:
def show_picture(dataload, generator, device):
    generator.eval()
    i = 0
    for x, y in dataload:
        with torch.no_grad():
            i += 1
            pred = generator(x.to(device))
            loss = F.mse_loss(pred, y.to(device))
            fig, axs = plt.subplots(1, 2, figsize=(10, 10))
            axs[0].imshow(pred[0].squeeze().permute(1,2,0).detach().cpu().numpy())
            axs[0].set_title('loss = %s'%loss)
            axs[1].imshow(y[0].squeeze().permute(1,2,0).detach().cpu().numpy())
        if i==2:
            break

In [None]:
from IPython.display import clear_output
from tqdm.notebook import tqdm
def train(
    train_loader,
    valid_loader,
    generator,
    discriminator,
    gen_loss,
    discr_loss,
    gen_optim,
    discr_optim,
    device,
    epochs: int
):
    best_params_generator = 0 
    history_loss_hr = list()
    history_loss_sr = list()
    history_mse = list()
    generator.train()
    discriminator.train()
    scheduler_dicr = lr_scheduler.StepLR(discr_optim, step_size=20, gamma=0.1)
    scheduler_gener = lr_scheduler.StepLR(gen_optim, step_size=20, gamma=0.1)
    i=0
    for epoch in range(epochs):
        print("epochs =  ", (i+1))
        i += 1
        generator.train()
        discriminator.train()
        for lr, hr in tqdm(train_loader):
           
            lr = lr.to(device)
            hr = hr.to(device)

            real_label = torch.full([lr.size(0), 1], 1.0, dtype = lr.dtype, device = device)
            fake_label = torch.full([lr.size(0), 1], 0.0, dtype = lr.dtype, device = device)

            sr = generator(lr)

            for p in discriminator.parameters():
                p.requires_grad = True

            discr_optim.zero_grad()

            hr_output = discriminator(hr)
            discriminator_hr_loss = discriminator_loss(hr_output, real_label)
            history_loss_hr.append(discriminator_hr_loss.item())
            discriminator_hr_loss.backward()

            sr_output = discriminator(sr)
            discriminator_sr_loss = discriminator_loss(sr_output, fake_label)
            history_loss_sr.append(discriminator_sr_loss.item())
            discriminator_sr_loss.backward()

            discr_optim.step()

            for p in discriminator.parameters():
                p.requires_grad = False

            gen_optim.zero_grad()

            #output = discriminator(sr)
            generator_loss_tr = generator_loss(sr, hr)
            history_mse.append(generator_loss_tr.item())
            #total_loss = generator_loss_tr + discriminator_sr_loss
            generator_loss_tr.backward()
            gen_optim.step()
            
            scheduler_dicr.step()
            scheduler_gener.step()
        clear_output(True)
        fig, axs = plt.subplots(1, 3, figsize=(10, 10), dpi=100)
        axs[0].plot(history_loss_hr)
        axs[0].set_title('Discriminator with real labels')
        axs[1].plot(history_loss_sr)
        axs[1].set_title('Discriminator with fake labels')
        axs[2].plot(history_mse)
        axs[2].set_title('MSE loss generator')
        plt.show()
        time.sleep(5)
        clear_output(True)
        j = 0
        for x, y in valid_loader:
            with torch.no_grad():
                j += 1
                pred = generator(x.to(device))
                loss = F.mse_loss(pred, y.to(device))
                fig, axs = plt.subplots(1, 2, figsize=(10, 10))
                axs[0].imshow(pred[0].squeeze().permute(1,2,0).detach().cpu().numpy())
                axs[0].set_title('loss = %s'%loss)
                axs[1].imshow(y[0].squeeze().permute(1,2,0).detach().cpu().numpy())
                plt.show()
            if j==1:
                break
        time.sleep(5)
            
    return history_loss_hr, history_loss_sr, history_mse

In [None]:
h_l_h, h_l_s, h_mse = train(dataloader, dataloader_valid, generator,
                            discriminator, generator_loss, discriminator_loss,
                           generator_optimizer, discriminator_optimizer, device, 40)

In [None]:
import time
print('kirik lox')
time.sleep(10)
print('shuchy')