# 2 Patch Level Analysis

## 2.1 Imports & Constants

In [1]:
import itertools
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patch

from PIL import Image
from sklearn import feature_extraction

from library import generator

%config InlineBackend.figure_format='retina'

In [2]:
BARABARA = 'images/portillo/barbara.png'

COLOR_SCALE = {
    'vmin': 0,
    'vmax': 255
}

## 2.2 Utilities

In [3]:
def read_image(address):
    image = Image.open(address)
    # print(image.format, image.size, image.mode)
    pixels = np.array(image)
    pixels = pixels.astype(np.float64)
    return pixels

def to_grayscale_patches(pixels, patch_size):
    shaped_patches = feature_extraction.image.extract_patches_2d(pixels, patch_size)
    patches = np.reshape(shaped_patches, (len(shaped_patches), -1)).T
    return patches

def get_patch_index(patch_matrix, snapshot):
    target = np.ndarray.flatten(snapshot)
    for col in range(len(patch_matrix.T)):
        current_patch = patch_matrix[:,col]
        if (target == current_patch).all():
            return col
    raise KeyError()

## 2.3 Figure Components

In [4]:
PATCH_SIZE = 16

barbara = read_image(BARABARA)
noisy_barbara = barbara + np.random.normal(scale=20, size=barbara.shape)

clean_patches = to_grayscale_patches(barbara, (PATCH_SIZE, PATCH_SIZE))
noisy_patches = to_grayscale_patches(noisy_barbara, (PATCH_SIZE, PATCH_SIZE))

In [5]:
### HIGHLIGHTED CLEAN BARBARA

LOCATIONS = [(260, 280), (25, 271), (460, 20)]
COLORS = ['red', 'purple', 'blue']

fig, ax = plt.subplots(1, figsize=(64, 64))
plt.imshow(barbara, **COLOR_SCALE, cmap='gray')

for location, color in zip(LOCATIONS, COLORS):
    rect = patch.Rectangle(
        location,
        PATCH_SIZE,
        PATCH_SIZE,
        linewidth=30,
        edgecolor=color,
        facecolor='none'
    )
    ax.add_patch(rect)
plt.axis('off')
fig.savefig('02-highlighted-patches.pdf', bbox_inches='tight')
plt.close()

In [6]:
### PATCH SNAPSHOTS

for index, (location, color) in enumerate(zip(LOCATIONS, COLORS)):
    x_start, y_start = location
    
    clean_snapshot = barbara[y_start:y_start + PATCH_SIZE, x_start:x_start + PATCH_SIZE]
    fig, ax = plt.subplots(1, figsize=(64, 64))
    plt.imshow(clean_snapshot, **COLOR_SCALE, cmap='gray', interpolation='nearest')
    plt.axis('off')
    fig.savefig(f'02-clean-snapshot-{index}.pdf', bbox_inches='tight')
    plt.close()
    
    noisy_snapshot = noisy_barbara[y_start:y_start + PATCH_SIZE, x_start:x_start + PATCH_SIZE]
    fig = plt.figure(figsize=(64, 64))
    plt.imshow(noisy_snapshot, **COLOR_SCALE, cmap='gray', interpolation='nearest')
    plt.axis('off')
    fig.savefig(f'02-noisy-snapshot-{index}.pdf', bbox_inches='tight')
    plt.close()

In [7]:
### SNAPSHOT BASES

correct_indices = []
for x_start, y_start in LOCATIONS:
    clean_snapshot = barbara[y_start:y_start + PATCH_SIZE, x_start:x_start + PATCH_SIZE]
    index = get_patch_index(clean_patches, clean_snapshot)
    correct_indices.append(index)
    
ITERATIONS = 100

updates = generator.get_dictionary_learning_iterates(clean_patches)
clean_dictionary = next(itertools.islice(updates, ITERATIONS, None))
clean_dictionary = clean_dictionary.T
clean_encoding = clean_dictionary.T @ clean_patches

updates = generator.get_dictionary_learning_iterates(noisy_patches)
noisy_dictionary = next(itertools.islice(updates, ITERATIONS, None))
noisy_dictionary = noisy_dictionary.T
noisy_encoding = noisy_dictionary.T @ noisy_patches

In [8]:
PACKAGED = [
    ('clean', clean_dictionary, clean_encoding),
    ('noisy', noisy_dictionary, noisy_encoding)
]

BASES = 6

for index, correct_index in enumerate(correct_indices):
    for label, dictionary, encoding in PACKAGED:
        patch_coding = encoding[:,correct_index]
        patch_coding = np.abs(patch_coding)
        
        all_indices = list(range(len(patch_coding)))
        all_indices.sort(key=lambda num: patch_coding[num], reverse=True)
        
        fig, axs = plt.subplots(1, BASES, figsize=(60, 10))
        plt.subplots_adjust(left=None, right=None, bottom=None, top=None, wspace=0.05, hspace=0.05)
        for col, ax in zip(all_indices, axs.flat):
            base = dictionary[:,col] * np.sign(encoding[col,correct_index])
            base = base - base.min()
            base = base / base.max() * 255
            base = np.reshape(base, (PATCH_SIZE, PATCH_SIZE))

            ax.imshow(base, **COLOR_SCALE, cmap='gray')
            ax.axis('off')
        fig.savefig(f'02-patch-{label}-bases-{index}.pdf', bbox_inches='tight')
        plt.close()