In [None]:
from scipy.sparse import csr_matrix
from skimage.io import imread

import cv2
import matplotlib.pyplot as plt
import numpy as np
import os

In [None]:
PATH_TO_DATA = '../data/depth_superres/'

reference = imread(os.path.join(PATH_TO_DATA, 'reference.png'))
target = imread(os.path.join(PATH_TO_DATA, 'target.png'))
confidence = imread(os.path.join(PATH_TO_DATA, 'confidence.png'))

In [None]:
print(target.shape)

In [None]:
plt.figure(figsize=(20, 20))
plt.subplot(311)
plt.imshow(reference)
plt.title('reference')
plt.subplot(312)
plt.imshow(confidence)
plt.title('confidence')
plt.subplot(313)
plt.imshow(target)
plt.title('target')
plt.show()

In [None]:
# COPIED FROM SOMEWHERE (need source)
RGB_TO_YUV = np.array([
    [ 0.299,     0.587,     0.114],
    [-0.168736, -0.331264,  0.5],
    [ 0.5,      -0.418688, -0.081312]])
YUV_TO_RGB = np.array([
    [1.0,  0.0,      1.402],
    [1.0, -0.34414, -0.71414],
    [1.0,  1.772,    0.0]])
YUV_OFFSET = np.array([0, 128.0, 128.0]).reshape(1, 1, -1)

def rgb2yuv(im):
    return np.tensordot(im, RGB_TO_YUV, ([2], [1])) + YUV_OFFSET

def yuv2rgb(im):
    return np.tensordot(im.astype(float) - YUV_OFFSET, YUV_TO_RGB, ([2], [1]))

In [None]:
reference_yuv = rgb2yuv(reference)

In [None]:
print(reference_yuv.shape)
print(reference_yuv[20, 1200, 0])

In [None]:
# TODO : names based on the paper.
def build_b():
    pass

def initialise_x():
    pass

def preconditioner():
    pass

def solver():
    pass

In [None]:
def build_W(im, param):
    
    yuv_im = rgb2yuv(im)
    n_row, n_col, _ = yuv_im.shape
    nb_pixels = n_row * n_col
    
    hashing_vector = np.zeros(5)
    for i in range(5):
        hashing_vector[i] = 255**i
    
    hashed_coords = {}
    
    vertical_idx = 0
    pixel_idx = 0
    
    splat_idxs = np.zeros(nb_pixels)
    for col_idx in range(n_col):
        for row_idx in range(n_row):
            coord = np.zeros(5)
            coord[0] = int(col_dix / param['sigma_spatial'])
            coord[1] = int(row_idx / param['sigma_spatial'])
            coord[2] = int(yuv_im[col_dix, row_idx, 0] / param['sigma_luma'])
            coord[3] = int(yuv_im[col_dix, row_idx, 1] / param['sigma_chroma'])
            coord[4] = int(yuv_im[col_dix, row_idx, 2] / param['sigma_chroma'])
            
            hashed_coord = np.sum([coord[i] * hashing_vector[i] for i in range(5)])
            
            if hashed_coord in hashed_coords:
                splat_idxs[pixel_idx] = hashed_coords[hashed_coord]
            else:
                hashed_coords[hashed_coord] = vertical_idx
                splat_idxs[pixel_idx] = vertical_idx
                vertical_idx += 1
            
            pixel_idx += 1
            
    n_vertices = len(hashed_coords)
    
    S = csr_matrix((np.ones(nb_pixels), (splat_idxs, np.arange(nb_pixels))))
    
    B_ = np.eye(n_vertices) * 10
    
    for offset in (-1, 1):
        for i in range(5):
            B_temp = np.zeros(n_vertices, n_vertices)
            offset_hashed_coord = offset * hashing_vector[i]
            for hashed_coord in hashed_coords.keys():
                neighbor = hashed_coord + offset_hashed_coord
                if neighbor in hashed_coords.keys():
                    B_temp[hashed_coords[hashed_coord], hashed_coords[neighbor]] = 1
            B_ += B_temp
            
    return S, B_
    

In [None]:
def bistochastization(S, B_, max_iter=10):
    m = np.dot(S, np.ones(S.shape[1]))
    n = np.ones(S.shape[0])
    
    it = 0
    while it < max_iter:
        it += 1
        num = n * m
        den = np.dot(B_, n)
        n = np.sqrt(np.divide(num, den))

    return np.diag(n), np.diag(m)
    
def build_A(reference_im, conf, param):
    S, B_ = build_W(reference_im, param)
    
    D_n, D_m = bistochastization(S, B_, param['max_iter_bisto'])
    
    Lam = param['lambda']
    
    A = Lam * (D_m - np.dot(D_n, np.dot(B_, D_n))) + np.diag(np.dot(S, conf))
    
    return A