<a href="https://colab.research.google.com/github/avyay10/Resources-for-SeSiGAN/blob/main/esrganmodel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import collections
import torch

class Model(object):
    def name(self):
        pass

    def load(self):
        pass

class FileModel(Model):
    def __init__(self, path):
        self._model = None
        self._path = path

    def _get_scale_index(self, state_dict):
        # this is more or less guesswork, since I haven't seen any
        # non-4x models using the new format in the wild, but it
        # should work in theory
        max_index = 0

        for k in state_dict.keys():
            if k.startswith("upconv") and k.endswith(".weight"):
                max_index = max(max_index, int(k[6:-7]))

        return max_index

    def _get_legacy_scale_index(self, state_dict):
        try:
            # get largest model index from keys like "model.X.weight"
            max_index = max([int(n.split(".")[1]) for n in state_dict.keys()])
        except:
            # invalid model dict format?
            raise RuntimeError("Unable to determine scale index for model")

        return (max_index - 4) // 3

    def _build_legacy_keymap(self, n_upscale):
        keymap = collections.OrderedDict()
        keymap["model.0"] = "conv_first"

        for i in range(23):
            for j in range(1, 4):
                for k in range(1, 6):
                    k1 = "model.1.sub.%d.RDB%d.conv%d.0" % (i, j, k)
                    k2 = "RRDB_trunk.%d.RDB%d.conv%d" % (i, j, k)
                    keymap[k1] = k2

        keymap["model.1.sub.23"] = "trunk_conv"

        n = 0
        for i in range(1, n_upscale + 1):
            n += 3
            k1 = "model.%d" % n
            k2 = "upconv%d" % i
            keymap[k1] = k2

        keymap["model.%d" % (n + 2)] = "HRconv"
        keymap["model.%d" % (n + 4)] = "conv_last"

        # add ".weigth" and ".bias" suffixes to all keys
        keymap_final = collections.OrderedDict()

        for k1, k2 in keymap.items():
            for k_type in ("weight", "bias"):
                k1_f = k1 + "." +  k_type
                k2_f = k2 + "." +  k_type
                keymap_final[k1_f] = k2_f

        return keymap_final

    def name(self):
        return os.path.splitext(os.path.basename(self._path))[0]

    def _load(self):
        state_dict = torch.load(self._path)

        # check for legacy model format
        if "model.0.weight" in state_dict:
            # remap dict keys to new format
            scale_index = self._get_legacy_scale_index(state_dict)
            keymap = self._build_legacy_keymap(scale_index)
            state_dict = {keymap[k]: v for k, v in state_dict.items()}
        else:
            scale_index = self._get_scale_index(state_dict)

        return state_dict, scale_index

    def load(self):
        if self._model is None:
            self._model = self._load()
        return self._model

class WeightedFileListModel(Model):
    def __init__(self, weight_map):
        self._models = {}
        self._total_weigth = 0

        names = []
        for path, weight in weight_map.items():
            model = FileModel(path)
            self._models[model] = weight

            names.append(model.name())
            names.append(str(weight))

        self._name = "_".join(names)

    def name(self):
        return self._name

    def load(self):
        net_interp = collections.OrderedDict()
        total_weigth = sum(self._models.values())
        scale = 0

        for model, weight in self._models.items():
            alpha = weight / total_weigth
            net, scale = model.load()
            for k, v in net.items():
                va = alpha * v
                if k in net_interp:
                    net_interp[k] += va
                else:
                    net_interp[k] = va

        return net_interp, scale

In [2]:
import functools
import collections
import torch
import torch.nn as nn
import torch.nn.functional as F

def make_layer(block, n_layers):
    layers = []
    for _ in range(n_layers):
        layers.append(block())
    return nn.Sequential(*layers)

class ResidualDenseBlock_5C(nn.Module):
    def __init__(self, nf=64, gc=32, bias=True):
        super(ResidualDenseBlock_5C, self).__init__()
        # gc: growth channel, i.e. intermediate channels
        self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
        self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
        self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
        self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
        self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

        # initialization
        # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        return x5 * 0.2 + x


class RRDB(nn.Module):
    '''Residual in Residual Dense Block'''

    def __init__(self, nf, gc=32):
        super(RRDB, self).__init__()
        self.RDB1 = ResidualDenseBlock_5C(nf, gc)
        self.RDB2 = ResidualDenseBlock_5C(nf, gc)
        self.RDB3 = ResidualDenseBlock_5C(nf, gc)

    def forward(self, x):
        out = self.RDB1(x)
        out = self.RDB2(out)
        out = self.RDB3(out)
        return out * 0.2 + x

class RRDBNet(nn.Module):
    def __init__(self, in_nc, out_nc, nf, nb, gc=32):
        super(RRDBNet, self).__init__()
        RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)

        self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
        self.RRDB_trunk = make_layer(RRDB_block_f, nb)
        self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        #### upsampling
        self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)

        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

        self.n_upscale = 0
        self.nf = nf

    def load_state_dict(self, state_dict, scale, strict=True):
        self.n_upscale = scale

        # build upconv layers based on model scale
        for n in range(1, self.n_upscale + 1):
            upconv = nn.Conv2d(self.nf, self.nf, 3, 1, 1, bias=True)
            setattr(self, "upconv%d" % n, upconv)

        return super().load_state_dict(state_dict, strict)

    def forward(self, x):
        fea = self.conv_first(x)
        trunk = self.trunk_conv(self.RRDB_trunk(fea))
        fea = fea + trunk

        # apply upconv layers
        for n in range(1, self.n_upscale + 1):
            upconv = getattr(self, "upconv%d" % n)
            fea = self.lrelu(upconv(F.interpolate(fea, scale_factor=2, mode="nearest")))

        out = self.conv_last(self.lrelu(self.HRconv(fea)))

        return out

In [4]:
import math
import numpy as np
import torch
#import rrdbnet

class Upscaler(object):
    def upscale(self, input_image):
        # nop
        return input_image

class RRDBNetUpscaler(Upscaler):
    def __init__(self, model, device):
        net, scale = model.load()

        model_net = rrdbnet.RRDBNet(3, 3, 64, 23)
        model_net.load_state_dict(net, scale, strict=True)
        model_net.eval()

        for _, v in model_net.named_parameters():
            v.requires_grad = False

        self.model = model_net.to(device)
        self.device = device
        self.scale_factor = 2 ** scale

    def upscale(self, input_image):
        input_image = input_image * 1.0 / 255
        input_image = np.transpose(input_image[:, :, [2, 1, 0]], (2, 0, 1))
        input_image = torch.from_numpy(input_image).float()
        input_image = input_image.unsqueeze(0).to(self.device)

        output_image = self.model(input_image).data.squeeze().float().cpu().clamp_(0, 1).numpy()
        output_image = np.transpose(output_image[[2, 1, 0], :, :], (1, 2, 0))
        output_image = (output_image * 255.0).round()

        return output_image

class TiledUpscaler(Upscaler):
    def __init__(self, upscaler, tile_size, tile_padding):
        self.upscaler = upscaler
        self.scale_factor = upscaler.scale_factor
        self.tile_size = tile_size
        self.tile_padding = tile_padding

    def upscale(self, input_image):
        scale_factor = self.upscaler.scale_factor
        width, height, depth = input_image.shape
        output_width = width * scale_factor
        output_height = height * scale_factor
        output_shape = (output_width, output_height, depth)

        # start with black image
        output_image = np.zeros(output_shape, np.uint8)

        tile_padding = math.ceil(self.tile_size * self.tile_padding)
        tile_size = math.ceil(self.tile_size / scale_factor)

        tiles_x = math.ceil(width / tile_size)
        tiles_y = math.ceil(height / tile_size)

        for y in range(tiles_y):
            for x in range(tiles_x):
                # extract tile from input image
                ofs_x = x * tile_size
                ofs_y = y * tile_size

                # input tile area on total image
                input_start_x = ofs_x
                input_end_x = min(ofs_x + tile_size, width)

                input_start_y = ofs_y
                input_end_y = min(ofs_y + tile_size, height)

                # input tile area on total image with padding
                input_start_x_pad = max(input_start_x - tile_padding, 0)
                input_end_x_pad = min(input_end_x + tile_padding, width)

                input_start_y_pad = max(input_start_y - tile_padding, 0)
                input_end_y_pad = min(input_end_y + tile_padding, height)

                # input tile dimensions
                input_tile_width = input_end_x - input_start_x
                input_tile_height = input_end_y - input_start_y

                tile_idx = y * tiles_x + x + 1

                print("  Tile %d/%d (x=%d y=%d %dx%d)" % \
                    (tile_idx, tiles_x * tiles_y, x, y, input_tile_width, input_tile_height))

                input_tile = input_image[input_start_x_pad:input_end_x_pad, input_start_y_pad:input_end_y_pad]

                # upscale tile
                output_tile = self.upscaler.upscale(input_tile)

                # output tile area on total image
                output_start_x = input_start_x * scale_factor
                output_end_x = input_end_x * scale_factor

                output_start_y = input_start_y * scale_factor
                output_end_y = input_end_y * scale_factor

                # output tile area without padding
                output_start_x_tile = (input_start_x - input_start_x_pad) * scale_factor
                output_end_x_tile = output_start_x_tile + input_tile_width * scale_factor

                output_start_y_tile = (input_start_y - input_start_y_pad) * scale_factor
                output_end_y_tile = output_start_y_tile + input_tile_height * scale_factor

                # put tile into output image
                output_image[output_start_x:output_end_x, output_start_y:output_end_y] = \
                    output_tile[output_start_x_tile:output_end_x_tile, output_start_y_tile:output_end_y_tile]

        return output_image