In [31]:
import glob
import numpy as np
import os
from os.path import join
from os import listdir
import argparse
import math
from functools import partial
import time

from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as compare_ssim

from torchsummary import summary

from RealESRGAN.rrdbnet_arch import RRDBNet
from RealESRGAN.model import RealESRGAN

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch.nn.utils.parametrize as parametrize
from torch.nn.utils import spectral_norm
import torch.nn.functional as F
from torchvision.models import vgg19
from torchvision.utils import save_image
from torch import nn
from PIL import Image
from torchvision.transforms import Compose, RandomHorizontalFlip, RandomCrop, ToPILImage, ToTensor, Resize, CenterCrop

device = 'cuda' if torch.cuda.is_available() else 'cpu'


In [32]:
!nvidia-smi

Wed Dec 13 13:39:57 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02              Driver Version: 530.30.02    CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla V100-PCIE-32GB            Off| 00000000:41:00.0 Off |                    0 |
| N/A   44C    P0               40W / 250W|  24062MiB / 32768MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

## Base Classes and Dataloader Methods

In [33]:
def calculate_valid_crop_size(crop_size, upscale_factor):
    return crop_size - (crop_size % upscale_factor)

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])

def train_hr_transform(crop_size):
    return Compose([
        RandomCrop(crop_size),
        RandomHorizontalFlip(),
        ToTensor()
    ])

def train_lr_transform(crop_size, upscale_factor):
    return Compose([
        ToPILImage(),
        Resize(crop_size // upscale_factor, interpolation=Image.BICUBIC),
        ToTensor()
    ])

def display_transform():
    return Compose([
        ToPILImage(),
        Resize(400),
        CenterCrop(400),
        ToTensor()
    ])

In [34]:
class TrainDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, crop_size, upscale_factor):
        super().__init__()
        self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]
        self.hr_transform = train_hr_transform(crop_size)
        self.lr_transform = train_lr_transform(crop_size, upscale_factor)

    def __getitem__(self, index):
        hr_image = self.hr_transform(Image.open(self.image_filenames[index]))
        lr_image = self.lr_transform(hr_image)
        return lr_image, hr_image

    def __len__(self):
        return len(self.image_filenames)


class ValDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, upscale_factor):
        super().__init__()
        self.upscale_factor = upscale_factor
        self.image_filenames = [join(dataset_dir, x) for x in listdir(dataset_dir) if is_image_file(x)]

    def __getitem__(self, index):
        hr_image = Image.open(self.image_filenames[index])
        w, h = hr_image.size
        crop_size = calculate_valid_crop_size(min(w, h), self.upscale_factor)
        lr_scale = Resize(crop_size // self.upscale_factor, interpolation=Image.BICUBIC)
        hr_scale = Resize(crop_size, interpolation=Image.BICUBIC)
        hr_image = CenterCrop(crop_size)(hr_image)
        lr_image = lr_scale(hr_image)
        hr_restore_img = hr_scale(lr_image)
        return ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image)

    def __len__(self):
        return len(self.image_filenames)


class TestDatasetFromFolder(Dataset):
    def __init__(self, dataset_dir, upscale_factor):
        super().__init__()
        self.lr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/data/'
        self.hr_path = dataset_dir + '/SRF_' + str(upscale_factor) + '/target/'
        self.upscale_factor = upscale_factor
        self.lr_filenames = [join(self.lr_path, x) for x in listdir(self.lr_path) if is_image_file(x)]
        self.hr_filenames = [join(self.hr_path, x) for x in listdir(self.hr_path) if is_image_file(x)]

    def __getitem__(self, index):
        image_name = self.lr_filenames[index].split('/')[-1]
        lr_image = Image.open(self.lr_filenames[index])
        w, h = lr_image.size
        hr_image = Image.open(self.hr_filenames[index])
        hr_scale = Resize((self.upscale_factor * h, self.upscale_factor * w), interpolation=Image.BICUBIC)
        hr_restore_img = hr_scale(lr_image)
        return image_name, ToTensor()(lr_image), ToTensor()(hr_restore_img), ToTensor()(hr_image)

    def __len__(self):
        return len(self.lr_filenames)


### Class for Exponential Moving Average

In [35]:
class EMA():
    def __init__(self, model, decay):
        self.model = model
        self.decay = decay
        self.shadow = {}
        self.backup = {}

    def register(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def update(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name].to(param.device)
                self.shadow[name] = new_average.clone()

    def apply_shadow(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                self.backup[name] = param.data
                param.data = self.shadow[name]

    def restore(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.backup
                param.data = self.backup[name]
        self.backup = {}

### Base LoRA Parametrization class

In [36]:
class LoRAParametrization(nn.Module):
    def __init__(self, fan_in, fan_out, fan_in_fan_out=False, rank=4, lora_dropout_p=0.0, lora_alpha=1):
        super().__init__()
        # if weight is stored as (fan_out, fan_in), the memory layout of A & B follows (W + BA)x
        # otherwise, it's x(W + AB). This allows us to tie the weights between linear layers and embeddings

        self.swap = (lambda x: (x[1], x[0])) if fan_in_fan_out else (lambda x: x)
        self.lora_A = nn.Parameter(torch.zeros(self.swap((rank, fan_in))))
        self.lora_B = nn.Parameter(torch.zeros(self.swap((fan_out, rank))))
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        self.lora_alpha, self.rank = lora_alpha, rank
        self.scaling = lora_alpha / rank
        self.lora_dropout = nn.Dropout(p=lora_dropout_p) if lora_dropout_p > 0 else lambda x: x
        self.dropout_fn = self._dropout if lora_dropout_p > 0 else lambda x: x
        self.register_buffer("lora_dropout_mask", torch.ones(self.swap((1, fan_in)), dtype=self.lora_A.dtype))
        self.forward_fn = self.lora_forward

    def _dropout(self, A):
        # to mimic the original implementation: A @ dropout(x), we do (A * dropout(ones)) @ x
        return A * self.lora_dropout(self.lora_dropout_mask)

    def lora_forward(self, X):
        return X + torch.matmul(*self.swap((self.lora_B, self.dropout_fn(self.lora_A)))).view(X.shape) * self.scaling

    def forward(self, X):
        return self.forward_fn(X)

    def disable_lora(self):
        self.forward_fn = lambda x: x

    def enable_lora(self):
        self.forward_fn = self.lora_forward

    @classmethod
    def from_linear(cls, layer, rank=4, lora_dropout_p=0.0, lora_alpha=1):
        fan_out, fan_in = layer.weight.shape
        return cls(
            fan_in, fan_out, fan_in_fan_out=False, rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha
        )

    @classmethod
    def from_conv2d(cls, layer, rank=4, lora_dropout_p=0.0, lora_alpha=1):
        fan_out, fan_in = layer.weight.view(layer.weight.shape[0], -1).shape
        return cls(
            fan_in, fan_out, fan_in_fan_out=False, rank=rank, lora_dropout_p=lora_dropout_p, lora_alpha=lora_alpha
        )


default_lora_config = {
    nn.Conv2d: {
        "weight": partial(LoRAParametrization.from_conv2d, rank=4),
    },
    nn.Linear: {
        "weight": partial(LoRAParametrization.from_linear, rank=4),
    },
}


def apply_lora(layer, register=True, merge=False, lora_config=default_lora_config):
    """add lora parametrization to a layer, designed to be used with model.apply"""
    # print("in apply_lora function: ", layer)
    if register:
        if type(layer) in lora_config:
            for attr_name, parametrization in lora_config[type(layer)].items():
                parametrize.register_parametrization(layer, attr_name, parametrization(layer))
    # else:  # this will remove all parametrizations, use with caution
    #     if hasattr(layer, "parametrizations"):
    #         for attr_name in layer.parametrizations.keys():
    #             parametrize.remove_parametrizations(layer, attr_name, leave_parametrized=merge)


def add_lora(model, lora_config=default_lora_config):
    """add lora parametrization to all layers in a model. Calling it twice will add lora twice"""
    model.apply(partial(apply_lora, lora_config=lora_config))

def merge_lora(model):
    """merge lora parametrization to all layers in a model. This will remove all parametrization"""
    model.apply(partial(apply_lora, register=False, merge=True))


def remove_lora(model):
    """remove lora parametrization to all layers in a model. This will remove all parametrization"""
    model.apply(partial(apply_lora, register=False, merge=False))

def apply_to_lora(fn):
    """apply a function to LoRAParametrization layers, designed to be used with model.apply"""

    def apply_fn(layer):
        if isinstance(layer, LoRAParametrization):
            fn(layer)

    return apply_fn


enable_lora = lambda model: model.apply(apply_to_lora(lambda x: x.enable_lora()))
disable_lora = lambda model: model.apply(apply_to_lora(lambda x: x.disable_lora()))


# ------------------- helper function for collecting parameters for training/saving -------------------


def name_is_lora(name):
    return (
        len(name.split(".")) >= 4
        and (name.split(".")[-4]) == "parametrizations"
        and name.split(".")[-1] in ["lora_A", "lora_B"]
    )


def name_is_bias(name):
    return name.split(".")[-1] == "bias"


def get_params_by_name(model, print_shapes=False, name_filter=None):
    for n, p in model.named_parameters():
        # print('n: ',n,' p: ',p)
        if name_filter is None or name_filter(n):
            if print_shapes:
                print(n, p.shape)
            yield p


def get_lora_params(model, print_shapes=False):
    return get_params_by_name(model, print_shapes=print_shapes, name_filter=name_is_lora)


def get_bias_params(model, print_shapes=False):
    return get_params_by_name(model, print_shapes=print_shapes, name_filter=name_is_bias)


def get_lora_state_dict(model):
    return {k: v for k, v in model.state_dict().items() if name_is_lora(k)}


# ------------------- helper function for inferencing with multiple lora -------------------


def _prepare_for_multiple_lora(lora_layer):
    lora_layer.lora_As = []
    lora_layer.lora_Bs = []


def _append_lora(lora_layer):
    lora_layer.lora_As.append(nn.Parameter(lora_layer.lora_A.clone()))
    lora_layer.lora_Bs.append(nn.Parameter(lora_layer.lora_B.clone()))


def load_multiple_lora(model, lora_state_dicts):
    model.apply(apply_to_lora(_prepare_for_multiple_lora))
    for state_dict in lora_state_dicts:
        _ = model.load_state_dict(state_dict, strict=False)
        model.apply(apply_to_lora(_append_lora))
    return model


def _select_lora(lora_layer, index):
    lora_layer.lora_A = lora_layer.lora_As[index]
    lora_layer.lora_B = lora_layer.lora_Bs[index]


def select_lora(model, index):
    model.apply(apply_to_lora(lambda x: _select_lora(x, index)))
    return model


# ------------------- helper function for tying and untieing weights -------------------


def tie_weights(linear: nn.Linear, embedding: nn.Embedding):
    """tie the weights of the linear layer and the embedding layer both with the same lora"""
    # this line below is optional if the original is already tied
    embedding.parametrizations.weight.original = linear.parametrizations.weight.original
    embedding.parametrizations.weight[0].lora_A = linear.parametrizations.weight[0].lora_B
    embedding.parametrizations.weight[0].lora_B = linear.parametrizations.weight[0].lora_A


def untie_weights(linear: nn.Linear, embedding: nn.Embedding):
    """untie the weights of the linear layer and the embedding layer"""
    embedding.parametrizations.weight.original = nn.Parameter(embedding.weight.original.clone())
    embedding.parametrizations.weight[0].lora_A = nn.Parameter(embedding.parametrizations.weight[0].lora_A.clone())
    embedding.parametrizations.weight[0].lora_B = nn.Parameter(embedding.parametrizations.weight[0].lora_B.clone())

### Load Generator with pretrained Weights (default = weights for DIV2K, Outdoor Scene, Flikr2k)

In [37]:
def load_generator(weight_path = './weights/RealESRGAN_x4.pth'):
    
    generator =  RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, 
                         num_block=23, num_grow_ch=32, scale=4
                        )
    if weight_path:
        loadnet = torch.load(weight_path)
        if 'params' in loadnet:
            generator.load_state_dict(loadnet['params'], strict=True)
        elif 'params_ema' in loadnet:
            generator.load_state_dict(loadnet['params_ema'], strict=True)
        else:
            generator.load_state_dict(loadnet, strict=True)
            
    return generator

### Load VGG Conceptual loss feature extractor and Discriminator Class Definition

In [38]:
class FeatureExtractor(nn.Module):
    def __init__(self, layers=[0, 5, 10, 19, 28]):
        super(FeatureExtractor, self).__init__()
        vgg19_model = vgg19(pretrained=True)
        self.layers = layers
        self.net = nn.Sequential(*[vgg19_model.features[i] for i in range(max(layers) + 1)])

    def forward(self, img):
        features = []
        for i in range(len(self.net)):
            img = self.net[i](img)
            if i in self.layers:
                features.append(img)
        return features
    
class UNetDiscriminatorSN(nn.Module):
    def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
        super().__init__()
        norm = spectral_norm
        self.skip_connection = skip_connection
        self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
        # downsample  
        self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False)) # (w/h + 2 - 4 + 2) / 2 = (w/h) / 2
        self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False)) # --> (w/h) / 4
        self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False)) # --> (w/h) / 8
        # upsample  
        self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False)) # (w/h + 2 - 3 + 1) = w/h
        self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
        self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))  
        # extra convolutions  
        self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
        self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
        self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)

    def forward(self, x):
        # downsample  
        x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
        x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
        x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
        x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)

        # upsample  
        x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
        x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)

        if self.skip_connection:
            x4 = x4 + x2
        x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
        x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)

        if self.skip_connection:
            x5 = x5 + x1
        x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
        x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)

        if self.skip_connection:
            x6 = x6 + x0
        
        # extra convolutions  
        out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
        out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
        out = self.conv9(out)

        return out


### Calculation of Parameter difference between pretrained and LORA

In [39]:
def linear_layer_parameterization(layer, device, rank=4, lora_alpha=1):
    # Only add the parameterization to the weight matrix, ignore the Bias

    # From section 4.2 of the paper:
    #   We limit our study to only adapting the attention weights for downstream tasks and freeze the MLP modules (so they are not trained in downstream tasks) both for simplicity and parameter-efficiency.
    #   [...]
    #   We leave the empirical investigation of [...], and biases to a future work.

    features_in, features_out = layer.weight.shape
    return LoRAParametrization(
        features_in, features_out, rank=rank, lora_alpha=lora_alpha
    )

def conv_layer_parameterization(layer,device,rank = 4, lora_alpha = 1):
    #     Conv layer parameterization

    features_in, features_out = layer.weight.view(layer.weight.shape[0], -1).shape
    return LoRAParametrization(
        features_in, features_out, rank=rank, lora_alpha=lora_alpha
    )

# Function to enable/disable LoRA
def enable_disable_lora(generator, enabled=True):
    for layer in generator.modules():
        if isinstance(layer, nn.Conv2d):
            layer.parametrizations["weight"][0].enabled = enabled

# Register parametrization for all convolutional layers

def register_parameterization_conv_layers(generator):
    for layer in generator.modules():
        if isinstance(layer, nn.Conv2d):
            parametrize.register_parametrization(
                layer, "weight", conv_layer_parameterization(layer, device)
            )
        
        
def display_parameter_difference(generator):        
        
    total_parameters_original = 0

    # Iterate through all layers in the model
    for index, layer in enumerate([module for module in generator.modules() if isinstance(module, nn.Conv2d)]):
        total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
        # print(f'Conv Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape}')

    print(f'Total number of convolutional parameters: {total_parameters_original:,}')
    
    print('\n\n\n')

    total_parameters_lora = 0
    total_parameters_non_lora = 0
    for index, layer in enumerate([module for module in generator.modules() if isinstance(module, nn.Conv2d)]):
        total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
        total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
        # print(
            # f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}'
        # )
    # The non-LoRA parameters count must match the original network
    assert total_parameters_non_lora == total_parameters_original
    print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
    print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
    print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
    parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
    print(f'Parameters incremment: {parameters_incremment:.3f}%')

## Setup Training 

In [40]:
generator = load_generator(weight_path = './weights/RealESRGAN_x4.pth')

# print("model summary before adding lora:")

# print(summary(generator, (3, 192, 192)))

# register_parameterization_conv_layers(generator)

# display_parameter_difference(generator)

add_lora(generator)

# print("model summary after adding lora:")

# print(summary(generator, (3, 192, 192)))

# Freeze the non-Lora parameters
for name, param in generator.named_parameters():
    if 'lora' not in name:
        # print(f'Freezing non-LoRA parameter {name}')
        param.requires_grad = False

discriminator = UNetDiscriminatorSN(3).to(device)
discriminator.load_state_dict(torch.load('./discriminator_4005.pth'))
feature_extractor = FeatureExtractor().to(device)

# set feature extractor to inference mode
feature_extractor.eval()

# Losses
criterion_GAN = torch.nn.BCEWithLogitsLoss().to(device)
criterion_content = torch.nn.L1Loss().to(device)
criterion_pixel = torch.nn.L1Loss().to(device)

# initialize optimzier

parameters = [
    {"params": list(get_lora_params(generator))},
]

print(parameters)

lr =0.0002

optimizer_G = torch.optim.Adam(parameters, lr=lr)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=1e-4)

Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor

# initialize ema
ema_G = EMA(generator, 0.999)
ema_D = EMA(discriminator, 0.999)
ema_G.register()
ema_D.register()



[{'params': [Parameter containing:
tensor([[-0.0497, -0.0002,  0.1327, -0.1323,  0.0276, -0.0889,  0.0467,  0.1698,
          0.1325,  0.0057,  0.0240, -0.1209,  0.1132, -0.0782, -0.0769, -0.1646,
         -0.1489,  0.1296,  0.0282,  0.1908,  0.0277, -0.1319,  0.1208,  0.0036,
         -0.0803,  0.1746,  0.0210],
        [ 0.1095, -0.1120,  0.1893,  0.1721, -0.0211, -0.1092,  0.1432, -0.0683,
         -0.0782, -0.1222, -0.1512,  0.0321, -0.1520,  0.1683, -0.1867, -0.1118,
          0.1052, -0.0019, -0.1779,  0.1529, -0.1633, -0.1892, -0.1337,  0.0991,
          0.1030, -0.1551, -0.0389],
        [-0.0137, -0.0226,  0.0794, -0.1134, -0.1190, -0.1841,  0.0603,  0.0386,
          0.0333, -0.0502, -0.0436,  0.1169,  0.1633,  0.1262, -0.0251,  0.0481,
          0.0840, -0.1384,  0.1446,  0.1706, -0.0554, -0.1491, -0.1706,  0.1628,
          0.1024,  0.1136,  0.0880],
        [ 0.0072, -0.0127, -0.0109,  0.0765,  0.1331,  0.0102,  0.0238,  0.1179,
         -0.0184, -0.0555,  0.0532, -0.0993,

In [41]:
generator.to(device)

RRDBNet(
  (conv_first): ParametrizedConv2d(
    3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
    (parametrizations): ModuleDict(
      (weight): ParametrizationList(
        (0): LoRAParametrization()
      )
    )
  )
  (body): Sequential(
    (0): RRDB(
      (rdb1): ResidualDenseBlock(
        (conv1): ParametrizedConv2d(
          64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
          (parametrizations): ModuleDict(
            (weight): ParametrizationList(
              (0): LoRAParametrization()
            )
          )
        )
        (conv2): ParametrizedConv2d(
          96, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
          (parametrizations): ModuleDict(
            (weight): ParametrizationList(
              (0): LoRAParametrization()
            )
          )
        )
        (conv3): ParametrizedConv2d(
          128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
          (parametrizations): ModuleDict(
            (we

In [42]:
def train(dataset_path, crop_size, upscale_factor, batch_size, warmup_batches, 
          n_batches, batch, lr,optimizer_G, sample_interval,save_path_suffix, device = 'cuda'):

    hr_shape = (crop_size, crop_size)
    channels = 3

    train_set = TrainDatasetFromFolder(dataset_path, crop_size=crop_size, upscale_factor = 4)
    train_loader = DataLoader(dataset=train_set, num_workers=4, batch_size=batch_size, shuffle=True)
    
    batch = 0
    total_psnr = []
    total_ssim = []


    while batch < n_batches:
        for i, (data, target) in enumerate(train_loader):
            batches_done = batch + i

            imgs_lr = data.to(device)
            imgs_hr = target.to(device)

            valid = torch.ones((imgs_lr.size(0), 1, *imgs_hr.shape[-2:]), requires_grad=False).to(device)
            fake = torch.zeros((imgs_lr.size(0), 1, *imgs_hr.shape[-2:]), requires_grad=False).to(device)

            # ---------------------
            # Training Generator
            # ---------------------

            optimizer_G.zero_grad()

            gen_hr = generator(imgs_lr)

            # Compute PSNR and SSIM
            with torch.no_grad():
                gen_hr_np = gen_hr.detach().cpu().numpy().transpose(0, 2, 3, 1)  # Convert tensor to NumPy (NCHW to NHWC)
                imgs_hr_np = imgs_hr.cpu().numpy().transpose(0, 2, 3, 1)  # Convert tensor to NumPy (NCHW to NHWC)

                psnr_values = []
                ssim_values = []

                for gen_img, hr_img in zip(gen_hr_np, imgs_hr_np):
                    # Ensure pixel values are in the range [0, 1]
                    gen_img = np.clip(gen_img, 0.0, 1.0)
                    hr_img = np.clip(hr_img, 0.0, 1.0)

                    psnr = compare_psnr(hr_img, gen_img)
                    ssim = compare_ssim(hr_img, gen_img, multichannel=True)

                    psnr_values.append(psnr)
                    ssim_values.append(ssim)

                avg_psnr = np.mean(psnr_values)
                avg_ssim = np.mean(ssim_values)

                total_psnr.append(avg_psnr)
                total_ssim.append(avg_ssim)

                # Print or log the average PSNR and SSIM for this batch
                print('[Iteration %d/%d] [Batch %d/%d] [PSNR: %.4f] [SSIM: %.4f]' % (batches_done, n_batches, i, len(train_loader), avg_psnr, avg_ssim))

            loss_pixel = criterion_pixel(gen_hr, imgs_hr)

            if batches_done < warmup_batches:
                loss_pixel.backward()
                optimizer_G.step()
                ema_G.update()
                print(
                    '[Iteration %d/%d] [Batch %d/%d] [G pixel: %f]' %
                    (batches_done, n_batches, i, len(train_loader), loss_pixel.item())
                )
                continue
            elif batches_done == warmup_batches:
                optimizer_G = torch.optim.Adam(generator.parameters(), lr=1e-4)

            pred_real = discriminator(imgs_hr).detach()
            pred_fake = discriminator(gen_hr)

            loss_GAN = (
                criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), valid) +
                criterion_GAN(pred_real - pred_fake.mean(0, keepdim=True), fake)
            ) / 2

            gen_features = feature_extractor(gen_hr)
            real_features = feature_extractor(imgs_hr)
            real_features = [real_f.detach() for real_f in real_features]
            loss_content = sum(criterion_content(gen_f, real_f) * w for gen_f, real_f, w in zip(gen_features, real_features, [0.1, 0.1, 1, 1, 1]))

            loss_G = loss_content + 0.1 * loss_GAN + loss_pixel

            loss_G.backward()
            optimizer_G.step()
            ema_G.update()

            # ---------------------
            # Train Discriminator
            # ---------------------

            optimizer_D.zero_grad()

            pred_real = discriminator(imgs_hr)
            pred_fake = discriminator(gen_hr.detach())

            loss_real = criterion_GAN(pred_real - pred_fake.mean(0, keepdim=True), valid)
            loss_fake = criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), fake)

            loss_D = (loss_real + loss_fake) / 2

            loss_D.backward()
            optimizer_D.step()
            ema_D.update()

            # -------------------------
            # Log Progress
            # -------------------------

            print(
                '[Iteration %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, content: %f, adv: %f, pixel: %f]' %
                (
                    batches_done,
                    n_batches,
                    i,
                    len(train_loader),
                    loss_D.item(),
                    loss_G.item(),
                    loss_content.item(),
                    loss_GAN.item(),
                    loss_pixel.item()
                )
            )

            if batches_done % sample_interval == 0:
                imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4, mode='bicubic')
                img_grid = torch.clamp(torch.cat((imgs_lr, gen_hr, imgs_hr), -1), min=0, max=1)
                save_image(img_grid, f'images/training/{save_path_suffix}_%d.png' % batches_done, nrow=1, normalize=False)

        batch = batches_done + 1

        ema_G.apply_shadow()
        ema_D.apply_shadow()

        torch.save(generator.state_dict(), f'saved_models/generator_{save_path_suffix}_%d.pth' % batch)
        torch.save(discriminator.state_dict(), f'saved_models/discriminator_{save_path_suffix}_%d.pth' % batch)

        ema_G.restore()
        ema_D.restore()

    print("total_psnr: ", np.mean(total_psnr))
    print("total_ssim: ", np.mean(total_ssim))

In [None]:
dataset_path='./Set14/high_res_resized'
crop_size = 384
upscale_factor = 4
batch_size = 4
warmup_batches = 100
n_batches = 500
residual_blocks = 23
batch = 0
lr = 0.0002
sample_interval = 5
channels = 3

save_path_suffix = 'lora_set14_96inp'

start_time = time.time()
train(dataset_path, crop_size, upscale_factor, batch_size, warmup_batches, n_batches, batch, lr,optimizer_G, sample_interval, save_path_suffix)
end_time = time.time()
finetuning_time = end_time - start_time
print(f"Finetuning time for image {num}: {finetuning_time} seconds")

  ssim = compare_ssim(hr_img, gen_img, multichannel=True)


[Iteration 0/500] [Batch 0/4] [PSNR: 26.0389] [SSIM: 0.8574]
[Iteration 0/500] [Batch 0/4] [G pixel: 0.034460]
[Iteration 1/500] [Batch 1/4] [PSNR: 22.5977] [SSIM: 0.8015]
[Iteration 1/500] [Batch 1/4] [G pixel: 0.052259]
[Iteration 2/500] [Batch 2/4] [PSNR: 25.3881] [SSIM: 0.8575]
[Iteration 2/500] [Batch 2/4] [G pixel: 0.036899]
[Iteration 3/500] [Batch 3/4] [PSNR: 21.2031] [SSIM: 0.6801]
[Iteration 3/500] [Batch 3/4] [G pixel: 0.059131]
[Iteration 4/500] [Batch 0/4] [PSNR: 26.9871] [SSIM: 0.8716]
[Iteration 4/500] [Batch 0/4] [G pixel: 0.031095]
[Iteration 5/500] [Batch 1/4] [PSNR: 24.4160] [SSIM: 0.8263]
[Iteration 5/500] [Batch 1/4] [G pixel: 0.043334]
[Iteration 6/500] [Batch 2/4] [PSNR: 23.4324] [SSIM: 0.7633]
[Iteration 6/500] [Batch 2/4] [G pixel: 0.050119]
[Iteration 7/500] [Batch 3/4] [PSNR: 23.4738] [SSIM: 0.8819]
[Iteration 7/500] [Batch 3/4] [G pixel: 0.042753]
[Iteration 8/500] [Batch 0/4] [PSNR: 26.2222] [SSIM: 0.8660]
[Iteration 8/500] [Batch 0/4] [G pixel: 0.036261]
[

In [16]:
lora_state_dict = get_lora_state_dict(generator)

torch.save(lora_state_dict, './lora_weights/500Iter_esrgan_384hr_set14_finetune.pth')

In [17]:
summary(generator,(3,192,192))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
LoRAParametrization-1              [-1, 3, 3, 3]               0
ParametrizedConv2d-2         [-1, 64, 192, 192]           1,792
LoRAParametrization-3              [-1, 3, 3, 3]               0
LoRAParametrization-4              [-1, 3, 3, 3]               0
LoRAParametrization-5              [-1, 3, 3, 3]               0
LoRAParametrization-6              [-1, 3, 3, 3]               0
LoRAParametrization-7             [-1, 64, 3, 3]               0
ParametrizedConv2d-8         [-1, 32, 192, 192]          18,464
LoRAParametrization-9             [-1, 64, 3, 3]               0
LoRAParametrization-10             [-1, 64, 3, 3]               0
LoRAParametrization-11             [-1, 64, 3, 3]               0
LoRAParametrization-12             [-1, 64, 3, 3]               0
        LeakyReLU-13         [-1, 32, 192, 192]               0
LoRAParametrization-14    

In [18]:
merge_lora(generator)

In [19]:
summary(generator, (3, 192, 192))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
LoRAParametrization-1              [-1, 3, 3, 3]               0
ParametrizedConv2d-2         [-1, 64, 192, 192]           1,792
LoRAParametrization-3              [-1, 3, 3, 3]               0
LoRAParametrization-4              [-1, 3, 3, 3]               0
LoRAParametrization-5              [-1, 3, 3, 3]               0
LoRAParametrization-6              [-1, 3, 3, 3]               0
LoRAParametrization-7             [-1, 64, 3, 3]               0
ParametrizedConv2d-8         [-1, 32, 192, 192]          18,464
LoRAParametrization-9             [-1, 64, 3, 3]               0
LoRAParametrization-10             [-1, 64, 3, 3]               0
LoRAParametrization-11             [-1, 64, 3, 3]               0
LoRAParametrization-12             [-1, 64, 3, 3]               0
        LeakyReLU-13         [-1, 32, 192, 192]               0
LoRAParametrization-14    

In [20]:
generator = load_generator(weight_path = './weights/RealESRGAN_x4.pth')

In [21]:
add_lora(generator)

In [22]:
_ = generator.load_state_dict(torch.load('./lora_weights/500Iter_esrgan_384hr_set14_finetune.pth'), strict=False)

In [23]:
merge_lora(generator)

In [24]:
import numpy as np
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from PIL import Image


def calculate_psnr(hr_image, sr_image):
    hr_array = np.array(hr_image).astype(np.uint8)
    sr_array = np.array(sr_image).astype(np.uint8)
    return peak_signal_noise_ratio(hr_array, sr_array)

def calculate_ssim(hr_image, sr_image):
    hr_array = np.array(hr_image)
    sr_array = np.array(sr_image)
    return structural_similarity(hr_array, sr_array, multichannel=True)


In [27]:
def get_inference_time(hr_dataset_path, path_of_lora_weights = None):
    
    generator = load_generator(weight_path = './weights/RealESRGAN_x4.pth')
    
    if path_of_lora_weights != None:
        
        add_lora(generator)
        
        _ = generator.load_state_dict(torch.load(path_of_lora_weights), strict=False)
        
        merge_lora(generator)
        
    generator.to(device)
    
    generator.eval()

    model = RealESRGAN(device, generator, scale=4)

    image_files = [f for f in os.listdir(hr_dataset_path) if os.path.isfile(os.path.join(hr_dataset_path, f))]

    total_time = 0
    
    total_psnr = []
    total_ssim = []

    # Loop through each image in the dataset
    for num, image_file in enumerate(image_files):
        path_to_image = os.path.join(hr_dataset_path, image_file)
        hr_image = Image.open(path_to_image).convert('RGB')
        lr_image = hr_image.resize((hr_image.size[0]//4, hr_image.size[1]//4))
    

        start_time = time.time()
        sr_image = model.predict(lr_image)
        end_time = time.time()

        inference_time = end_time - start_time
        total_time += inference_time
        
#         hr_image.show()
        
#         sr_image.show()
        
        psnr = calculate_psnr(hr_image, sr_image)
        ssim = calculate_ssim(hr_image, sr_image)

        total_psnr.append(psnr)
        total_ssim.append(ssim)
        

        print(f"Inference time for image {num}: {inference_time} seconds")

    return total_time / len(image_files), np.mean(total_psnr), np.mean(total_ssim)

In [26]:
print(get_inference_time('./Set14/high_res_resized', None))

xdg-open: no method available for opening '/tmp/tmp46rjc5eq.PNG'
  return structural_similarity(hr_array, sr_array, multichannel=True)
xdg-open: no method available for opening '/tmp/tmpnfctr28f.PNG'


Inference time for image 0: 1.3366646766662598 seconds


xdg-open: no method available for opening '/tmp/tmpuiyka6mk.PNG'
  return structural_similarity(hr_array, sr_array, multichannel=True)
xdg-open: no method available for opening '/tmp/tmp_ssccnq5.PNG'


Inference time for image 1: 0.9315950870513916 seconds


xdg-open: no method available for opening '/tmp/tmpbzobloja.PNG'
  return structural_similarity(hr_array, sr_array, multichannel=True)
xdg-open: no method available for opening '/tmp/tmp050j9vw0.PNG'


Inference time for image 2: 1.0434257984161377 seconds


xdg-open: no method available for opening '/tmp/tmpt4me4z3t.PNG'
  return structural_similarity(hr_array, sr_array, multichannel=True)
xdg-open: no method available for opening '/tmp/tmp75zkxoo4.PNG'


Inference time for image 3: 0.8649723529815674 seconds


xdg-open: no method available for opening '/tmp/tmpokq13xs4.PNG'
  return structural_similarity(hr_array, sr_array, multichannel=True)
xdg-open: no method available for opening '/tmp/tmp80_eug2k.PNG'


Inference time for image 4: 0.824530839920044 seconds


xdg-open: no method available for opening '/tmp/tmpfxmbutom.PNG'
  return structural_similarity(hr_array, sr_array, multichannel=True)
xdg-open: no method available for opening '/tmp/tmp53iiwaxt.PNG'


Inference time for image 5: 0.8324079513549805 seconds


xdg-open: no method available for opening '/tmp/tmp06943a0g.PNG'
  return structural_similarity(hr_array, sr_array, multichannel=True)
xdg-open: no method available for opening '/tmp/tmph5_bao9x.PNG'


Inference time for image 6: 0.8285553455352783 seconds


xdg-open: no method available for opening '/tmp/tmpsm7dedib.PNG'
  return structural_similarity(hr_array, sr_array, multichannel=True)
xdg-open: no method available for opening '/tmp/tmpg5zm3xuk.PNG'


Inference time for image 7: 0.9285309314727783 seconds


xdg-open: no method available for opening '/tmp/tmp5jrt7tb5.PNG'
  return structural_similarity(hr_array, sr_array, multichannel=True)
xdg-open: no method available for opening '/tmp/tmp3c_etzfs.PNG'


Inference time for image 8: 1.038693904876709 seconds


xdg-open: no method available for opening '/tmp/tmply0isamk.PNG'
  return structural_similarity(hr_array, sr_array, multichannel=True)
xdg-open: no method available for opening '/tmp/tmpcnzt43s6.PNG'


Inference time for image 9: 0.9774129390716553 seconds


xdg-open: no method available for opening '/tmp/tmp5rv2myhu.PNG'
  return structural_similarity(hr_array, sr_array, multichannel=True)
xdg-open: no method available for opening '/tmp/tmphjzyn1ie.PNG'


Inference time for image 10: 1.014695405960083 seconds


xdg-open: no method available for opening '/tmp/tmpiznzb5vh.PNG'
  return structural_similarity(hr_array, sr_array, multichannel=True)
xdg-open: no method available for opening '/tmp/tmpllwdjzye.PNG'


Inference time for image 11: 1.0021007061004639 seconds


xdg-open: no method available for opening '/tmp/tmpnbdr9f_w.PNG'
  return structural_similarity(hr_array, sr_array, multichannel=True)
xdg-open: no method available for opening '/tmp/tmpemuhdadq.PNG'


Inference time for image 12: 0.9534690380096436 seconds
(0.9674657674936148, 24.507952482070102, 0.7107710654165188)


In [28]:
get_inference_time('./Set14/high_res_resized', './lora_weights/500Iter_esrgan_384hr_set14_finetune.pth')

  return structural_similarity(hr_array, sr_array, multichannel=True)


Inference time for image 0: 1.3527185916900635 seconds
Inference time for image 1: 1.2986574172973633 seconds
Inference time for image 2: 1.3045752048492432 seconds
Inference time for image 3: 1.4482848644256592 seconds
Inference time for image 4: 1.2403860092163086 seconds
Inference time for image 5: 1.302368402481079 seconds
Inference time for image 6: 1.377288579940796 seconds
Inference time for image 7: 1.308201789855957 seconds
Inference time for image 8: 1.2459073066711426 seconds
Inference time for image 9: 1.303875207901001 seconds
Inference time for image 10: 1.3234174251556396 seconds
Inference time for image 11: 1.313504934310913 seconds
Inference time for image 12: 1.3203206062316895 seconds


(1.3184235646174505, 28.257841413022486, 0.7987105151211928)

In [30]:
total_parameters_original = 0

# Iterate through all layers in the model
for index, layer in enumerate([module for module in generator.modules() if isinstance(module, nn.Conv2d)]):
    total_parameters_original += layer.weight.nelement() + layer.bias.nelement()
    # print(f'Conv Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape}')

print(f'Total number of convolutional parameters: {total_parameters_original:,}')

total_parameters_lora = 0
total_parameters_non_lora = 0
for index, layer in enumerate([module for module in generator.modules() if isinstance(module, nn.Conv2d)]):
    total_parameters_lora += layer.parametrizations["weight"][0].lora_A.nelement() + layer.parametrizations["weight"][0].lora_B.nelement()
    total_parameters_non_lora += layer.weight.nelement() + layer.bias.nelement()
    # print(
        # f'Layer {index+1}: W: {layer.weight.shape} + B: {layer.bias.shape} + Lora_A: {layer.parametrizations["weight"][0].lora_A.shape} + Lora_B: {layer.parametrizations["weight"][0].lora_B.shape}'
    # )
# The non-LoRA parameters count must match the original network
assert total_parameters_non_lora == total_parameters_original
print(f'Total number of parameters (original): {total_parameters_non_lora:,}')
print(f'Total number of parameters (original + LoRA): {total_parameters_lora + total_parameters_non_lora:,}')
print(f'Parameters introduced by LoRA: {total_parameters_lora:,}')
parameters_incremment = (total_parameters_lora / total_parameters_non_lora) * 100
print(f'Parameters incremment: {parameters_incremment:.3f}%')

Total number of convolutional parameters: 16,697,987
Total number of parameters (original): 16,697,987
Total number of parameters (original + LoRA): 18,353,659
Parameters introduced by LoRA: 1,655,672
Parameters incremment: 9.915%
