In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
from gco import pygco 
import cv2

def dist(patch1, patch2):
    pixel_count = patch1.size
    _patch1 = patch1.flatten()
    _patch2 = patch2.flatten()
    value = 0.0
    for i in range(pixel_count):
        value += abs(float(_patch1[i]) - float(_patch2[i]))
    return value/pixel_count    

#=======================================================
# get camera matrix
cam_matrix = np.loadtxt('cameras.txt', dtype=np.float32)
K1 = cam_matrix[0:3, 0:3]
R1 = cam_matrix[3:6, 0:3]
T1 = cam_matrix[6:7, 0:3].T
K2 = cam_matrix[7:10, 0:3]
R2 = cam_matrix[10:13, 0:3]
T2 = cam_matrix[13:14, 0:3].T

# read images
img1 = cv2.imread('Img/test00.jpg')
img2 = cv2.imread('Img/test09.jpg')
img1 = cv2.cvtColor(img1,cv2.COLOR_BGR2RGB)
img2 = cv2.cvtColor(img2,cv2.COLOR_BGR2RGB)
newImg = img1
H1 = img1.shape[0]
W1 = img1.shape[1]
H2 = img2.shape[0]
W2 = img2.shape[1]

# parameters setting
d_min = 0
d_max = 0.01
m = 50
d = np.linspace(d_min,d_max,m)
m_lambda = 0.1

# caculate epipolar line
unary = np.zeros([H1, W1, m])
pairwise = (1 - np.eye(m)) * m_lambda
term1 = K2 @ R2 @ R1.T @ np.linalg.inv(K1)
term2 = K2 @ R2 @ np.mat(T1-T2)

for i in range(0, m):
    for y in range(0, H1):
        for x in range(0, W1):
            xh = np.array([[x], [y], [1]])
            coord = term1 @ xh + d[i] * term2
            coord = coord/coord[2]
            
            if coord[1]<0:
                coord[1] = 0
            if coord[1]>H1:
                coord[1] = H1-1
            if coord[0]<0:
                coord[0] = 0
            if coord[0]>W1:
                coord[0] = W1-1
            
            c1 = img1[y][x]
            c2 = img2[int(coord[1])][int(coord[0])]
                
            #unary[y][x][i] = dist(c1, c2) / 255
            unary[y][x][i] = sum(abs(img1[y][x]-img2[int(coord[1])][int(coord[0])]))/255




In [None]:
n_labels = pygco.cut_grid_graph_simple(unary, pairwise, n_iter=-1)
n_labels = n_labels.reshape(H1, W1)

for y in range(0, H1):
    for x in range(0, W1):
        newImg[y][x] = n_labels[y][x] * 255 / m * np.array([1,1,1])
        
plt.imshow(newImg)
plt.show()