In [77]:
import numpy as np
import matplotlib.pyplot as plt
import math
import itertools
import time

In [55]:
# get possible edges in a grid
def get_edges(num_rows, num_cols):
    edges = {}
    for row in range(num_rows):
        for col in range(num_cols):
            neighbors = []
            if row - 1 >= 0:
                neighbors.append((row - 1, col))
            if row + 1 < num_rows:
                neighbors.append((row + 1, col))
            if col - 1 >=0:
                neighbors.append((row, col - 1))
            if col + 1 < num_cols:
                neighbors.append((row, col + 1))
            
            edges[row*num_cols + col] = neighbors
    return edges

In [52]:
# genrate x = (x1,x2, ..., x25) using gibbs sampling
# order: x1 x2 x3 x4 x5
#        x6 x7 x8 x9 x10
#        x11 x12 x13 x14 x15
#        x16 x17 x18 x19 x20
#        x21 x22 x23 x24 x25
def gibbs_generate_x(num_rows, num_cols, num_iterations):
    # draw random values from [0,1] and place them on a grid
    grid = np.random.randint(0, 2, (num_rows,num_cols))
    
    # per each node get possible edges 
    edges = get_edges(num_rows, num_cols)
    
    # save tuples of all rows and columns for convenient iteration
    rows = [i for i in range(num_rows) for _ in range(num_cols)]
    cols = list(range(num_cols))*num_rows
    
    # iterations
    for i in range(0, num_iterations):
        # iterate over grid indices
        for row, col in zip(rows, cols):
            neighbors = edges[row*num_cols + col]
            # iterate over neighbors
            x_eq_0 = 0  # sum indicators of x_i with all the neighbors when x_i=0 
            x_eq_1 = 0  # sum indicators of x_i with all the neighbors when x_i=1
            for neighbor_row, neighbor_col in neighbors:
                x_eq_0 += 1 if grid[neighbor_row, neighbor_col] == 0 else 0
                x_eq_1 += 1 if grid[neighbor_row, neighbor_col] == 1 else 0
            
            p_x_eq_0_given_neighbors = np.exp(x_eq_0)/(np.exp(x_eq_0) + np.exp(x_eq_1))  # p(xi=0|x1,...,x25)
            grid[row,col] = 0 if np.random.rand() < p_x_eq_0_given_neighbors else 1  # generate value

    return grid

In [53]:
# generate y_i from normal distribution with mu=x_i and sigma_i=1
def generate_y(num_rows, num_cols, grid):
    y_grid = np.zeros(grid.shape)
    
    # save tuples of all rows and columns for convenient iteration
    rows = [i for i in range(num_rows) for _ in range(num_cols)]
    cols = list(range(num_cols))*num_rows
    
    # generate y_grid
    for row, col in zip(rows, cols): 
        y_grid[row,col] = np.random.normal(grid[row,col], 1)
    
    return y_grid

In [80]:
# computing the correct marginal distribution according to p(x_i|y) = sigma_over_X\x_i{p(x_i,X|y)}
# fixed_index and vixed val - the x_i index and value that will not be changed
def correct_marginal(fixed_index, xi_val, y_grid, num_rows, num_cols):
    
    # per each node get possible edges 
    edges = get_edges(num_rows, num_cols)
    
    # save tuples of all rows and columns for convenient iteration
    rows = [i for i in range(num_rows) for _ in range(num_cols)]
    cols = list(range(num_cols))*num_rows
    
    # get all combinations of X\x_i
    combinations = list(map(list, itertools.product([0, 1], repeat=num_rows*num_cols-1)))
    exp_sum_0 = exp_sum_1 = 0
    
    start_time = time.time()
    
    # per each X compute p(X|Y)
    for iteration, comb in enumerate(combinations):
        comb_0 = list(comb)
        comb_1 = list(comb)
        
        comb_0.insert(fixed_index, 0)  # insert x_i=0 at the relevant position
        comb_1.insert(fixed_index, 1)  # insert x_i=1 at the relevant position
        
        grid_0 = np.asarray(comb_0).reshape(num_rows,num_cols)  # build grid from the current permutation with x_i=0
        grid_1 = np.asarray(comb_1).reshape(num_rows,num_cols)  # build grid from the current permutation with x_i=1
        
        psi_xi_xj_0 = psi_xi_xj_1 = 0

        # iterate over grid indices
        for row, col in zip(rows, cols):
            neighbors = edges[row*num_cols + col]
                        
            # iterate over neighbors
            for neighbor_row, neighbor_col in neighbors:
                psi_xi_xj_0 += 1 if grid_0[neighbor_row, neighbor_col] == grid_0[row, col] else 0
                psi_xi_xj_1 += 1 if grid_1[neighbor_row, neighbor_col] == grid_1[row, col] else 0
                    
        psi_xi_yi_0 = np.sum(-0.5*(grid_0 - y_grid) ** 2)
        psi_xi_yi_1 = np.sum(-0.5*(grid_1 - y_grid) ** 2)

        exp_sum_0 += np.exp(psi_xi_xj_0 + psi_xi_yi_0)  # exp^(sigama_psi_ij(x_i,x_j) + sigama_psi_i(x_i,y_i))
        exp_sum_1 += np.exp(psi_xi_xj_1 + psi_xi_yi_1)  # exp^(sigama_psi_ij(x_i,x_j) + sigama_psi_i(x_i,y_i))
        
        if iteration % 100000 == 0:
            elapsed_time = time.time() - start_time
            elapsed_time = time.strftime("%H:%M:%S", time.gmtime(elapsed_time))
            print(str(elapsed_time) + ": " + str(iteration) + " iterations passed")
        
    p_xi_given_y = exp_sum_0/(exp_sum_0 + exp_sum_1) if xi_val == 0 else exp_sum_1/(exp_sum_0 + exp_sum_1)  # p(x_i=0 | y1,...,y25) or p(x_i=1 | y1,...,y25)
    return p_xi_given_y

In [75]:
grid_0 = np.random.randint(0, 2, (3,3))
y_grid = np.random.rand(3,3)
print(grid_0)
print(y_grid)
print(-0.5*(grid_0 - y_grid) ** 2)
print(np.sum(-0.5*(grid_0 - y_grid) ** 2))

[[1 1 0]
 [0 1 0]
 [1 1 1]]
[[0.13504349 0.27818506 0.44410675]
 [0.91764942 0.3202238  0.85412648]
 [0.1141354  0.85190199 0.47738424]]
[[-0.37407488 -0.26050841 -0.0986154 ]
 [-0.42104023 -0.23104784 -0.36476602]
 [-0.39237804 -0.01096651 -0.13656362]]
-2.289960953824007


In [83]:
np.random.seed(999)
num_rows = 5
num_cols = 5
num_iterations = 10000

# generate a random picture 
grid = gibbs_generate_x(num_rows, num_cols, num_iterations)
y_grid = generate_y(num_rows, num_cols, grid)

print(grid)
print(y_grid)
print(correct_marginal(0, grid[0][0], y_grid, num_rows, num_cols))

# now, it's like we are in the real world and we are receiving a picture

[[1 1 0 0 1]
 [1 1 0 0 0]
 [1 1 0 0 0]
 [0 1 1 1 0]
 [1 1 1 1 0]]
[[ 2.47089137  1.80035099  0.7296911  -1.07243159  1.33548536]
 [ 3.06151232  1.52832455 -1.15760754 -0.4899489  -1.32276818]
 [ 2.08479065  1.90488301  1.20166761  0.27954848  0.66653454]
 [-0.23913622  0.97126084  0.61148317  0.7524438  -0.08196441]
 [ 1.42112959  0.12474655  0.91836101  0.9292069   0.51232383]]
00:00:00: 0 iterations passed
00:00:06: 100000 iterations passed
00:00:13: 200000 iterations passed
00:00:20: 300000 iterations passed
00:00:27: 400000 iterations passed
00:00:35: 500000 iterations passed


KeyboardInterrupt: 