In [None]:
import random

from PIL import Image
import numpy as np
from scipy.stats import norm

class_colors = [0, 127, 255]
class_original_image_filenames = ["test2-c1.png", "test2-c2.png", "test2-c3.png"]
other_pixels_colors = [0, 0, 0]


def rgb2hsv(r, g, b):
    r, g, b = r / 255.0, g / 255.0, b / 255.0
    mx = max(r, g, b)
    mn = min(r, g, b)
    df = mx - mn
    if mx == mn:
        h = 0
    elif mx == r:
        h = (60 * ((g - b) / df) + 360) % 360
    elif mx == g:
        h = (60 * ((b - r) / df) + 120) % 360
    elif mx == b:
        h = (60 * ((r - g) / df) + 240) % 360
    if mx == 0:
        s = 0
    else:
        s = df / mx
    v = mx
    return h, s, v


class CustomImageSegmentationMRF():
    def initialize(self, test_image_filepath):
        self.im = Image.open(test_image_filepath).convert('LA')
        self.pixels = self.im.load()
        colored_im = Image.open(test_image_filepath)

        self.colored_pixels = colored_im.load()

        width, height = colored_im.size
        self.hue_pixels = [[0 for _ in range(height)] for _ in range(width)]
        i = 9
        j = 0

        while (i < width):
            while (j < height):
                pix, _, _ = rgb2hsv(self.colored_pixels[i, j][0], self.colored_pixels[i, j][1],
                                    self.colored_pixels[i, j][2])
                self.hue_pixels[i][j] = pix
                j += 1
            i += 1
            j = 0

        _, self.in_class_means, self.in_class_stds = fit_class_dists(class_original_image_filenames,
                                                                     test_image_filepath,
                                                                     other_pixels_colors, 'in')
        _, self.hu_class_means, self.hu_class_stds = fit_class_dists(class_original_image_filenames,
                                                                     test_image_filepath,
                                                                     other_pixels_colors, 'hu')

        _, self.sa_class_means, self.sa_class_stds = fit_class_dists(class_original_image_filenames,
                                                                     test_image_filepath,
                                                                     other_pixels_colors, 'sa')

        _, self.va_class_means, self.va_class_stds = fit_class_dists(class_original_image_filenames,
                                                                     test_image_filepath,
                                                                     other_pixels_colors, 'va')

        _, self.te_class_means, self.te_class_stds = fit_class_dists(class_original_image_filenames,
                                                                     test_image_filepath,
                                                                     other_pixels_colors, 'te')
        print(self.te_class_means, self.te_class_stds)
        _, self.r_class_means, self.r_class_stds = fit_class_dists(class_original_image_filenames,
                                                                   test_image_filepath,
                                                                   other_pixels_colors, 'r')
        _, self.g_class_means, self.g_class_stds = fit_class_dists(class_original_image_filenames,
                                                                   test_image_filepath,
                                                                   other_pixels_colors, 'g')
        _, self.b_class_means, self.b_class_stds = fit_class_dists(class_original_image_filenames,
                                                                   test_image_filepath,
                                                                   other_pixels_colors, 'b')

    """
     Start SA setting parameters

     Parameters
     ----------
     state :
         The current state to check with neighboring state for simulated annealing technique
     repeat_cnt:
         The number of times repeating move() until it reaches the convergence for simulated annealing technique
     t :
         Temperature
     t_ratio :
         Temperature ratio

     """

    def start(self, repeat_count, state, test_image_filepath, fixed_pixel_indexes, beta=5,
              t=100, neighbor_count=4, t_ratio=0.97):
        self.initialize(test_image_filepath)

        self.t_ratio = t_ratio
        self.neighbor_count = neighbor_count
        self.fixed_pixel_indexes = fixed_pixel_indexes
        self.beta = beta
        self.state = state
        self.e = self.get_total_energy()
        self.t = t
        self.move_count = 0
        for i in range(repeat_count):
            self.move()
            self.move_count += 1

        return self.state

    def move(self):
        if self.move_count % 300 == 0:
            self.t = self.t * self.t_ratio
        i = random.randint(0, len(self.state) - 1)
        j = random.randint(0, len(self.state[0]) - 1)
        while (i, j) in self.fixed_pixel_indexes:
            i = random.randint(0, len(self.state) - 1)
            j = random.randint(0, len(self.state[0]) - 1)

        rand_mean_index = random.randint(0, len(self.in_class_means) - 1)
        while (rand_mean_index == self.state[i][j]):
            rand_mean_index = random.randint(0, len(self.in_class_means) - 1)
        pre_state = self.state[i][j]
        new_e = self.e
        new_e -= self.get_pix_total_energy(i, j)
        self.state[i][j] = rand_mean_index
        new_e += self.get_pix_total_energy(i, j)
        if new_e < self.e:
            self.e = new_e
            return
        else:
            # self.state[i][j] = pre_state
            # return

            # if(new_e - self.e)<3:
            # if random.uniform(0, 1) <= self.t:
            if random.uniform(0, 1) <= np.math.exp(-1 * float(new_e - self.e) / float(self.t)):
                self.e = new_e
                return
            else:
                self.state[i][j] = pre_state
                return

    """
      Get the energy of the neighbors for each pixel and 
      return the sum of their energies multiplied by beta 
      in oder to calculate the doubletons 
      (Doubleton: favours similar labels at neighbouring pixels – smoothness prior)
      """

    def get_neighbors_sum_energy(self, i, j, neighbor_cnt):
        e = 0
        if i != 0:
            e = e - 1 if self.state[i][j] == self.state[i - 1][j] else e + 1

        if j != 0:
            e = e - 1 if self.state[i][j] == self.state[i][j - 1] else e + 1

        if i != self.i_max:
            e = e - 1 if self.state[i][j] == self.state[i + 1][j] else e + 1

        if j != self.j_max:
            e = e - 1 if self.state[i][j] == self.state[i][j + 1] else e + 1

        # The 4 top conditions are in terms of showing the neighbors in 4 directions (up,down,left and right)
        # The next 4 conditions are used just in case of checking the rest of the neighbors (the diagonal ones)

        if neighbor_cnt == 8:
            if i != 0 and j != 0:
                e = e - 1 if self.state[i][j] == self.state[i - 1][j - 1] else e + 1

            if i != 0 and j != self.j_max:
                e = e - 1 if self.state[i][j] == self.state[i - 1][j + 1] else e + 1

            if i != self.i_max and j != 0:
                e = e - 1 if self.state[i][j] == self.state[i + 1][j - 1] else e + 1

            if i != self.i_max and j != self.j_max:
                e = e - 1 if self.state[i][j] == self.state[i + 1][j + 1] else e + 1

        return e * self.beta


def get_neighbors_sum_energy_for_total_energy_computing(self, i, j, neighbor_count):
    e = 0
    if i != 0:
        if self.state[i][j] == self.state[i - 1][j]:
            e -= 1
        else:
            e += 1

    if j != 0:
        if self.state[i][j] == self.state[i][j - 1]:
            e -= 1
        else:
            e += 1

    if neighbor_count == 8:
        if i != 0 and j != 0:
            if self.state[i][j] == self.state[i - 1][j - 1]:
                e -= 1
            else:
                e += 1

        if i != 0 and j != self.j_max:
            if self.state[i][j] == self.state[i - 1][j + 1]:
                e -= 1
            else:
                e += 1

    return e * self.beta


def get_total_energy(self):
    self.j_max = len(self.state[0]) - 1
    self.i_max = len(self.state) - 1
    return 100000
    # e = 0
    # for i in range(len(self.state)):
    #     for j in range(len(self.state[0])):
    #         e += self.get_neighbors_sum_energy_for_total_energy_computing(i, j,
    #                                                                       self.neighbor_count) + self.get_pix_energy_just_for_intensity(
    #             i, j)
    # return e


def get_pix_energy_just_for_intensity(self, i, j):
    class_std = self.in_class_stds[self.state[i][j]]
    class_mean = self.in_class_means[self.state[i][j]]
    pix = self.pixels[i, j][0]
    pix_e_change = np.math.log(get_pistd_term(class_std)) + get_exp_term(class_mean, class_std, pix)
    return pix_e_change


def get_pix_energy_just_for_texture(self, i, j):
    class_std = self.te_class_stds[self.state[i][j]]
    class_mean = self.te_class_means[self.state[i][j]]
    pix = get_texture(self.pixels, self.i_max, self.j_max, i, j, 8)
    # pix = get_texture2(self.hue_pixels, self.i_max, self.j_max, i, j, 8)
    pix_e_change = np.math.log(get_pistd_term(class_std)) + get_exp_term(class_mean, class_std, pix)
    return pix_e_change


def get_pix_energy_just_for_hue(self, i, j):
    class_std = self.hu_class_stds[self.state[i][j]]
    class_mean = self.hu_class_means[self.state[i][j]]
    colored_pixel = self.colored_pixels[i, j]
    pix, _, _ = rgb2hsv(colored_pixel[0], colored_pixel[1], colored_pixel[2])
    pix_e_change = np.math.log(get_pistd_term(class_std)) + get_exp_term(class_mean, class_std, pix)
    return pix_e_change


def get_pix_energy_just_for_saturation(self, i, j):
    class_std = self.sa_class_stds[self.state[i][j]]
    class_mean = self.sa_class_means[self.state[i][j]]
    colored_pixel = self.colored_pixels[i, j]
    _, pix, _ = rgb2hsv(colored_pixel[0], colored_pixel[1], colored_pixel[2])
    pix_e_change = np.math.log(get_pistd_term(class_std)) + get_exp_term(class_mean, class_std, pix)
    return pix_e_change


def get_pix_energy_just_for_value(self, i, j):
    class_std = self.va_class_stds[self.state[i][j]]
    class_mean = self.va_class_means[self.state[i][j]]
    colored_pixel = self.colored_pixels[i, j]
    _, _, pix = rgb2hsv(colored_pixel[0], colored_pixel[1], colored_pixel[2])
    pix_e_change = np.math.log(get_pistd_term(class_std)) + get_exp_term(class_mean, class_std, pix)
    return pix_e_change


def get_pix_energy_just_for_red(self, i, j):
    class_std = self.r_class_stds[self.state[i][j]]
    class_mean = self.r_class_means[self.state[i][j]]
    colored_pixel = self.colored_pixels[i, j]
    pix = colored_pixel[0]
    pix_e_change = np.math.log(get_pistd_term(class_std)) + get_exp_term(class_mean, class_std, pix)
    return pix_e_change


def get_pix_energy_just_for_green(self, i, j):
    class_std = self.g_class_stds[self.state[i][j]]
    class_mean = self.g_class_means[self.state[i][j]]
    colored_pixel = self.colored_pixels[i, j]
    pix = colored_pixel[1]
    pix_e_change = np.math.log(get_pistd_term(class_std)) + get_exp_term(class_mean, class_std, pix)
    return pix_e_change


def get_pix_energy_just_for_black(self, i, j):
    class_std = self.b_class_stds[self.state[i][j]]
    class_mean = self.b_class_means[self.state[i][j]]
    colored_pixel = self.colored_pixels[i, j]
    pix = colored_pixel[2]
    pix_e_change = np.math.log(get_pistd_term(class_std)) + get_exp_term(class_mean, class_std, pix)
    return pix_e_change


def get_pix_total_energy(self, i, j):
    # in
    # return self.get_pix_energy_just_for_intensity(i, j) + self.get_neighbors_sum_energy(i, j, self.neighbor_count)
    # hu
    # return self.get_pix_energy_just_for_hue(i, j) + self.get_neighbors_sum_energy(i, j, self.neighbor_count)
    # in, hu
    # return 3*self.get_pix_energy_just_for_hue(i, j) + self.get_neighbors_sum_energy(i, j, self.neighbor_count) + self.get_pix_energy_just_for_intensity(i, j)
    # in, hu, sa, va
    return 3 * self.get_pix_energy_just_for_hue(i, j) + self.get_pix_energy_just_for_saturation(i,
                                                                                                j) + self.get_pix_energy_just_for_value(
        i, j) + self.get_neighbors_sum_energy(i, j, self.neighbor_count) + self.get_pix_energy_just_for_intensity(i, j)
    # in, hu, sa, va and other pixels
    # return 3*self.get_pix_energy_just_for_hue(i, j) + self.get_pix_energy_just_for_saturation(i, j) + self.get_pix_energy_just_for_value(i, j) + self.get_neighbors_sum_energy(i, j, self.neighbor_count) + self.get_pix_energy_just_for_intensity(i, j) + self.get_neighbors_sum_energy(i, j, self.neighbor_count)\
    #        + 10*(self.get_neighbors_sum_energy(i+1, j+1, 4) + self.get_neighbors_sum_energy(i-1, j-1, 4) + self.get_neighbors_sum_energy(i-1, j+1, 4) + self.get_neighbors_sum_energy(i+1, j-1, 4))

    ##in, hu, sa, va, texture
    # return (3*self.get_pix_energy_just_for_hue(i, j) + self.get_pix_energy_just_for_saturation(i, j) + self.get_pix_energy_just_for_value(i, j) + self.get_neighbors_sum_energy(i, j, self.neighbor_count) + self.get_pix_energy_just_for_intensity(i, j)) \
    #        + 1*self.get_pix_energy_just_for_texture(i, j)

    # in, hu, texture
    # return 3*self.get_pix_energy_just_for_hue(i, j) + self.get_neighbors_sum_energy(i, j, self.neighbor_count) + self.get_pix_energy_just_for_intensity(i, j) \
    #        + 5*self.get_pix_energy_just_for_texture(i, j)


def get_texture(orig_pixels, width, height, i, j, neighbor_count=8):
    e = 0
    threshold = 65
    if i != 0:
        if np.math.fabs(orig_pixels[i, j][0] - orig_pixels[i - 1, j][0]) > threshold:
            e += 1

    if j != 0:
        if np.math.fabs(orig_pixels[i, j][0] - orig_pixels[i, j - 1][0]) > threshold:
            e += 1

    if i != width:
        if np.math.fabs(orig_pixels[i, j][0] - orig_pixels[i + 1, j][0]) > threshold:
            e += 1

    if j != height:
        if np.math.fabs(orig_pixels[i, j][0] - orig_pixels[i, j + 1][0]) > threshold:
            e += 1

    if neighbor_count == 8:
        if i != 0 and j != 0:
            if np.math.fabs(orig_pixels[i, j][0] - orig_pixels[i - 1, j - 1][0]) > threshold:
                e += 1

        if i != 0 and j != height:
            if np.math.fabs(orig_pixels[i, j][0] - orig_pixels[i - 1, j + 1][0]) > threshold:
                e += 1

        if i != width and j != 0:
            if np.math.fabs(orig_pixels[i, j][0] - orig_pixels[i + 1, j - 1][0]) > threshold:
                e += 1

        if i != width and j != height:
            if np.math.fabs(orig_pixels[i, j][0] - orig_pixels[i + 1, j + 1][0]) > threshold:
                e += 1

    return e


def get_texture2(orig_pixels, width, height, i, j, neighbor_count=8):
    e = 0
    threshold = 50
    if i != 0:
        if np.math.fabs(orig_pixels[i][j] - orig_pixels[i - 1][j]) > threshold:
            e += 1

    if j != 0:
        if np.math.fabs(orig_pixels[i][j] - orig_pixels[i][j - 1]) > threshold:
            e += 1

    if i != width:
        if np.math.fabs(orig_pixels[i][j] - orig_pixels[i + 1][j]) > threshold:
            e += 1

    if j != height:
        if np.math.fabs(orig_pixels[i][j] - orig_pixels[i][j + 1]) > threshold:
            e += 1

    if neighbor_count == 8:
        if i != 0 and j != 0:
            if np.math.fabs(orig_pixels[i][j] - orig_pixels[i - 1][j - 1]) > threshold:
                e += 1

        if i != 0 and j != height:
            if np.math.fabs(orig_pixels[i][j] - orig_pixels[i - 1][j + 1]) > threshold:
                e += 1

        if i != width and j != 0:
            if np.math.fabs(orig_pixels[i][j] - orig_pixels[i + 1][j - 1]) > threshold:
                e += 1

        if i != width and j != height:
            if np.math.fabs(orig_pixels[i][j] - orig_pixels[i + 1][j + 1]) > threshold:
                e += 1

    return e


def fit_normal_dist_per_class(orgiginal_image_filepath, train_image_filepath, other_pixels_color, feature):
    image_per_label_pixels = set()
    orig_im = Image.open(orgiginal_image_filepath)
    if feature == "in":
        orig_im = orig_im.convert('LA')
    if feature == "te":
        orig_im = orig_im.convert('LA')
    orig_pixels = orig_im.load()
    width, height = orig_im.size

    # if feature == 'te':
    if feature == 'asd':
        hue_pixels = [[0 for _ in range(height)] for _ in range(width)]
        i = 9
        j = 0

        while (i < width):
            while (j < height):
                pix, _, _ = rgb2hsv(orig_pixels[i, j][0], orig_pixels[i, j][1], orig_pixels[i, j][2])
                hue_pixels[i][j] = pix
                j += 1
            i += 1
            j = 0

    i = 9
    j = 0
    data = []

    while (i < width):
        while (j < height):
            if feature == 'in':
                pix = orig_pixels[i, j][0]
            elif feature == 'hu':
                pix, _, _ = rgb2hsv(orig_pixels[i, j][0], orig_pixels[i, j][1], orig_pixels[i, j][2])
            elif feature == 'sa':
                _, pix, _ = rgb2hsv(orig_pixels[i, j][0], orig_pixels[i, j][1], orig_pixels[i, j][2])
            elif feature == 'va':
                _, _, pix = rgb2hsv(orig_pixels[i, j][0], orig_pixels[i, j][1], orig_pixels[i, j][2])
            elif feature == 'te':
                pix = get_texture(orig_pixels, width - 1, height - 1, i, j, 8)
                # pix = get_texture2(hue_pixels, width-1, height-1, i, j, 8)
            elif feature == 'r':
                pix = orig_pixels[i, j][0]
            elif feature == 'g':
                pix = orig_pixels[i, j][1]
            elif feature == 'b':
                pix = orig_pixels[i, j][2]

            if (pix != other_pixels_color):
                image_per_label_pixels.add((i, j))
                data.append(pix)
            j += 1
        i += 1
        j = 0
    mean, std = norm.fit(data)
    return mean, std, image_per_label_pixels


def get_gaussian_naive_bayes_prob_per_class(mean, std, val):
    exp_term = get_exp_term(mean, std, val)
    pistd_term = get_pistd_term(std)
    return float(np.math.exp(-1 * exp_term)) / pistd_term


def get_pistd_term(std):
    return float(np.math.sqrt(2 * np.math.pi * np.math.pow(std, 2)))


def get_exp_term(mean, std, val):
    return (float(np.math.pow(val - mean, 2)) / float(2 * np.math.pow(std, 2)))


def report_segment_image_mrf(base_filepath, test_image_filename, labeled_image_filenme, noisy_train_image_filename,
                             class_original_image_filenames, other_pixels_colors, beta=5, t=100, neighbor_count=4,
                             true_label_prob=0, t_ratio=0.97):
    true_label_in_init_count = 0
    all_classes_image_per_label_pixels, class_means, class_stds = fit_class_dists(class_original_image_filenames,
                                                                                  noisy_train_image_filename,
                                                                                  other_pixels_colors, 'in')

    _, _, pixel_class_indexes = segment_image(base_filepath,
                                              test_image_filename, all_classes_image_per_label_pixels,
                                              class_means, class_stds)
    fixed_pixel_indexes = set()

    for i in range(len(pixel_class_indexes)):
        for j in range(len(pixel_class_indexes[0])):
            if random.uniform(0, 1) > true_label_prob:
                pixel_class_indexes[i][j] = random.randint(0, len(class_means) - 1)
            else:
                for idx, all_classes_image_per_label_pixel in enumerate(all_classes_image_per_label_pixels):
                    if (i, j) in all_classes_image_per_label_pixel:
                        pixel_class_indexes[i][j] = idx
                        true_label_in_init_count += 1
                        break

    image_segmentation_mrf = CustomImageSegmentationMRF()
    new_pixel_class_indexes = image_segmentation_mrf.start(500000, pixel_class_indexes,
                                                           base_path + test_image_filename,
                                                           fixed_pixel_indexes, beta, t, neighbor_count, t_ratio)

    width = len(new_pixel_class_indexes)
    height = len(new_pixel_class_indexes[0])
    saving_im = Image.new('LA', (width, height))
    im = Image.open(base_filepath + test_image_filename).convert('LA')
    pixels = im.load()
    all_pixels_count = 0
    all_trues_count = 0
    for i in range(width):
        for j in range(height):
            saving_im.putpixel((i, j), (int(class_colors[new_pixel_class_indexes[i][j]]), pixels[i, j][1]))
            all_pixels_count += 1
            if (i, j) in all_classes_image_per_label_pixels[new_pixel_class_indexes[i][j]]:
                all_trues_count += 1
    saving_im.save(base_filepath + labeled_image_filenme, 'png')
    print("accuracy: " + str(float(all_trues_count) / float(all_pixels_count)))
    print("true_label_prob: " + str(true_label_prob))
    print("true_label_in_init_count: " + str(true_label_in_init_count))
    print("all_pixels_count: " + str(all_pixels_count))
    pass


def segment_image(base_filepath,
                  test_image_filename, all_classes_image_per_label_pixels, class_means, class_stds):
    im = Image.open(base_filepath + test_image_filename).convert('LA')
    pixels = im.load()
    width, height = im.size
    # print(width, height)
    pixel_class_indexes = [[0 for x in range(height)] for y in range(width)]
    i = 0
    j = 0
    all_trues = 0
    all_pixels = 0
    while (i < width):
        while (j < height):
            all_pixels += 1
            pix = pixels[i, j][0]
            prob = -1
            is_true = False
            for idx, class_mean in enumerate(class_means):
                new_prob = get_gaussian_naive_bayes_prob_per_class(class_means[idx], class_stds[idx], pix)
                if new_prob > prob:
                    prob = new_prob
                    if (i, j) in all_classes_image_per_label_pixels[idx]:
                        is_true = True
                    else:
                        is_true = False
                        # print(i, j)
                    pixel_class_indexes[i][j] = idx

            if is_true:
                all_trues += 1
            j += 1
        i += 1
        j = 0
    return all_pixels, all_trues, pixel_class_indexes


def fit_class_dists(class_original_image_filenames, noisy_train_image_filename, other_pixels_colors, feature):
    class_means = list()
    class_stds = list()
    all_classes_image_per_label_pixels = list()
    for class_original_image_filename, other_pixels_color in zip(class_original_image_filenames, other_pixels_colors):
        class_mean, class_std, class_image_per_label_pixels = fit_normal_dist_per_class(
            base_path + class_original_image_filename, base_path + noisy_train_image_filename, other_pixels_color,
            feature)
        class_means.append(class_mean)
        class_stds.append(class_std)
        all_classes_image_per_label_pixels.append(class_image_per_label_pixels)
    return all_classes_image_per_label_pixels, class_means, class_stds


if __name__ == "__main__":
    base_path = "\\home\\hajilo\\PycharmProjects\\PGM_P1\\"
    train_image_filename = "test2.jpg"
    test_image_filename = "test2.jpg"

    b = 0.1
    t = 100
    neighbor_count = 4
    true_label_prob = 0
    t_ratio = 0.97

    labeled_image_filenme = "test2-labeled-mrf-t" + str(t) + "-b" + str(
        b) + "-nc" + str(neighbor_count) + "-tp" + str(true_label_prob) + "-tr" + str(
        t_ratio) + "-hu-in-te.jpg"
    report_segment_image_mrf(base_path, test_image_filename, labeled_image_filenme,
                             train_image_filename,
                             class_original_image_filenames, other_pixels_colors, b, t,
                             neighbor_count, true_label_prob, t_ratio)
