## Setup

In [None]:
%matplotlib qt
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import os

In [None]:
Path('mnist_distribution').mkdir(exist_ok=True)
os.chdir('mnist_distribution')

In [None]:
#load MNIST and concatenates train and test data
(x_train, _), (x_test, _) = mnist.load_data()
data = np.concatenate((x_train, x_test))

## 1 Mean pixel value

In [None]:
mean = np.mean(data, axis=0)
var = np.sqrt(np.var(data, axis=0))

In [None]:
fig, axs = plt.subplots(1, 2)

ax = axs[0]
ax.imshow(mean, cmap='gray', vmin=0, vmax=255, interpolation='nearest')
ax.axis(False)
ax.set_title('Mean')

ax = axs[1]
pcm = ax.imshow(var, cmap='gray', vmin=0, vmax=255, interpolation='nearest')
ax.axis(False)
ax.set_title('Variance')

plt.colorbar(pcm, ax=axs, shrink=0.5)
fig.savefig('mnist_mean_var.pdf', bbox_inches='tight', pad_inches=0)

## 2 Pixel value probability distribution

### 2.1 Plot single pixel distribution

In [None]:
px = 14
py = 14
pixels = data[:, px, py]

In [None]:
values = np.arange(256)
probs = np.zeros(256)
unique, count = np.unique(pixels, return_counts=True)
for px_value, n_ocurrences in zip(unique, count):
    probs[px_value] = 100 * n_ocurrences / data.shape[0]

In [None]:
fig = plt.figure()
plt.plot(values, probs, linewidth=1)
plt.xlabel('Pixel Value')
plt.ylabel('Probability (%)')
plt.grid()
fig.savefig('mnist_dist_pixel_%dx%d.pdf' % (px, py), bbox_inches='tight')

### 2.1 Plotting only column distribution

In [None]:
def get_column_distribution(data, column_index):
    columns = data[:, :, column_index]
    total = columns.shape[0]
    n_lines = columns.shape[1]
    x = np.arange(n_lines)
    y = np.arange(256)
    z = np.zeros((256, n_lines))
    
    #Iterates through each pixel calculating it's probability distribution
    for i in range(n_lines):
        unique, count = np.unique(columns[:, i], return_counts=True)
        for px_value, n_ocurrences in zip(unique, count):
            z[px_value][i] = n_ocurrences / total
    return x, y, z

In [None]:
def plot_column_distribution(x, y, z):
    n_lines = x.shape[0]
    X, Y = np.meshgrid(x, y)
    Z = 100 * z
    
    fig = plt.figure()
    ax = plt.axes(projection='3d')
    ax.view_init(10, 35)
    ax.contour3D(X, Y, Z, n_lines, cmap='viridis', zdir = 'x')
    ax.set_xlabel('Line')
    ax.set_ylabel('Pixel Value')
    ax.set_zlabel('Probability (%)')
    ax.set_zlim((0, 100))
    return fig

In [None]:
for column_index in [0, 12, 15]:
    x, y, z = get_column_distribution(data, column_index)
    fig = plot_column_distribution(x, y, z)
    fig.savefig('mnist_dist_column_%d.pdf' % column_index, bbox_inches='tight', pad_inches=0)

### 2.2 Plotting distribution with image reference

In [None]:
def high_light_mnist_column(image, column_index):
    alpha = np.full_like(image, 50)[..., np.newaxis]
    alpha[:, column_index, :] = 255
    image = np.repeat(image[:, :, np.newaxis], 3, axis=2)
    return np.append(image, alpha, axis=2)

In [None]:
def plot_column_distribution_and_highlight(x, y, z, highlight):
    n_lines = x.shape[0]
    X, Y = np.meshgrid(x, y)
    Z = 100 * z

    fig = plt.figure(figsize=(10, 10))
    fig.tight_layout()
    plt.subplot(323)
    plt.imshow(highlight, cmap='gray', vmin=0, vmax=255, interpolation='nearest')
    plt.axis('off')
    
    ax = plt.subplot(122, projection='3d')
    ax.view_init(10, 35)
    ax.contour3D(X, Y, Z, n_lines, cmap='viridis', zdir = 'x')
    ax.set_xlabel('Line')
    ax.set_ylabel('Pixel Value')
    ax.set_zlabel('Probability (%)')
    ax.set_zlim((0, 100))
    return fig

In [None]:
plt.ioff()
image = data[0]
for column_index in range(28):
    x, y, z = get_column_distribution(data, column_index)
    highlight = high_light_mnist_column(image, column_index)
    fig = plot_column_distribution_and_highlight(x, y, z, highlight)
    
    # Save as pdf to get the nicest quality
    fig.savefig('mnist_highlight_dist_column_%d.pdf' % column_index, bbox_inches='tight', pad_inches=0)
    # Save as png to convert images to video or gif
    fig.savefig('mnist_highlight_dist_column_%d.png' % column_index, bbox_inches='tight', pad_inches=0, dpi=196)
    plt.close(fig)

## 3 Sampling from pixel distributions

In [None]:
def get_cumulative_distribution(data):
    total, n_lines, n_columns = data.shape
    dist = np.zeros((n_lines, n_columns, 256))
    
    #Iterates through each pixel calculating it's cumulative probability distribution
    for i in range(n_lines):
        for j in range(n_columns):
            values = dist[i, j, :]
            unique, count = np.unique(data[:, i, j], return_counts=True)
            for px_value, n_ocurrences in zip(unique, count):
                values[px_value] = n_ocurrences
            for px_value in range(1, 256):
                values[px_value] += values[px_value - 1]
            values /= total
    return dist

In [None]:
def sample_dist(dist):
    p = np.random.uniform()
    return np.searchsorted(dist, p)

In [None]:
dist = get_cumulative_distribution(data)

In [None]:
SEED = 279923  # https://youtu.be/nWSFlqBOgl8?t=86  -  I love this song
np.random.seed(SEED)

images = np.zeros((3, 28, 28))
for img in images:
    for i in range(28):
        for j in range(28):
            img[i, j] = sample_dist(dist[i,j])

In [None]:
fig = plt.figure()
for i, img in enumerate(images):
    plt.subplot(1, 3, i + 1)
    plt.imshow(img, cmap='gray', vmin=0, vmax=255, interpolation='nearest')
    plt.axis(False)
fig.savefig('mnist_simple_samples.pdf', bbox_inches='tight', pad_inches=0)