In [None]:
%load_ext lab_black
%matplotlib inline

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
from tqdm import tqdm as tqdm
import seaborn as sns

sns.set_style("whitegrid", {"axes.grid": False})

ETA = 2.1
BETA = 1

### Import image

In [None]:
impath = "images/markov.jpg"

im = Image.open(impath)
im = transforms.Resize(100)(
    transforms.RandomHorizontalFlip(1)(
        transforms.RandomVerticalFlip(1)(transforms.Grayscale()(im))
    )
)

im_tensor = transforms.ToTensor()(im)
im_array = im_tensor.numpy().reshape(im_tensor.shape[1], im_tensor.shape[2])
y = np.where(im_array > 0.4, 1, -1)
plt.imshow(y)

In [None]:
for row in tqdm(range(y.shape[0])):
    for col in range(y.shape[1]):
        switch = np.random.binomial(1, 0.1)
        if switch:
            y[row, col] *= -1

x = y.copy()

In [None]:
plt.imshow(y)

### step1: calculate the total energy

In [None]:
def calculate_total_energy(x, y):
    energy = 0
    energy += -ETA * (x * y).sum()
    energy += -BETA * (x[:-1, :] * x[1:, :]).sum()
    energy += -BETA * (x[:, :-1] * x[:, 1:]).sum()
    return energy

In [None]:
def set_best_state(x, y, row, col):
    energies = []
    STATES = [-1, 1]
    for state in STATES:
        x[row, col] = state
        energies.append(calculate_total_energy(x, y))
    best_state = STATES[np.argmin(energies)]
    x[row, col] = best_state

In [None]:
energies = [calculate_total_energy(x, y)]
num_rounds = 0
delta_energy = 1
while delta_energy:
    for row in tqdm(range(x.shape[0])):
        for col in range(x.shape[1]):
            set_best_state(x, y, row, col)
    energies.append(calculate_total_energy(x, y))
    delta_energy = energies[-1] - energies[-2]
    num_rounds += 1
    if num_rounds == 100:
        break
    if num_rounds % 10 == 0:
        print(num_rounds)

In [None]:
pd.DataFrame({"energy": energies}).reset_index().rename(
    {"index": "n_epoch"}, axis=1
).set_index("n_epoch").plot(figsize=(12, 8))

### Observed image

In [None]:
plt.imshow(y)

### Reconstructed image

In [None]:
plt.imshow(x)