In [None]:
import os
os.environ["OMP_NUM_THREADS"]="1"
os.environ["MKL_NUM_THREADS"]="1"
os.environ["CUDA_VISIBLE_DEVICES"]="1"

import sys
from skimage import io
from matplotlib import pyplot as plt
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from torchvision import transforms
import warnings

# RANSAC-Flow Path
sys.path.append('./RANSAC-Flow')
sys.path.append('./RANSAC-Flow/utils')
sys.path.append('./RANSAC-Flow/model')
sys.path.append('./RANSAC-Flow/quick_start')
from coarseAlignFeatMatch import CoarseAlign
import outil
import model as model
import PIL.Image as Image 
import kornia.geometry as tgm
if not sys.warnoptions:
    warnings.simplefilter("ignore")

# metrics
from skimage.metrics import mean_squared_error as mse 
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim
import datetime
# psnr
from skimage.util.dtype import dtype_range
from skimage._shared.utils import warn, check_shape_equality
# ssim
from skimage.util.arraycrop import crop
from scipy.ndimage import uniform_filter, gaussian_filter


In [None]:
def show_input_images(scene_images, target_images):
    plt.figure(figsize=(15,15))
    num = len(scene_images)
    plt.subplot(221), plt.imshow(scene_images[0]), plt.axis('off')
    plt.subplot(222), plt.imshow(target_images[0]), plt.axis('off')
    plt.subplot(223), plt.imshow(scene_images[num-1]), plt.axis('off')
    plt.subplot(224), plt.imshow(target_images[num-1]), plt.axis('off')
    plt.show()

In [None]:
def blend(out, source, scene, blend_type, mask=None):
    out = np.array(out)
    source = np.array(source)
    scene = np.array(scene)
    if mask is None:
        intensity = np.linalg.norm(out, axis=2)
        mask = (intensity == 0)[:,:,np.newaxis]   
    else:
        mask = mask    
    if blend_type == 'light':
        result = (out * (1 - mask) + source * mask).astype(np.uint8)
    else:
        result = (out * (1 - mask) + scene * mask).astype(np.uint8) 
    return result, mask


In [None]:
def PSNR(image_true, image_test, mask=None):
    check_shape_equality(image_true, image_test)    

    if image_true.dtype != image_test.dtype:
        warn("Inputs have mismatched dtype. Setting data_range based on "
                "im_true.", stacklevel=2)
    dmin, dmax = dtype_range[image_true.dtype.type]
    true_min, true_max = np.min(image_true), np.max(image_true)
    if true_max > dmax or true_min < dmin:
        raise ValueError(
            "im_true has intensity values outside the range expected for "
            "its data type. Please manually specify the data_range")
    if true_min >= 0:
        # most common case (255 for uint8, 1 for float)
        data_range = dmax
    else:
        data_range = dmax - dmin
    
    image_true = image_true.astype(np.float64)
    image_test = image_test.astype(np.float64)
    if mask is None:                      
        err = np.mean((image_true - image_test) ** 2, dtype=np.float64)
    else:
        mask = 1-mask
        cnt = np.count_nonzero(mask) * image_true.shape[2]
        sum = np.sum((image_true*mask - image_test*mask)**2)    
        err = sum/cnt
        '''
        plt.figure(facecolor='white')
        plt.subplot(121), plt.imshow(image_true.astype(np.uint8)*mask)
        plt.subplot(122), plt.imshow(image_test.astype(np.uint8)*mask)
        '''
    return 10 * np.log10((data_range ** 2) / err)

In [None]:
def SSIM(im1, im2, multichannel=True, mask=None):
    check_shape_equality(im1, im2)

    if multichannel:
        # loop over channels
        nch = im1.shape[-1]
        mssim = np.empty(nch)
        for ch in range(nch):
            ch_result = SSIM(im1[..., ch], im2[..., ch], multichannel=False, mask=mask)
            mssim[..., ch] = ch_result
        mssim = mssim.mean()
        return mssim

    if im1.dtype != im2.dtype:
        warn("Inputs have mismatched dtype.  Setting data_range based on im1.dtype.", stacklevel=2)
    dmin, dmax = dtype_range[im1.dtype.type]
    data_range = dmax - dmin
    
    K1 = 0.01
    K2 = 0.03
    R = data_range
    C1 = (K1 * R) ** 2
    C2 = (K2 * R) ** 2

    # ndimage filters need floating point data
    im1 = im1.astype(np.float64)
    im2 = im2.astype(np.float64)

    if mask is None:
        mask = np.zeros_like(im1)
        x,y = np.nonzero(1-mask)
    else:
        x,y,_ = np.nonzero(1-mask)
    im1_pixel = im1[x,y]
    im2_pixel = im2[x,y]

    ux = np.mean(im1_pixel)
    uy = np.mean(im2_pixel)
    uxy = np.mean(im1_pixel*im2_pixel)
    vx = np.var(im1_pixel)
    vy = np.var(im2_pixel)
    vxy = uxy - ux*uy

    ssim = (2*ux*uy+C1)*(2*vxy+C2)/((ux**2+uy**2+C1)*(vx+vy+C2))
    return ssim


In [None]:
def metrics(id, output_images, masks, save_root, scene_images=None, source=None, blend_type='scale', with_mask=False, save=True):   
    if blend_type != 'light' and len(output_images)!=len(scene_images):
        raise ValueError('output images should have the same number as scene images')
    if save:
        save_file = os.path.join(save_root, 'output.txt')
    result = {'p':[], 's':[], 'ce':[], 'fail':0}
    print('id: {}'.format(str(id).zfill(4)))
    if save:
        with open(save_file, 'a') as f:
            f.write('id: {}\n'.format(str(id).zfill(4)))

    for i, output in enumerate(output_images):
        if output is None: 
            result['p'].append(-1)
            result['s'].append(-1)
            result['ce'].append(-1)
            result['fail'] = result['fail']+1
            print('no.{:2d}\t Failed'.format(i))
            if save:
                with open(save_file, 'a') as f:
                    f.write('no.{:2d}\t Failed\n'.format(i))
            continue

        if blend_type == 'light':
            scene = source
        else:
            scene = scene_images[i]

        # print(output.shape[:2], scene.shape[:2])
        if not with_mask:
            p = psnr(np.array(output), np.array(scene))
            s = ssim(np.array(output), np.array(scene), multichannel=True)                
        else:            
            mask = masks[i]
            p = PSNR(np.array(output), np.array(scene), mask=mask)
            s = SSIM(np.array(output), np.array(scene), multichannel=True, mask=mask)                                
            #plt.figure(facecolor='white')
            #plt.subplot(121), plt.imshow(output*(1-mask)), plt.title('output'), plt.axis('off')
            #plt.subplot(122), plt.imshow(scene*(1-mask)), plt.title('scene'), plt.axis('off')
            
        result['p'].append(round(p, 2))
        result['s'].append(round(s, 2))            
        #print('no.{:2d}\t PSNR:{:.2f}\t SSIM:{:.2f}\t census error:{:.2f}'.format(i, p, s, ce))
        print('no.{:2d}\t PSNR:{:.2f}\t SSIM:{:.2f}\t'.format(i, p, s))
        if save:
            with open(save_file, 'a') as f:
                f.write('no.{:2d}\t PSNR:{:.2f}\t SSIM:{:.2f}\t\n'.format(i, p, s))
    return result

In [None]:
def blend_RANSAC(scene_images, target, coarseModel=None, network=None, source=None, blend_type='scale', save_root=None):
    blends = []
    masks = []    
    for idx, scene in enumerate(scene_images):        
        target = target.resize(scene.size)        
        coarseModel.setSource(target)
        coarseModel.setTarget(scene)

        I2w, I2h = coarseModel.It.size
        featt = F.normalize(network['netFeatCoarse'](coarseModel.ItTensor))
                    
        #### -- grid     
        gridY = torch.linspace(-1, 1, steps = I2h).view(1, -1, 1, 1).expand(1, I2h,  I2w, 1)
        gridX = torch.linspace(-1, 1, steps = I2w).view(1, 1, -1, 1).expand(1, I2h,  I2w, 1)
        grid = torch.cat((gridX, gridY), dim=3).cuda() 
        warper = tgm.HomographyWarper(I2h,  I2w)

        bestPara, InlierMask = coarseModel.getCoarse(np.zeros((I2h, I2w)))
        bestPara = torch.from_numpy(bestPara).unsqueeze(0).cuda()

        flowCoarse = warper.warp_grid(bestPara)        
        I1_coarse = F.grid_sample(coarseModel.IsTensor, flowCoarse)
        # I1_coarse_pil = transforms.ToPILImage()(I1_coarse.cpu().squeeze())

        featsSample = F.normalize(network['netFeatCoarse'](I1_coarse.cuda()))

        corr12 = network['netCorr'](featt, featsSample)
        flowDown8 = network['netFlowCoarse'](corr12, False) ## output is with dimension B, 2, W, H

        flowUp = F.interpolate(flowDown8, size=(grid.size()[1], grid.size()[2]), mode='bilinear')
        flowUp = flowUp.permute(0, 2, 3, 1)

        flowUp = flowUp + grid

        flow12 = F.grid_sample(flowCoarse.permute(0, 3, 1, 2), flowUp).permute(0, 2, 3, 1).contiguous()

        I1_fine = F.grid_sample(coarseModel.IsTensor, flow12)
        I1_fine_pil = transforms.ToPILImage()(I1_fine.cpu().squeeze())
        if blend_type == 'light':
            I1_fine_pil = I1_fine_pil.resize(source.size)    
        else:    
            I1_fine_pil = I1_fine_pil.resize(scene.size)

        blend_i, mask_i = blend(I1_fine_pil, source, scene, blend_type)
        #blend_i = cv2.resize(blend_i, (ori_W, ori_H))
        blends.append(blend_i)
        masks.append(mask_i)
        
        
    return blends, masks

In [None]:
def eval(id, 
         scene_images, 
         target,           
         save_root,          
         blend_type='scale', 
         coarseModel=None,
         network=None,
         source=None,         
         with_mask = False,
         draw = True,
         save = True,
         ):   
    '''
    Args:
        @id: (int) target image id, for logging
        @scene_images: (list) test scene images
        @target: (array) target image
        @save_root: (string)         
        @blend_type: (string, 'scale') ['scale', 'light', 'viewpoint']
        @estimator: model to estimate when blend_method=='life'
        @flow_estimator: model to estimate when blend_method=='pdc'
        @estimate_uncertainty: check result uncertainty when blend_method=='pdc'
        @matching: SPSG
        @coarseModel & network: RANSAC-Flow
        @source: source image to calculate metrics when blend_type=='lighting'
        @detector: (string, 'SIFT) ['SIFT', 'ORB'] detector used in homography
        @with_mask: (bool) calculate metrics with or without mask
        @draw: (bool) show results or not
        @save: (bool) save results or not
        @warp: (string,'grid_sample') ['grid_sample', 'homography'] warp method for LIFE model
    '''
    out, mask = blend_RANSAC(scene_images, target, coarseModel=coarseModel, network=network, source=source, blend_type=blend_type, save_root=save_root)

        
    if blend_type == 'light':
        result = metrics(id, out, mask, save_root=save_root, source=source, blend_type=blend_type, with_mask=with_mask, save=save)
    else:
        result = metrics(id, out, mask, save_root=save_root, scene_images =scene_images, blend_type=blend_type, with_mask=with_mask, save=save)
    
    #if not with_mask and draw:
    if draw: 
        save_root = os.path.join(save_root, str(id).zfill(4))
        if save and not os.path.exists(save_root):
            os.makedirs(save_root)
        for i in range(len(scene_images)): 
            if out[i] is None:
                continue
            title = 'PSNR: '+str(result['p'][i])+' SSIM: '+str(result['s'][i])
            # plt.figure(figsize=(10,15), facecolor='white')      
                        
            # scene_images[i] = rescale(scene_images[i])
            # out[i] = rescale(out[i])
            plt.figure(facecolor='white')
            plt.subplot(1, 3, 1)
            plt.imshow(target), plt.axis('off')# , plt.title('target '+str(id).zfill(4))        
            plt.subplot(1, 3, 2)
            plt.imshow(scene_images[i]), plt.axis('off')# , plt.title('scene '+str(i).zfill(4))        
            plt.subplot(1, 3, 3)
            if not with_mask:
                plt.imshow(out[i]), plt.axis('off') #, plt.title(title)
            else:
                plt.imshow(out[i]*(1-mask[i])), plt.axis('off') #, plt.title(title)
            if save: 
                #plt.figure(facecolor='white')
                #plt.imshow(out[i]*(1-mask[i])), plt.title(title), plt.axis('off')
                plt.savefig(os.path.join(save_root,str(i)+'_'+title+'.png'), dpi=200, bbox_inches='tight')
                #io.imsave(os.path.join(save_root, str(i)+'_'+title+'.png'), out[i]) # cv2.resize(out[i],None,fx=0.25,fy=0.25)
                # io.imsave(os.path.join(save_root, str(i)+'_out.png'), np.array(out[i]))
                # io.imsave(os.path.join(save_root, str(i)+'_scene.png'), np.array(scene_images[i]))
            plt.close()
            
    return (out, result)

def rescale(img):
    if img.shape[0] > img.shape[1]:
        resize = cv2.flip(cv2.transpose(img), 0)
        # region = img[140:500, 0:480]
        # resize = cv2.resize(region, (640, 480))
    return resize


In [None]:
def run(root = './assets/',
        blend_type = 'scale',        
        img_num = 5, 
        start_img_id = 0,
        start_scene_id = 0,
        scn_num = 10,
        source_id = -1,
        H = 480,
        W = 640,
        with_mask = False,
        draw = True,
        save = True,        
        multisample=True,
        folder="",
        ):
    '''
    Args:
        @root: (string, './assets/') image data root
        @blend_type: (string, 'scale) ['scale', 'light', 'viewpoint', 'deformation', 'occlusion']
        @blend_method: (string, 'life') ['homography', 'pdc', 'life', 'raft', 'biraft', 'ms', 'occ', 'twins', 'twins-twostage', 'twins-onestage']
        @img_num: (int, 5) number of test images
        @start_img_id: (int, 0) test from image with id=start_img_id
        @start_scene_id: (int, 0) test from scene image with id=start_scene_id
        @scn_num: (int, 10) number of scene images per test image
        @source_id: if blend_type=='light', source id should be specified for metrics calculation
        @(W, H): resize images
        @with_mask: caculate PSNR/SSIM with or without mask
        @draw: (bool)
        @save: (bool)
        @detector: (string, 'SIFT') ['SIFT', 'ORB']
        @multisample: (bool, True) if False(not one of blend_type), specify scene images folder
        @warp: (string, 'grid_sample') ['grid_sample', 'homography'] warp method for LIFE model
    Ouput:
        folder '{blend_method}_output_{time}' with output images and output.txt(metrics) will be created under folder {blend_type}
    '''
    # path setting
    if blend_type == 'light':
        if source_id == -1:
            raise FileExistsError('no source image id specified.')    
    if multisample:
        scene_root = os.path.join(root, blend_type)
    else:
        if folder=="":
            raise FileNotFoundError("input folder name")
        scene_root = os.path.join(root, folder) 
    time = datetime.datetime.now()
    suffix = datetime.datetime.strftime(time, '%m%d%H%M')
    if not with_mask:
        save_root = os.path.join(scene_root, 'ransac_output_'+suffix)
    else:
        save_root = os.path.join(scene_root, 'ransac_output_mask_'+suffix)
    if save and not os.path.exists(save_root):
        os.makedirs(save_root)

    # data loading
    if H > W:
        H, W = W, H
    scene_images = []
    target_images = []    
    for id in range(img_num):
        id = id + start_img_id
        scene_image = []
        for i in range(start_scene_id, scn_num):
            scene_path = os.path.join(scene_root, str(id).zfill(4) + "_" + blend_type + "_"+str(i)+".jpg")
            scene = Image.open(scene_path).convert('RGB')
            if scene.size[1] > scene.size[0]: # H > W
                scene = scene.resize((H, W))
            else:
                scene = scene.resize((W, H))
            scene_image.append(scene)
        scene_images.append(scene_image)

        target_path = os.path.join(root, str(id).zfill(4) + ".jpg")
        target_image = Image.open(target_path).convert('RGB')
        if target_image.size[1] > target_image.size[0]: # H > W
            target_image = target_image.resize((H, W))
        else:
            target_image = target_image.resize((W, H))     
        target_images.append(target_image)
    total = len(scene_images)*len(scene_images[0])
    print('input {}x{} scene images'.format(len(scene_images), len(scene_images[0])))
    print('input {} target images'.format(len(target_images)))
    # show_input_images(scene_images, target_images)

    # model loading
    resumePth = './RANSAC-Flow/model/pretrained/MegaDepth_Theta1_Eta001_Grad1_0.774.pth' ## model for visualization
    kernelSize = 7
    ## Loading model
    # Define Networks
    network = {'netFeatCoarse' : model.FeatureExtractor(), 
            'netCorr'       : model.CorrNeigh(kernelSize),
            'netFlowCoarse' : model.NetFlowCoarse(kernelSize), 
            'netMatch'      : model.NetMatchability(kernelSize),
            }
        

    for key in list(network.keys()) : 
        network[key].cuda()        

    # loading Network 
    param = torch.load(resumePth)
    msg = 'Loading pretrained model from {}'.format(resumePth)
    print (msg)

    for key in list(param.keys()) : 
        network[key].load_state_dict( param[key] ) 
        network[key].eval()

    nbScale = 7
    coarseIter = 10000
    coarsetolerance = 0.05
    minSize = 400
    imageNet = True # we can also use MOCO feature here
    scaleR = 1.2 

    coarseModel = CoarseAlign(nbScale, coarseIter, coarsetolerance, 'Homography', minSize, 1, True, imageNet, scaleR)
            
    ID = []
    #out_images = []
    #results = []
    p = []
    s = []
    fail = 0
    for id in range(img_num):
        scene = scene_images[id]
        target = target_images[id]
        source = scene[source_id]  
        out, result = eval(id, scene, target, save_root=save_root, source=source, blend_type=blend_type, coarseModel=coarseModel, network=network, with_mask=with_mask, draw=draw, save=save)
        #out_images.append(out)
        #results.append(result)
        #ID.append(str(id).zfill(4))
        
        # p.append(result['p'])
        # s.append(result['s'])
        p.append([result['p'][i] for i in range(len(result['p'])) if result['p'][i]!=-1])
        s.append([result['s'][i] for i in range(len(result['s'])) if result['s'][i]!=-1])
        fail = fail + result['fail']
    
    # p_array = np.asarray(p).flatten()
    # s_array = np.asarray(s).flatten()
    p_array = np.asarray([item for sub in p for item in sub])
    s_array = np.asarray([item for sub in s for item in sub])
    #print(p_array)
    print('RANSAC\t PSNR: {:.2f}/{:.2f}\t SSIM: {:.2f}/{:.2f}\t fail: {:.2f}%({}/{})'
        .format(np.mean(p_array), 
                np.median(p_array),
                np.mean(s_array), 
                np.median(s_array),
                fail*1.0/total*100,
                fail, total))
    
    #draw_result(ID, out_images, results)

In [None]:
''' 数据集格式
|-deformation (scene images)
    |-0000_deformation_0.jpg
    |-0000_deformation_1.jpg
    |-...
    |-0001_deformation_0.jpg
    |-...
|-light
|-occlusion
|-scale
|-viewpoint
|-0000.jpg (target image)

|-0001.jpg
|-...
'''
# RANSAC
# pip install kornia==0.1.4.post2
# kornia 版本太高会找不到warp_grid函数
run(blend_type='occlusion', img_num=10, scn_num=10, source_id=1, with_mask=True, multisample=True, folder="minions", start_img_id=0, start_scene_id = 0, draw=False, save=False)