# Bayes classifier algorithm

## Importing requiered libraries

In [1]:
import numpy as np
import cv2
import math

## Getting training samples

In [2]:
def get_samples(img_test):
    img_back_1 = img_test[0:50,0:50]
    img_back_2 = img_test[img_test.shape[0]-50:img_test.shape[0],img_test.shape[1]-50:img_test.shape[1]]
    img_back = np.append(img_back_1,img_back_2,axis = 1)
    img_obj = img_test[int((img_test.shape[0]/2)-50):int((img_test.shape[0]/2)+50),
                     int((img_test.shape[1]/2)-50):int((img_test.shape[1])/2)+50]
    return img_back,img_obj

## Computing means of object and background

In [3]:
def compute_means(img_back,img_obj):
    z_back = np.sum(img_back,axis = 0,keepdims = True)
    z_back = np.sum(z_back,axis = 1,keepdims = True)
    sum_back = z_back.reshape(3,1)
    mean_back = sum_back / (img_back.shape[0] * img_back.shape[1])
    z_obj = np.sum(img_obj,axis = 0,keepdims = True)
    z_obj = np.sum(z_obj,axis = 1,keepdims = True)
    sum_obj = z_obj.reshape(3,1)
    mean_obj = sum_obj / (img_obj.shape[0] * img_obj.shape[1])
    return mean_back,mean_obj

## Computing covariance matrix and its determinants

In [4]:
def compute_covariance(mean_back,mean_obj,img_back,img_obj):
    sum_products_back = np.zeros((3,3))
    Xk_back = np.zeros((3,1))
    sum_products_obj = np.zeros((3,3))
    Xk_obj = np.zeros((3,1))
    for i in range (img_back.shape[0]):
        for j in range (img_back.shape[1]):      
            Xk_back = img_back[i,j,0:3]
            Xk_back.resize((3,1))
            product = np.dot(Xk_back-mean_back,(Xk_back-mean_back).T)
            sum_products_back += product
    covar_back = sum_products_back/(img_back.shape[0] * img_back.shape[1])
    covar_det_back = np.linalg.det(covar_back)
    for i in range (img_obj.shape[0]):
        for j in range (img_obj.shape[1]):      
            Xk_obj = img_obj[i,j,0:3]
            Xk_obj.resize((3,1))
            product = np.dot(Xk_obj-mean_obj,(Xk_obj-mean_obj).T)
            sum_products_obj += product
    covar_obj = sum_products_obj/(img_back.shape[0] * img_back.shape[1])
    covar_det_obj = np.linalg.det(covar_obj)
    return covar_back,covar_obj,covar_det_back,covar_det_obj

## Classifying each pixel in the image as object or background

In [5]:
def compute_pixels_classification(img_test,mean_back,mean_obj,covar_back,covar_obj,covar_det_back,covar_det_obj):
    pixel = np.zeros((3,1))
    img_test_result = img_test.copy()
    mean_color_back = mean_back.reshape((1,1,3))
    mean_color_obj = mean_obj.reshape((1,1,3))
    for i in range(img_test.shape[0]):
        for j in range(img_test.shape[1]):
            pixel = img_test[i][j].reshape((3,1))
            log_prob_back = -(3/2)*math.log(2*math.pi) - 0.5*math.log(covar_det_back)-0.5*np.dot(np.dot((pixel-mean_back).T,
                                                                            np.linalg.inv(covar_back)),pixel-mean_back)
            log_prob_obj = -(3/2)*math.log(2*math.pi) - 0.5*math.log(covar_det_obj)-0.5*np.dot(np.dot((pixel-mean_obj).T,
                                                                            np.linalg.inv(covar_obj)),pixel-mean_obj)
            if log_prob_obj>log_prob_back:
                img_test_result[i][j] = mean_color_obj.astype(int)
            else:
                img_test_result[i][j] = mean_color_back.astype(int)
    return img_test_result

## Model integration

In [6]:
import time
date = time.strftime('%m-%d-%y-%H_%M_%S')
img_test = cv2.imread('F:/images/gym/20210917_184222.jpg')
img_test = cv2.resize(img_test,(480,480))
img_back,img_obj = get_samples(img_test)
mean_back,mean_obj = compute_means(img_back,img_obj)
covar_back,covar_obj,covar_det_back,covar_det_obj = compute_covariance(mean_back,mean_obj,img_back,img_obj)
img_test_result = compute_pixels_classification(img_test,mean_back,mean_obj,covar_back,covar_obj,covar_det_back,covar_det_obj)
cv2.imshow("img_test", img_test)
cv2.imshow("img_test_result", img_test_result)
path = "F:" + date + ".jpg"
print(path)
cv2.imwrite(path,img_test_result)
cv2.waitKey()
cv2.destroyAllWindows()

F:01-14-22-02_18_53.jpg


In [9]:
img_test.shape

(2592, 1944, 3)