In [9]:
import numpy as np
import matplotlib
import cv2
from typing import Tuple,Any,List
matplotlib.use('Agg')
from scipy.stats import multivariate_normal as mv_normal
import matplotlib.pyplot as plt
import time
from tqdm import tqdm

EPS = 0.1

def get_neighbors(matrix, i, j):
    neighbors = []
    if j < matrix.shape[1] - 1:
        neighbors.append(matrix[i, j+1])
    if i < matrix.shape[0] - 1:
        neighbors.append(matrix[i+1, j])
    if i > 0:
        neighbors.append(matrix[i-1, j])
    if j > 0:
        neighbors.append(matrix[i, j-1])
    
    return neighbors

def compute_obervation(label_matrix, image, mean, denominator,cov_inv):
    new_label_matrix = np.zeros((image.shape[0],image.shape[1]))
    for i in range(image.shape[0]-1):
        for j in range(image.shape[1]-1):
            neighbors = get_neighbors(label_matrix,i,j)
            factor = 1
            for neighbor in neighbors:
                if neighbor == label_matrix[i,j]:
                    factor*=EPS
                else: factor*=(1-EPS)
            x_m_1 = image[i,j] - mean[0]
            numerator_1 = np.exp(-(np.dot(np.dot(x_m_1.T,cov_inv[0]),x_m_1)) / 2)
            p1 = ((numerator_1/denominator[0])/100)*factor
            x_m_2 = image[i,j] - mean[1]
            numerator_2 = np.exp(-(np.dot(np.dot(x_m_2.T,cov_inv[1]),x_m_2)) / 2)
            p2 = ((numerator_2/denominator[1])/100)*factor
            p1 = p1/(p1+p2)
            p2 = p2/(p1+p2)
            if p1>=p2:
                new_label_matrix[i,j] = 0
            else:
                new_label_matrix[i,j] = 1
    return new_label_matrix

def EM_algo(img):
    p_0 = 0.5
    p_1 = 0.5
    mean_0 = np.random.randint(1,255,size=3)
    mean_1 = np.random.randint(1,255,size=3)
    cov_0 = np.random.randint(-1000,1000,size=(3,3))
    cov_0 = cov_0.T @ cov_0
    cov_1 = np.random.randint(-1000,1000,size=(3,3))
    cov_1 = cov_1.T @ cov_1
    for _ in tqdm(range(100)):
        dens_arr_0 = mv_normal.pdf(img, mean_0, cov_0)
        dens_arr_1 = mv_normal.pdf(img, mean_1, cov_1)

        alphas_0 = p_0*dens_arr_0
        alphas_1 = p_1*dens_arr_1
        sum_alphas = alphas_0 + alphas_1
        alphas_0 = alphas_0/sum_alphas
        alphas_1 = alphas_1/sum_alphas
        
        p_0 = alphas_0.mean()
        p_1 = alphas_1.mean()
        
        alphas_0_temp = np.zeros(img.shape)
        alphas_1_temp = np.zeros(img.shape)
        for s in [0,1,2]:
            alphas_0_temp[:,:,s] = alphas_0
            alphas_1_temp[:,:,s] = alphas_1
        mean_0 = ( (alphas_0_temp * img).sum( axis=(0,1) ))/(alphas_0.sum())
        mean_1 = ( (alphas_1_temp * img).sum( axis=(0,1) ))/(alphas_1.sum())
        
        alphas_0_temp = np.sqrt(alphas_0_temp)
        numerator0 = ((alphas_0_temp)*(img - mean_0)).reshape(3,-1)
        numerator0 = numerator0 @ numerator0.T
        cov_0 = numerator0/(alphas_0.sum())
        alphas_1_temp = np.sqrt(alphas_1_temp)
        numerator1 = ((alphas_1_temp)*(img - mean_1)).reshape(3,-1)
        numerator1 = numerator1 @ numerator1.T
        cov_1 = numerator1/(alphas_1.sum())

    return mean_0.astype(int),mean_1.astype(int),cov_0.astype(int),cov_1.astype(int)


def gibbs_sampler(image : np.ndarray,iterations: int = 20) -> Tuple[np.ndarray,List[np.ndarray]]:
    m_1,m_2,cov_1,cov_2 = EM_algo(image)
    cov_inverse_1 = np.linalg.inv(cov_1)
    cov_inverse_2 = np.linalg.inv(cov_2) 
    cov_det_1 = np.linalg.det(cov_1)
    cov_det_2 = np.linalg.det(cov_2)
    mean = [m_1,m_2]
    cov_inv = [cov_inverse_1,cov_inverse_2]
    cov_det = [cov_det_1,cov_det_2]
    denominator = [np.sqrt(((2 * np.pi)**3) * cov_det[0]),np.sqrt(((2 * np.pi)**3) * cov_det[1])]
    label_matrix = np.zeros((image.shape[0],image.shape[1]))
    for i in range(image.shape[0]-1):
        for j in range(image.shape[1]-1):
            x_m_1 = image[i,j] - mean[0]
            numerator_1 = np.exp(-(np.dot(np.dot(x_m_1.T,cov_inv[0]),x_m_1)) / 2)
            p1 = ((numerator_1/denominator[0])/100)
            x_m_2 = image[i,j] - mean[1]
            numerator_2 = np.exp(-(np.dot(np.dot(x_m_2.T,cov_inv[1]),x_m_2)) / 2)
            p2 = ((numerator_2/denominator[1])/100)
            if p1>=p2:
                label_matrix[i,j] = 0
            else:
                label_matrix[i,j] = 1
    results = []
    for _ in tqdm(range(iterations)):
        label_matrix = compute_obervation(label_matrix, image, mean, denominator, cov_inv)
        results.append(label_matrix)
    return label_matrix,results

if __name__ == '__main__':
    image_path = 'D:/lab 3/test.jpg'
    image = plt.imread(image_path)
    image = cv2.resize(image,(256,256))
    start = time.time()
    seg_image, results = gibbs_sampler(image)
    stacked_img = np.stack((seg_image,)*3, axis=-1)
    stacked_img[:,:,0]*=255
    cv2.imwrite("D:/lab 3/result.jpg",stacked_img)
    end = time.time()
    print("TIME = ",end-start)
    for i,mask in enumerate(results):
        stacked_img = np.stack((mask,)*3, axis=-1)*255
        cv2.imwrite(f'D:/lab 3/result/image_{i}.jpg',stacked_img)

100%|████████████████████████████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 64.96it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:14<00:00,  1.35it/s]

TIME =  16.963226795196533



