In [6]:
import cv2
import numpy as np
import matplotlib.pyplot as plt

from scipy import sparse
from scipy.sparse import linalg

In [7]:
def get_neighbor_pixels(image, coord, T):
    x, y = coord
    s_coord, s_value = [], []
    
    for i in range(-T, T+1):
        for j in range(-T, T+1):
            if i == j == 0: continue
            try: 
                if x + i < 0 or y + j < 0: continue
                s = image[x+i][y+j] 
                s_coord.append((x+i, y+j))
                s_value.append(s)
            except IndexError: continue     
    return s_coord, s_value


def weight_f(mean, var, r, s): # option 2
    eps = 1e-6
    return 1 + ((r-mean)*(s-mean))/(var+eps)


def get_weight(r, values):
    n_mean, n_var = np.mean(values), np.var(values)
    weight_neighbor = [weight_f(n_mean, n_var, r, i) for i in values]
    normalized_neighbor = weight_neighbor / np.sum(weight_neighbor)
    return normalized_neighbor


def get_neighbor_matrix(image, T):
    height, width = image.shape
    neighborhood = sparse.lil_matrix((height * width, height * width)) # 337,500 * 337,500
    numbering = 0

    for i in range(height):
        for j in range(width):
            r = image[i][j]
            coords, values = get_neighbor_pixels(image, (i, j), T)
            normalized_neighbor = get_weight(r, values)
            for (x, y), weight in zip(coords, normalized_neighbor):
                neighborhood[numbering, x * width + y] = weight
            numbering += 1
    return neighborhood


def get_scribbles(scribbles_img):
    scribbles = cv2.imread(scribbles_img)
    return np.sum(scribbles, axis=2).reshape(-1)

# I - w
def get_identity_weights(neighbor, scribbles_flat, height, width):
    identity_matrix = sparse.identity(height * width)
    for i in range(neighbor.shape[0]):
        if scribbles_flat[i] != 0: neighbor[i, :] = 0
    return identity_matrix - neighbor


def least_sq(i_minus_weight, scribbles_flat, h, w):
    print('lstsq start')
    x_sky = linalg.lsqr(i_minus_weight, np.where(scribbles_flat==1,1,0))
    x_bui = linalg.lsqr(i_minus_weight, np.where(scribbles_flat==2,1,0))
    x_tre = linalg.lsqr(i_minus_weight, np.where(scribbles_flat==3,1,0))
    x_hai = linalg.lsqr(i_minus_weight, np.where(scribbles_flat==4,1,0))
    print('half way')
    x_ski = linalg.lsqr(i_minus_weight, np.where(scribbles_flat==5,1,0))
    x_pho = linalg.lsqr(i_minus_weight, np.where(scribbles_flat==6,1,0))
    x_clo = linalg.lsqr(i_minus_weight, np.where(scribbles_flat==7,1,0))
    print('lstsq finish')
    
    n = np.stack([x_sky[0], x_bui[0], x_tre[0], x_hai[0], x_ski[0], x_pho[0], x_clo[0]], axis=0)
    c = n.argmax(axis=0)
    return np.reshape(c, (h, w))


def get_ground_truth(gt_img):
    gt = cv2.imread(gt_img, cv2.COLOR_BGR2RGB)
    return np.sum(gt, axis=2)


def get_iou_score(gt, spm):
    intersections = []
    unions = []
    scores = []
    
    for i in range(7):
        gt_ind = np.where(gt==i+1,1,0)
        spm_ind = np.where(spm==i,1,0)
        intersection = np.logical_and(gt_ind, spm_ind)
        union = np.logical_or(gt_ind, spm_ind)
        score = np.sum(intersection) * 100 / np.sum(union)
        
        intersections.appends(intersection)
        unions.appends(union)
        scores.appends(score)
        
        print('class =', i)
        print('Intersection =', np.sum(intersection))
        print('Union =', np.sum(union))
        print('IoU score =', score)
        print()
    
    print('mIoU = ', sum(scores) / len(scores))
    return intersections, unions, score


def make_plot(gt, spm, intersections, unions):
    plt.title(f'Multi-Label Ground Truth')
    plt.imshow(gt)
    plt.savefig(f'ml-gt.png', facecolor='#eeeeee', edgecolor='blue', bbox_inches='tight')
    
    plt.title(f'Multi-Label Output')
    plt.imshow(spm)
    plt.savefig(f'ml-output.png', facecolor='#eeeeee', edgecolor='blue', bbox_inches='tight')
    
    classes = ['Sky','Buildings','Tree','Hair','Skin','Phone','Clothes']
    
    for i in range(7):
        plt.title(f'{class[i]} Ground Truth')
        plt.imshow(np.where(gt==i+1,1,0))
        plt.savefig(f'{class[i]}-gt.png', facecolor='#eeeeee', edgecolor='blue', bbox_inches='tight')
        
        plt.title(f'{class[i]} Output')
        plt.imshow(np.where(spm==i,1,0))
        plt.savefig(f'{class[i]}-output.png', facecolor='#eeeeee', edgecolor='blue', bbox_inches='tight')

        plt.title(f'{class[i]} Intersection')
        plt.imshow(intersection[i])
        plt.savefig(f'{class[i]}-intersection.png', facecolor='#eeeeee', edgecolor='blue', bbox_inches='tight')

        plt.title(f'{class[i]} Union')
        plt.imshow(union[i])
        plt.savefig(f'{class[i]}-union.png', facecolor='#eeeeee', edgecolor='blue', bbox_inches='tight')
    
    print('Done')
    

def all_in_one(original_img, scribble_img, gt_img, T):
    img = cv2.imread(original_img, cv2.IMREAD_GRAYSCALE)
    height, width = img.shape
    
    neighbor = get_neighbor_matrix(img, T)
    scribbles_flat = get_scribbles(scribble_img)
    i_minus_weight = get_identity_weights(neighbor, scribbles_flat, height, width)
    
    spm = least_sq(i_minus_weight, scribbles_flat, height, width)
    gt = get_ground_truth(gt_img)
    intersections, unions, scores = get_iou_score(gt, spm)
    make_plot(gt, spm, intersections, unions)

In [None]:
original = 'Emily-In-Paris-gray.png'
scribble = 'Emily-In-Paris-scribbles.png'
gt_img = 'Emily-In-Paris-gt-plus.png'

all_in_one(original, scribble, gt_img, 5)

lstsq start
