In [1]:
import numpy as np
from matplotlib import pyplot as plt
import csv

In [None]:
rows = 321
columns = 265
pixels = rows * columns
W = 10

In [None]:
def get_image_data():
    with open("img.csv") as f:
        reader = csv.reader(f, delimiter=',')
        data = np.zeros(pixels, dtype=bool)
        for i, row in enumerate(reader):
            for j, value in enumerate(row):
                data[i * columns + j] = True if value == "1" else False
        return data
    
# def are_neighbours(a, b):
#     """ Returns True if given indices a and b are neighbours; the check is made
#     by determining whether the distance between the two is 1 and they are on the
#     same row (left and right) or the distance is the number of columns (up and down). """
#     d = abs(a - b) # the distance between the two pixels
#     rd = abs(int(a / columns) - int(b / columns)) # the row distance
#     return (d == 1 and rd == 0) or (d == columns)

def get_pairs():
    p_pairs = set()
    s_pairs = set()
    for i in range(pixels):
        p_neighbours, s_neighbours = get_neighbours(i)
        for j in p_neighbours:
            if (i, j) not in p_pairs and (j, i) not in p_pairs:
                p_pairs.add((i, j))
        for j in s_neighbours:
            if (i, j) not in s_pairs and (j, i) not in s_pairs:
                s_pairs.add((i, j))
    return p_pairs, s_pairs

# def get_neighbours(x):
#     neighbours = set()
#     # up
#     if x >= columns:
#         neighbours.add(x - columns)
#     # down
#     if x < pixels - columns:
#         neighbours.add(x + columns)
#     # right
#     if (x + 1) % columns != 0:
#         neighbours.add(x + 1)
#     # left
#     if x % (columns ) != 0:
#         neighbours.add(x - 1)
#     return neighbours

# def get_neighbours(x):
#     neighbours = set()
#     if x == 0:
#         # Top left corner
#         return {1, columns}
#     elif x == columns - 1:
#         # Top right corner
#         return {x - 1, x + columns}
#     elif x == pixels - columns:
#         # Bottom left
#         return {x + 1, x - columns}
#     elif x == pixels - 1:
#         # Bottom right
#         return {x - 1, x - columns}
#     elif x < columns:
#         # Top row
#         return {x - 1, x + 1, x + columns}
#     elif x > pixels - columns:
#         # Bottom row
#         return {x - 1, x + 1, x - columns}
#     elif x % columns == 0:
#         # left column
#         return {x - columns, x + columns, x + 1}
#     elif x % (columns - 1) == 0:
#         # left column
#         return {x - columns, x + columns, x - 1}
#     else:
#         return {x - columns, x + columns, x - 1, x + 1}

def get_neighbours(x):
    """ Returns the indices of the pixels neighbouring the pixel at index x.
    The neighbours are considered to be the pixel which are not just above,
    below, to the left and right, but also the four ones in the corners. """
    if x == 0:
        # Top left corner
        return {1, columns}, {columns + 1}
    elif x == columns - 1:
        # Top right corner
        return {x - 1, x + columns}, {x + columns - 1}
    elif x == pixels - columns:
        # Bottom left
        return {x + 1, x - columns}, {x - columns + 1}
    elif x == pixels - 1:
        # Bottom right
        return {x - 1, x - columns}, {x - columns - 1}
    elif x < columns:
        # Top row
        return {x - 1, x + 1, x + columns}, {x + columns - 1, x + columns + 1}
    elif x > pixels - columns:
        # Bottom row
        return {x - 1, x + 1, x - columns}, {x - columns - 1, x - columns + 1}
    elif x % columns == 0:
        # left column
        return {x - columns, x + columns, x + 1}, {x + columns + 1, x - columns + 1}
    elif (x + 1) % columns == 0:
        # right column
        return {x - columns, x + columns, x - 1}, {x + columns - 1, x - columns - 1}
    else:
        return {x - columns, x + columns, x - 1, x + 1}, {x + columns - 1, x + columns + 1, x - columns - 1, x - columns + 1}

In [None]:
noisy_img = get_image_data()
clean_img = get_image_data()
#clean_img = np.random.randint(0, 2, (pixels), dtype=bool)

def plot_images():
    fig = plt.figure(figsize=(12, 12))
    fig.add_subplot(1, 2, 1)
    img = noisy_img.reshape((rows, columns))
#     img = plt.imshow(img, cmap="gist_gray")
    img = plt.imshow(img)
    plt.axis('off')
    fig.add_subplot(1, 2, 2)
    img = clean_img.reshape((rows, columns))
    img = plt.imshow(img)
    plt.axis('off')
    plt.show()

In [None]:
def obj():
    """ Returns the value of the objective function for the current state
    of the noisy image and clean image. """
    obj = 0
    for i in range(pixels):
        obj += (2 * noisy_img[i] - 1) * clean_img[i]
    p_pairs, s_pairs = get_pairs()
    for pair in p_pairs:
        obj += W if clean_img[pair[0]] == clean_img[pair[1]] else 0
    for pair in s_pairs:
        obj += 0.7 * W if clean_img[pair[0]] == clean_img[pair[1]] else 0
    return obj

def updated_obj(obj, i):
    """ Returns the value of the objective function should the i-th pixel be the
    other value given the current value of the objective function. This is achieved
    by subtracting the contribution of pixel i for the value it has at
    first, flips the value and adds the contribution back. """
    obj -= (2 * noisy_img[i] - 1) * clean_img[i]
    obj += (2 * noisy_img[i] - 1) * (not clean_img[i])
    p_n, s_n = get_neighbours(i)
    for j in p_n:
        obj -= W if clean_img[i] == clean_img[j] else 0
        obj += W if (not clean_img[i]) == clean_img[j] else 0
    for j in s_n:
        obj -= 0.7 * W if clean_img[i] == clean_img[j] else 0
        obj += 0.7 * W if (not clean_img[i]) == clean_img[j] else 0
    return obj    

In [None]:
def clean():
    b_obj = obj()
    n = 0
    while True:
        r = np.arange(pixels - 1)
        np.random.shuffle(r)
        for i in r:
            new_obj = updated_obj(b_obj, i)
            if new_obj > b_obj:
                b_obj = new_obj
                clean_img[i] = not clean_img[i]
            if n % 1000 == 0:
                plot_images()
            n += 1
            
clean()