In [6]:
from svd_replacement import SRConv
import numpy as np
import os
import torch

In [7]:
import PIL
from PIL import Image
def load_img(path):
	image = Image.open(path).convert("RGB")
	w, h = image.size  # pil_img.size = [w, h]
	print(f"loaded input image of size (width:{w}, height:{h}) from {path}")
	w, h = map(lambda x: x - x % 8, (w, h))  # resize to integer multiple of 32
	image = image.resize((w, h), resample=PIL.Image.LANCZOS)
	image = np.array(image).astype(np.float32) / 255.0
	image = image[None].transpose(0, 3, 1, 2)
	image = torch.from_numpy(image)
	return 2.*image - 1.

factor = 2
# SR results
src = f"results/sun-test-x{factor}/test_OmniSSR_gamma-latent-0.5_gamma-erp-1.0_input-size-512_pre-upscale-4_nrows-4_fov-75-75_patchsize-512-512"

# Lr image
# lr_img_path = f"datasets/lau_dataset_resize_clean/odisr/testing/LR_erp/X{factor}.00"
lr_img_path = f"datasets/lau_dataset_resize_clean/sun_test/LR_erp/X{factor}.00"

gamma = 1  # gamma_p
init_img_path = os.path.join(src, "erp_output")
out_path = os.path.join(src, f"post-gamma-{gamma}")
os.makedirs(out_path, exist_ok=True)

def bicubic_kernel(x, a=-0.5):
    if abs(x) <= 1:
        return (a + 2)*abs(x)**3 - (a + 3)*abs(x)**2 + 1
    elif 1 < abs(x) and abs(x) < 2:
        return a*abs(x)**3 - 5*a*abs(x)**2 + 8*a*abs(x) - 4*a
    else:
        return 0

k = np.zeros((factor * 4))
for i in range(factor * 4):
    x = (1/factor)*(i- np.floor(factor*4/2) +0.5)
    k[i] = bicubic_kernel(x)
k = k / np.sum(k)
kernel = torch.from_numpy(k).float().to('cuda')
H_funcs = SRConv(kernel / kernel.sum(), \
                    3, 256, 'cuda', stride = factor)

init_img_list = list(filter(lambda f: ".png" in f, os.listdir(init_img_path)))
init_img_list = sorted(init_img_list)
lr_img_list = list(filter(lambda f: ".png" in f, os.listdir(lr_img_path)))
lr_img_list = sorted(lr_img_list)

init_images = []
lr_images = []

for item in init_img_list:        
    cur_image = load_img(os.path.join(init_img_path, item))
    init_images.append(cur_image)

for item in lr_img_list:        
    cur_image = load_img(os.path.join(lr_img_path, item))
    lr_images.append(cur_image)



loaded input image of size (width:2048, height:1024) from results/erp-bicubic/sun-test-x4/test_PanoStableSR_V3_gamma-latent-0.5_gamma-erp-1_input-size-512_pre-upscale-1_nrows-4_fov-75-75_patchsize-1024-1024/erp_output/0000.png
loaded input image of size (width:2048, height:1024) from results/erp-bicubic/sun-test-x4/test_PanoStableSR_V3_gamma-latent-0.5_gamma-erp-1_input-size-512_pre-upscale-1_nrows-4_fov-75-75_patchsize-1024-1024/erp_output/0001.png
loaded input image of size (width:2048, height:1024) from results/erp-bicubic/sun-test-x4/test_PanoStableSR_V3_gamma-latent-0.5_gamma-erp-1_input-size-512_pre-upscale-1_nrows-4_fov-75-75_patchsize-1024-1024/erp_output/0002.png
loaded input image of size (width:2048, height:1024) from results/erp-bicubic/sun-test-x4/test_PanoStableSR_V3_gamma-latent-0.5_gamma-erp-1_input-size-512_pre-upscale-1_nrows-4_fov-75-75_patchsize-1024-1024/erp_output/0003.png
loaded input image of size (width:2048, height:1024) from results/erp-bicubic/sun-test-x4/te

In [8]:

from einops import rearrange, repeat
from overlapping_tile import partion_overlapping_window, reverse_overlapping_window

def null_space_decomposition(x0:torch.Tensor, y:torch.Tensor, H_funcs):
    # DDNM #
    B, C, H ,W = x0.shape
    patch_size = H_funcs.img_dim
    scale = H_funcs.ratio
    overlap_ratio = 0
    x0_patch, hr_paddings = partion_overlapping_window(x0, patch_size, overlap_ratio)
    y_patch, lr_paddings = partion_overlapping_window(y, int(patch_size/scale), overlap_ratio)

    visual_x0_patch =  rearrange(x0_patch, 'B (C Hp Wp) Np -> (B Np) C Hp Wp', C=C, Hp=patch_size, Np=x0_patch.shape[-1])
    visual_y_patch =  rearrange(y_patch, 'B (C Hp Wp) Np -> (B Np) C Hp Wp', C=C, Hp=int(patch_size/scale), Np=y_patch.shape[-1])

    visual_x0_hat_patch = visual_x0_patch - H_funcs.H_pinv(
        H_funcs.H(visual_x0_patch.reshape(visual_x0_patch.size(0), -1)) - visual_y_patch.reshape(visual_y_patch.size(0), -1)
        ).reshape(*visual_x0_patch.size())

    x0_hat_patch = rearrange(visual_x0_hat_patch, '(B Np) C Hp Wp -> B (C Hp Wp) Np', C=C, Hp=patch_size, Np=x0_patch.shape[-1])

    x0_hat = reverse_overlapping_window(x0_hat_patch, (B, C, H, W), hr_paddings, patch_size, overlap_ratio)

    return x0_hat

In [9]:
import torchvision.utils as tvu
for n in range(len(init_images)):
    lr_image = lr_images[n].to('cuda')
    init_image = init_images[n].to('cuda')
    basename = os.path.splitext(os.path.basename(init_img_list[n]))[0]

    new_image = null_space_decomposition(x0=init_image, y=lr_image, H_funcs=H_funcs)

    new_image = (1-gamma) * init_image + gamma * new_image

    tvu.save_image(torch.clamp((new_image + 1.0) / 2.0, min=0, max=1), f"{out_path}/{basename}.png")

In [10]:
# f"{out_path}/{basename}.png"