In [None]:
import os

import numpy as np
import cv2 as cv
import matplotlib.pyplot as plt
%matplotlib inline
from tqdm.notebook import tqdm, trange

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn

from skimage.io import imread, imsave
from skimage.util import img_as_float32, img_as_ubyte

from models import PAT
from ext_utils import find_good_matching_points, _project_mesh_grid_to_indices_cube, move_pixel_value

In [None]:
# load source images
data_folder = '../data/quad_camera_sample/'
quad = [imread(os.path.join(data_folder, fn)) for fn in \
        ['cam4_w_exp150.png', 'cam1_r_exp500.png', 'cam2_g_exp300.png', 'cam3_b_exp800.png']]

# precrop and scale low-res images for better inference
hr = quad[0]
h, w = hr.shape
lrs = [cv.resize(img[h//5:-h//5, w//5:-w//5], (w,h), interpolation=cv.INTER_CUBIC) \
       for img in quad[1:]]

In [None]:
# show source images
fig, axes = plt.subplots(2, 2, figsize=(16, 12))
for k in range(3):
    _tmp_img = np.zeros((h,w,3), dtype=np.uint8)
    _tmp_img[..., k] = lrs[k]
    axes.flatten()[k].imshow(_tmp_img)
axes.flatten()[-1].imshow(hr, cmap='gray')
plt.show()

In [None]:
# load network
net = PAT(1, in_channel=1, num_input=4).to('cuda')
net = nn.DataParallel(net)
net.eval()
cudnn.benchmark = True
pretrained_dict = torch.load('./weights/quad_input_weights.pth.tar')
net.load_state_dict(pretrained_dict['state_dict'])

In [None]:
# get query matrix
img_left = torch.from_numpy(img_as_float32(hr)[..., np.newaxis].transpose((2, 0, 1))).unsqueeze(0).to('cuda')
with torch.no_grad():
    x_left = net.module.init_feature(img_left)
    buffer_left = net.module.pam.rb(x_left)
    Q = net.module.pam.b1(buffer_left)
    
    # prepare fused feature
    fused_feature = torch.zeros((1, 256, h, w)).float().to('cuda')
    fused_feature[:,-64:,:,:,] = x_left

In [None]:
# attention receptive field and block parameters

hw = 8 # half window width
s = 1  # strides

ph, pw = 200, 320 # patch width
m, n = np.ceil(h/ph).astype(int), np.ceil(w/pw).astype(int)
i_list, j_list = np.meshgrid(np.arange(m), np.arange(n), indexing='ij')
ij_list = np.stack([i_list, j_list], axis=-1).reshape(-1, 2)

In [None]:
# process one image at a time
for k in trange(3): 
    # find rough correspondence with SIFT
    pts1, pts2 = find_good_matching_points(lrs[k], hr)
    M_h2l, _ = cv.findHomography(pts2.reshape(-1,1,2),
                                 pts1.reshape(-1,1,2),
                                 cv.RANSAC, 5.0)
    yys, xxs = _project_mesh_grid_to_indices_cube((w, h), (w, h), M_h2l, hw, s) # original PAT abuse x/y 
                                                                                # for first/second index
    # get key and value matrices
    img_right = torch.from_numpy(img_as_float32(lrs[k])[..., np.newaxis].transpose((2, 0, 1))).unsqueeze(0).to('cuda')
    with torch.no_grad():
        x_right = net.module.init_feature(img_right)
        buffer_right = net.module.pam.rb(x_right)
        K = net.module.pam.b2s[k](buffer_right)
        V = net.module.pam.b3s[k](buffer_right)
        # apply attention block by block
        for (i, j) in tqdm(ij_list):
            xl, xu, yl, yu = i*ph, i*ph+ph, j*pw, j*pw+pw
            Q_ = Q[:, :, xl:xu,yl:yu].contiguous()
            Po = (torch.from_numpy(xxs[xl:xu, yl:yu][np.newaxis]), 
                  torch.from_numpy(yys[xl:xu, yl:yu][np.newaxis]))
            buffer, _ = net.module.pam.fe_pam(Q_, K, V, Po, False)
            fused_feature[:, k*64:(k+1)*64, xl:xu, yl:yu] = buffer
            
with torch.no_grad():
    out = net.module.pam.fusion(fused_feature)
    out = net.module.upscale(out)
    
pred = out.squeeze().cpu().numpy().transpose(1, 2, 0)
del yys, xxs

In [None]:
# white balance via normalize with channel inputs
for k, lr in enumerate(lrs):
    u = np.mean(img_as_float32(lr).flatten())
    s = np.std(img_as_float32(lr).flatten())
    pred[...,k] = move_pixel_value(pred[...,k],u,s)
    
pred = np.clip(pred,0,1)
plt.figure(figsize=(16, 12))
plt.imshow(pred)
plt.show()

In [None]:
subfolder = os.path.join("../results", "pat_quad_camera")
if not os.path.exists(subfolder):
    os.mkdir(subfolder)

for k, lr in enumerate(lrs):
    _tmp_img = np.zeros((h,w,3), dtype=np.uint8)
    _tmp_img[..., k] = lr
    imsave(os.path.join(subfolder, 'lr_{}.png'.format(k)), _tmp_img, check_contrast=False)
imsave(os.path.join(subfolder, 'hr.png'), hr)
imsave(os.path.join(subfolder, 'pred.png'), img_as_ubyte(pred))