In [4]:
from scipy.sparse import lil_matrix
from scipy.sparse.linalg import lsqr
import enum

import numpy as np
import cv2 
import matplotlib.pyplot as plt

%matplotlib inline

class Directions(enum.Enum):
    UP=(-1, 0)
    LEFT=(0, -1)
    DOWN=(1, 0)
    RIGHT=(0, 1)


def generate_sparse_matrix(source, target, mask, alpha):
    h, w = target.shape[:2]
    N = h * w

    A = lil_matrix((N, N))
    b = np.zeros(N)

    for row in range(h):
        for col in range(w):
            idx = row * w + col
            
            if mask[row, col] > 0:
                src_grad = compute_gradient(source, row, col)
                tgt_grad = compute_gradient(target, row, col)
                
                b[idx] = alpha * src_grad + (1 - alpha) * tgt_grad
                
                A[idx, idx] = 4
                for dir in Directions:
                    if (-1 < row + dir.value[0] < h) and (-1 < col + dir.value[1] < w):
                        A[idx, idx + w*dir.value[0] + dir.value[1]] = -1
            else:
                A[idx, idx] = 1
                b[idx] = target[row, col]

    return A, b

def compute_gradient(image, row, col):
    h, w = image.shape[:2]
    grad = 4 * image[row, col]
    
    for dir in Directions:
        if (-1 < row + dir.value[0] < h) and (-1 < col + dir.value[1] < w):
            grad -= image[row + dir.value[0], col + dir.value[1]]
    
    return grad

def blend_color_channels(source, target, mask, alpha):
    blended_image = np.zeros_like(target)
    for channel in range(3):
        A, b = generate_sparse_matrix(source[:, :, channel], target[:, :, channel], mask, alpha)
        blended_channel = lsqr(A, b)[0]
        blended_channel[blended_channel > 255] = 255
        blended_channel[blended_channel < 0] = 0
        blended_channel = blended_channel.astype('uint8')

        blended_image[:, :, channel] = blended_channel.reshape(target.shape[:2])
    return blended_image

source_image = cv2.imread('blending/Ryan.jpg')
target_image = cv2.imread('blending/Matycha.jpg')
mask_image = cv2.imread('blending/Maska.jpg', cv2.IMREAD_GRAYSCALE)

alpha = 0.83
blended_image = blend_color_channels(source_image, target_image, mask_image, alpha)


# fig, ax = plt.subplots(1, 1, figsize = (12, 8))
# ax.axis('off')
# ax.imshow(blended_image)
# plt.show()
cv2.imwrite('blending/blended_image.jpg', blended_image)

True