In [1]:
import torch
import torch.nn as nn
from timm.models.layers import Mish
from torchvision.models import vgg19
from torchvision.utils import save_image
torch.backends.cudnn.benchmark = True
import glob
import cv2
import numpy as np
from albumentations import (
    Compose, Resize, Normalize
)
from albumentations.pytorch import ToTensor
from torch.utils.data import Dataset

ModuleNotFoundError: No module named 'timm'

In [None]:
### PARAMETERS ###

# Normalization parameters for pre-trained PyTorch models
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])



In [None]:
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        vgg19_model = vgg19(pretrained=True)
        self.vgg19_54 = nn.Sequential(*list(vgg19_model.features.children())[:35])

    def forward(self, img):
        return self.vgg19_54(img)

In [None]:
class DenseResidualBlock(nn.Module):
    """
    The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
    """

    def __init__(self, filters, res_scale=0.2, use_LeakyReLU_Mish=True):
        super(DenseResidualBlock, self).__init__()
        self.res_scale = res_scale
        self.MISH = Mish()

        def block(in_features, non_linearity=True, use_LeakyReLU_Mish=True):
            layers = [nn.Conv2d(in_features, filters, 3, 1, 1, bias=True)]
            if non_linearity:
                if use_LeakyReLU_Mish:
                    layers += [nn.LeakyReLU()]
                else:
                    layers += [Mish()]
            return nn.Sequential(*layers)

        self.b1 = block(in_features=1 * filters, use_LeakyReLU_Mish=use_LeakyReLU_Mish)
        self.b2 = block(in_features=2 * filters, use_LeakyReLU_Mish=use_LeakyReLU_Mish)
        self.b3 = block(in_features=3 * filters, use_LeakyReLU_Mish=use_LeakyReLU_Mish)
        self.b4 = block(in_features=4 * filters, use_LeakyReLU_Mish=use_LeakyReLU_Mish)
        self.b5 = block(in_features=5 * filters, non_linearity=False)
        self.blocks = [self.b1, self.b2, self.b3, self.b4, self.b5]

    def forward(self, x):
        inputs = x
        for block in self.blocks:
            out = block(inputs)
            inputs = torch.cat([inputs, out], 1)
        return out.mul(self.res_scale) + x

In [None]:
class ResidualInResidualDenseBlock(nn.Module):
    def __init__(self, filters, res_scale=0.2, use_LeakyReLU_Mish=True):
        super(ResidualInResidualDenseBlock, self).__init__()
        self.res_scale = res_scale
        self.dense_blocks = nn.Sequential(
            DenseResidualBlock(filters, use_LeakyReLU_Mish),
            DenseResidualBlock(filters, use_LeakyReLU_Mish),
            DenseResidualBlock(filters, use_LeakyReLU_Mish)
        )

    def forward(self, x):
        return self.dense_blocks(x).mul(self.res_scale) + x

In [None]:
class GeneratorRRDB(nn.Module):
    def __init__(self, channels, filters=64, num_res_blocks=16, num_upsample=2, use_LeakyReLU_Mish=True):
        super(GeneratorRRDB, self).__init__()

        # First layer
        self.conv1 = nn.Conv2d(channels, filters, kernel_size=3, stride=1, padding=1)
        # Residual blocks
        self.res_blocks = nn.Sequential(*[ResidualInResidualDenseBlock(filters, use_LeakyReLU_Mish) for _ in range(num_res_blocks)])
        # Second conv layer post residual blocks
        self.conv2 = nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1)
        # Upsampling layers
        upsample_layers = []
        for _ in range(num_upsample):
            upsample_layers += [
                nn.Conv2d(filters, filters * 4, kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU(),
                nn.PixelShuffle(upscale_factor=2),
            ]
        self.upsampling = nn.Sequential(*upsample_layers)
        # Final output block
        self.conv3 = nn.Sequential(
            nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(filters, channels, kernel_size=3, stride=1, padding=1),
        )

    def forward(self, x):
        out1 = self.conv1(x)
        out = self.res_blocks(out1)
        out2 = self.conv2(out)
        out = torch.add(out1, out2)
        out = self.upsampling(out)
        out = self.conv3(out)
        return out


In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_shape, use_LeakyReLU_Mish=True):
        super(Discriminator, self).__init__()

        self.input_shape = input_shape
        in_channels, in_height, in_width = self.input_shape
        patch_h, patch_w = int(in_height / 2 ** 4), int(in_width / 2 ** 4)
        self.output_shape = (1, patch_h, patch_w)
        self.MISH = Mish()

        def discriminator_block(in_filters, out_filters, first_block=False):
            layers = [nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1)]
            if not first_block:
                layers.append(nn.BatchNorm2d(out_filters))
            if use_LeakyReLU_Mish:
                layers.append(nn.LeakyReLU(0.2, inplace=True))
            else:
                layers.append(self.MISH)
            layers.append(nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=2, padding=1))
            layers.append(nn.BatchNorm2d(out_filters))
            if use_LeakyReLU_Mish:
                layers.append(nn.LeakyReLU(0.2, inplace=True))
            else:
                layers.append(self.MISH)
            return layers

        layers = []
        in_filters = in_channels
        for i, out_filters in enumerate([64, 128, 256, 512]):
            layers.extend(discriminator_block(in_filters, out_filters, first_block=(i == 0)))
            in_filters = out_filters

        layers.append(nn.Conv2d(out_filters, 1, kernel_size=3, stride=1, padding=1))

        self.model = nn.Sequential(*layers)

    def forward(self, img):
        return self.model(img)

In [None]:
class train_loop(metaclass=ABCMeta):
    def __init__(self, epoch=0, n_epochs=200, dataset_name="../dataset/TurkishPlates", batch_size=3, lr=0.0002, b1=0.9,
                 b2=0.999, decay_epoch=100, n_cpu=8, hr_height=256, hr_width=256,
                 channels=3, sample_interval=100, checkpoint_interval=100, residual_blocks=23, warmup_batches=500,
                 lambda_adv=5e-3, lambda_pixel=1e-2):
        os.makedirs("images/training", exist_ok=True)
        os.makedirs("saved_models", exist_ok=True)

        # self.epoch = epoch
        # self.n_epochs = n_epochs
        # self.dataset_name = dataset_name
        # self.batch_size = batch_size
        # self.lr = lr
        # self.b1 = b1
        # self.b2 = b2
        # self.decay_epoch = decay_epoch
        # self.n_cpu = n_cpu
        # self.hr_height = hr_height
        # self.hr_width = hr_width
        # self.channels = channels
        # self.sample_interval = sample_interval
        # self.checkpoint_interval = checkpoint_interval
        # self.residual_blocks = residual_blocks
        # self.warmup_batches = warmup_batches
        # self.lambda_adv = lambda_adv
        # self.lambda_pixel = lambda_pixel

        self.reload(b1, b2, batch_size, channels, checkpoint_interval, dataset_name, decay_epoch, epoch, hr_height,
                    hr_width, lambda_adv, lambda_pixel, lr, n_cpu, n_epochs, residual_blocks, sample_interval,
                    warmup_batches)

    def reload(self, b1, b2, batch_size, channels, checkpoint_interval, dataset_name, decay_epoch, epoch, hr_height,
               hr_width, lambda_adv, lambda_pixel, lr, n_cpu, n_epochs, residual_blocks, sample_interval,
               warmup_batches):
        self.parser = argparse.ArgumentParser()
        self.opt = self.__commandline_interface(epoch=epoch, n_epochs=n_epochs, dataset_name=dataset_name,
                                                batch_size=batch_size, lr=lr, b1=b1, b2=b2, decay_epoch=decay_epoch,
                                                n_cpu=n_cpu, hr_height=hr_height, hr_width=hr_width, channels=channels,
                                                sample_interval=sample_interval,
                                                checkpoint_interval=checkpoint_interval,
                                                residual_blocks=residual_blocks, warmup_batches=warmup_batches,
                                                lambda_adv=lambda_adv, lambda_pixel=lambda_pixel)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # Initialize generator and discriminator
        hr_shape = (self.opt.hr_height, self.opt.hr_width)
        self.discriminator, self.feature_extractor, self.generator = self.network_initializers(hr_shape)
        # Losses
        self.criterion_GAN, self.criterion_content, self.criterion_pixel = self.losses()
        # Optimizers
        self.optimizer_D, self.optimizer_G = self.optimizers()
        # Data
        self.dataset = ImageDataset_superresolution(device=self.device, root=self.opt.dataset_name, hr_shape=hr_shape)
        self.dataloader = DataLoader(
            self.dataset,
            batch_size=self.opt.batch_size,
            num_workers=self.opt.n_cpu,
            shuffle=True,
            pin_memory=True,
        )

    @abstractmethod
    def __call__(self, args, kwargs):
        pass

    def __commandline_interface(self, epoch=0, n_epochs=200, dataset_name="dataset", batch_size=4, lr=0.0002, b1=0.9,
                                b2=0.999, decay_epoch=100, n_cpu=8, hr_height=256, hr_width=256,
                                channels=3, sample_interval=100, checkpoint_interval=100, residual_blocks=23,
                                warmup_batches=500, lambda_adv=5e-3, lambda_pixel=1e-2):
        self.parser.add_argument("--epoch", type=int, default=epoch, help="epoch to start training from")
        self.parser.add_argument("--n_epochs", type=int, default=n_epochs, help="number of epochs of training")
        self.parser.add_argument("--dataset_name", type=str, default=dataset_name,
                            help="name of the dataset")  # img_align_celeba
        self.parser.add_argument("--batch_size", type=int, default=batch_size, help="size of the batches")  # 4
        self.parser.add_argument("--lr", type=float, default=lr, help="adam: learning rate")
        self.parser.add_argument("--b1", type=float, default=b1, help="adam: decay of first order momentum of gradient")
        self.parser.add_argument("--b2", type=float, default=b2, help="adam: decay of first order momentum of gradient")
        self.parser.add_argument("--decay_epoch", type=int, default=decay_epoch, help="epoch from which to start lr decay")
        self.parser.add_argument("--n_cpu", type=int, default=n_cpu,
                            help="number of cpu threads to use during batch generation")
        self.parser.add_argument("--hr_height", type=int, default=hr_height, help="high res. image height")
        self.parser.add_argument("--hr_width", type=int, default=hr_width, help="high res. image width")
        self.parser.add_argument("--channels", type=int, default=channels, help="number of image channels")
        self.parser.add_argument("--sample_interval", type=int, default=sample_interval,
                            help="interval between saving image samples")
        self.parser.add_argument("--checkpoint_interval", type=int, default=checkpoint_interval,
                            help="batch interval between model checkpoints")  # 5000
        self.parser.add_argument("--residual_blocks", type=int, default=residual_blocks,
                            help="number of residual blocks in the generator")
        self.parser.add_argument("--warmup_batches", type=int, default=warmup_batches,
                            help="number of batches with pixel-wise loss only")
        self.parser.add_argument("--lambda_adv", type=float, default=lambda_adv, help="adversarial loss weight")
        self.parser.add_argument("--lambda_pixel", type=float, default=lambda_pixel, help="pixel-wise loss weight")
        opt = self.parser.parse_args()
        print(opt)
        return opt

    def network_initializers(self, hr_shape, use_LeakyReLU_Mish=False):
        generator = GeneratorRRDB(self.opt.channels, filters=64, num_res_blocks=self.opt.residual_blocks,
                                  use_LeakyReLU_Mish=use_LeakyReLU_Mish).to(self.device, non_blocking=True)
        discriminator = Discriminator(input_shape=(self.opt.channels, *hr_shape),
                                      use_LeakyReLU_Mish=use_LeakyReLU_Mish).to(self.device, non_blocking=True)
        feature_extractor = FeatureExtractor().to(self.device, non_blocking=True)
        # Set feature extractor to inference mode
        feature_extractor.eval()
        return discriminator, feature_extractor, generator

    def losses(self):
        criterion_GAN = torch.nn.BCEWithLogitsLoss()
        criterion_content = torch.nn.L1Loss()
        criterion_pixel = torch.nn.L1Loss()
        return criterion_GAN, criterion_content, criterion_pixel

    def optimizers(self):
        optimizer_G = torch.optim.Adam(self.generator.parameters(), lr=self.opt.lr, betas=(self.opt.b1, self.opt.b2))
        optimizer_D = torch.optim.Adam(self.discriminator.parameters(), lr=self.opt.lr,
                                       betas=(self.opt.b1, self.opt.b2))
        return optimizer_D, optimizer_G

    def __train(self):
        pass

In [None]:
class ImageDataset_superresolution(Dataset):
    def __init__(self, device, root=None, hr_shape=None, interpolation='bicubic', load_all=True):
        hr_height, hr_width = hr_shape
        self.hr_height = hr_height
        self.hr_width = hr_width
        self.device = device
        self.load_all = load_all

        # ## kornia
        # Transforms for low resolution images and high resolution images
        # self.lr_transform = kornia.nn.Sequential(
        #     kornia.geometry.Resize(size=(hr_height // 4, hr_height // 4), interpolation=interpolation),
        #     kornia.color.Normalize(torch.from_numpy(mean), torch.from_numpy(std)),
        # ).to(device=device)
        # self.hr_transform = kornia.nn.Sequential(
        #     kornia.geometry.Resize(size=(hr_height, hr_height), interpolation=interpolation),
        #     kornia.color.Normalize(torch.from_numpy(mean), torch.from_numpy(std)),
        # )

        # ## torchvision
        # self.lr_transform = transforms.Compose(
        #     [
        # #       transforms.ToPILImage(),
        #         transforms.Resize((hr_height // 4, hr_height // 4), Image.BICUBIC),
        #         transforms.ToTensor(),
        #         transforms.Normalize(mean, std),
        #     ]
        # )
        # self.hr_transform = transforms.Compose(
        #     [
        # #       transforms.ToPILImage(),
        #         transforms.Resize((hr_height, hr_height), Image.BICUBIC),
        #         transforms.ToTensor(),
        #         transforms.Normalize(mean, std),
        #     ]
        # )

        ## Albumentation
        self.lr_transform = Compose([
            Resize(hr_height // 4, hr_height // 4),
            Normalize(mean, std),
            ToTensor(),
        ])

        self.hr_transform = Compose([
            Resize(hr_height, hr_height),
            Normalize(mean, std),
            ToTensor(),
        ])

        if isinstance(root, str):
            self.files = sorted(glob.glob(root + "/*.*"))
        else:
            self.files = root

        if load_all:
            self.images_lr, self.images_hr = self.__load_all_images()

    def __load_all_images(self):
        files_len = len(self.files)

        index = 0
        img_pivot = cv2.imread(self.files[index % files_len])
        img_lr = self.lr_transform(image=img_pivot)['image']
        img_hr = self.hr_transform(image=img_pivot)['image']
        images_lr = torch.zeros((files_len, img_pivot.shape[-1], self.hr_height//4, self.hr_width//4))
        images_hr = torch.zeros((files_len, img_pivot.shape[-1], self.hr_height, self.hr_width))

        images_lr[index, ...] = img_lr
        images_hr[index, ...] = img_hr
        for index in range(1, files_len):
            img = cv2.imread(self.files[index % files_len])
            img_lr = self.lr_transform(image=img)['image']
            img_hr = self.hr_transform(image=img)['image']
            images_lr[index, ...] = img_lr
            images_hr[index, ...] = img_hr
        return images_lr, images_hr

    def load_one_image(self, index):
        if isinstance(self.files[index], str):
            img = cv2.imread(self.files[index % len(self.files)])
        else:
            img = self.files
        img_lr = self.lr_transform(image=img)['image']
        img_hr = self.hr_transform(image=img)['image']
        return {"lr": img_lr, "hr": img_hr}

    def __getitem__(self, index):
        if self.load_all:
            return {"lr": self.images_lr[index], "hr": self.images_hr[index]}
        else:
            return self.load_one_image(index)

    def __len__(self):
        return len(self.files)


if __name__ == "__main__":
    dataset_name = "../TurkishPlates"
    hr_shape = (256, 256)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dataset = ImageDataset_superresolution(device=device, root=dataset_name, hr_shape=hr_shape)
    print(dataset[0])

In [None]:
class esrgan(train_loop):

    def __call__(self, *args, **kwargs):
        self.esrgan_train()

    def esrgan_train(self):
        if self.opt.epoch != 0:
            # Load pretrained models
            self.generator.load_state_dict(torch.load("saved_models/generator_%d.pth" % self.opt.epoch))
            self.discriminator.load_state_dict(torch.load("saved_models/discriminator_%d.pth" % self.opt.epoch))
        self.__train()

    def __train(self):
        for epoch in range(self.opt.epoch, self.opt.n_epochs):
            for i, imgs in enumerate(self.dataloader):

                batches_done = epoch * len(self.dataloader) + i

                # Configure model input
                imgs_lr = imgs["lr"].to(self.device, non_blocking=True)
                imgs_hr = imgs["hr"].to(self.device, non_blocking=True)

                # Adversarial ground truths
                valid = torch.ones((imgs_lr.size(0), *self.discriminator.output_shape), requires_grad=False).to(
                    self.device, non_blocking=True)
                fake = torch.zeros((imgs_lr.size(0), *self.discriminator.output_shape), requires_grad=False).to(
                    self.device, non_blocking=True)

                # ------------------
                #  Train Generators
                # ------------------

                self.optimizer_G.zero_grad()

                # Generate a high resolution image from low resolution input
                gen_hr = self.generator(imgs_lr)

                # Measure pixel-wise loss against ground truth
                loss_pixel = self.criterion_pixel(gen_hr, imgs_hr)  # L1Loss

                if batches_done < self.opt.warmup_batches:
                    # Warm-up (pixel-wise loss only)
                    loss_pixel.backward()
                    self.optimizer_G.step()
                    print(
                        "[Epoch %d/%d] [Batch %d/%d] [G pixel: %f]"
                        % (epoch, self.opt.n_epochs, i, len(self.dataloader), loss_pixel.item())
                    )
                    continue

                # Extract validity predictions from discriminator
                pred_real = self.discriminator(imgs_hr).detach()
                pred_fake = self.discriminator(gen_hr)

                # Adversarial loss (relativistic average GAN)
                loss_GAN = self.criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), valid)

                # Content loss
                real_features = self.feature_extractor(imgs_hr).detach()
                gen_features = self.feature_extractor(gen_hr)
                loss_content = self.criterion_content(gen_features, real_features)  # L1Loss

                # Total generator loss
                loss_G = loss_content + self.opt.lambda_adv * loss_GAN + self.opt.lambda_pixel * loss_pixel

                loss_G.backward()
                self.optimizer_G.step()

                # ---------------------
                #  Train Discriminator
                # ---------------------

                loss_D = self.train_discriminator(gen_hr, imgs_hr, fake, valid)

                # --------------
                #  Log Progress
                # --------------

                self.__log_progress(i, batches_done, epoch, gen_hr, imgs_lr, loss_D, loss_G, loss_GAN, loss_content,
                                    loss_pixel)

    def train_discriminator(self, gen_hr, imgs_hr, fake, valid):
        self.optimizer_D.zero_grad()
        pred_real = self.discriminator(imgs_hr)
        pred_fake = self.discriminator(gen_hr.detach())
        # Adversarial loss for real and fake images (relativistic average GAN)
        loss_real = self.criterion_GAN(pred_real - pred_fake.mean(0, keepdim=True), valid)
        loss_fake = self.criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), fake)
        # Total loss
        loss_D = (loss_real + loss_fake) / 2
        loss_D.backward()
        self.optimizer_D.step()
        return loss_D

    def __log_progress(self, i, batches_done, epoch, gen_hr, imgs_lr, loss_D, loss_G, loss_GAN, loss_content,
                       loss_pixel):
        self.summary_string = "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, content: %f, adv: %f, pixel: %f]"\
            % (
                epoch,
                self.opt.n_epochs,
                i,
                len(self.dataloader),
                loss_D.item(),
                loss_G.item(),
                loss_content.item(),
                loss_GAN.item(),
                loss_pixel.item(),
            )
        print(self.summary_string)
        if batches_done % self.opt.sample_interval == 0:
            # Save image grid with upsampled inputs and ESRGAN outputs
            imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)
            img_grid = denormalize(torch.cat((imgs_lr, gen_hr), -1))
            save_image(img_grid, "images/training/%d.png" % batches_done, nrow=1, normalize=False)
        if batches_done % self.opt.checkpoint_interval == 0:
            # Save model checkpoints
            torch.save(self.generator.state_dict(), "saved_models/generator_%d.pth" % epoch)
            torch.save(self.discriminator.state_dict(), "saved_models/discriminator_%d.pth" % epoch)


In [None]:
def denormalize(tensors):
    """ Denormalizes image tensors using mean and std """
    for c in range(3):
        tensors[:, c].mul_(std[c]).add_(mean[c])
    return torch.clamp(tensors, 0, 255)

In [None]:
superres = esrgan()
superres()