In [None]:
import time
import pickle

import numpy as np
import matplotlib.pyplot as plt

### my library

import plots
import utils
import images
import distances
import transforms

In [None]:
### prepare images

### inputs

L = 65  # images will be of size L x L pixels
n_imgs = False  # if limiting the number of images to align

### load original mnist test set data

mnist_image_file = 'Data/t10k-images-idx3-ubyte.gz'
mnist_label_file = 'Data/t10k-labels-idx1-ubyte.gz'

mnist_images = utils.load_mnist_images(mnist_image_file)
mnist_labels = utils.load_mnist_labels(mnist_label_file)

digits = [2, 3, 4, 5, 6, 7, 9]  # the non-symmetric digits

digit_images = {}

for digit in digits:
    loc = np.argwhere(digit == mnist_labels).flatten()
    imgs = mnist_images[loc]
    imgs = utils.zero_pad_image_stack_to_size(imgs, L)  # zero pad to size L x L
    digit_images[digit] = imgs

# # choose which digits to use for alignment

# use the digit that minimizes the Eucldiean distance to the mean
reference_digit = {}
for digit, imgs in digit_images.items():
    img_mean = np.mean(imgs, axis=0)
    dists_l2 = np.array([np.linalg.norm(img_mean - imgs[idx]) for idx in range(len(imgs))])
    reference_digit[digit] = np.argmin(dists_l2)

# # use the digit that minimizes the sliced 2-Wasserstein distance to the mean
# reference_digit = {}
# for digit, imgs in digit_images.items():
#     img_mean = np.mean(imgs, axis=0)
#     U = transforms.Transform(img_mean, apply_ramp=False).inverse_cdf_transform()
#     V = transforms.Transform(imgs, apply_ramp=False).inverse_cdf_transform()
#     dists_sw2 = np.array([distances.sliced_distance(U, V[idx]) for idx in range(len(imgs))])
#     reference_digit[digit] = np.argmin(dists_sw2)

# # use the first image for each digit in test set
# reference_digit = {digit: 0 for digit in digits}

# # use a random image for each digit
# np.random.seed(0)
# reference_digit = {digit: np.random.randint(0, imgs.shape[0]) for digit, imgs in digit_images.items()}

reference_images = {}

for digit, ref_idx in reference_digit.items():
    reference_images[digit] = digit_images[digit][ref_idx]
    digit_images[digit] = np.delete(digit_images[digit], ref_idx, axis=0)
    if n_imgs:
        digit_images[digit] = digit_images[digit][:n_imgs]

In [None]:
### compute rotational distances for all digits at random rotations and shifts

t = time.time()

np.random.seed(0)

shifts = [0, 2, 4, 6]  # pixel shifts
# metrics = ['sliced 2-Wasserstein', 'signed sliced 2-Wasserstein']
# metrics = ['sliced 2-Wasserstein', 'signed sliced 2-Wasserstein', 'Euclidean', 'wavelet EMD']
metrics = ['sliced 2-Wasserstein', 'signed sliced 2-Wasserstein', 'sliced 2-Cramér', 'Euclidean', 'wavelet EMD']
alignment_est = {(d, s): {m: [] for m in metrics} for d in digits for s in shifts}

N_max = np.amax([imgs.shape[0] for digit, imgs in digit_images.items()])
rotations = np.random.randint(0, 360, N_max)
translations = {s: np.random.choice((-s, s), (N_max, 2)) for s in shifts}

for d in digits:

    imgs = digit_images[d]
    imgs_rot = np.array([utils.rotate(imgs[idx], rotations[idx]) for idx in range(len(imgs))])

    for s in shifts:

        imgs_rot_shift = np.array([utils.translate(imgs_rot[idx], translations[s][idx][0], translations[s][idx][1]) for idx in range(len(imgs))])

        ### compute various rotational distances
        N, ny, nx = imgs_rot_shift.shape
        proj_angles = np.linspace(0, 360, ny, endpoint=False)  # projection angles
        p = ny + 1  # number of points to sample

        f = images.Image(reference_images[d]).preprocess_images()  # normalize images
        g = images.Image(imgs_rot_shift).preprocess_images() 

        if 'sliced 2-Wasserstein' in metrics:
            U_I = transforms.Transform(f, apply_ramp=False, angles=proj_angles, n_points=p).inverse_cdf_transform()
            V_I = transforms.Transform(g, apply_ramp=False, angles=proj_angles, n_points=p).inverse_cdf_transform()
            dists_dict_sw = distances.reference_rotational_distances(U_I[0], V_I)
            angles_est_sw = np.array([proj_angles[np.argmin(dists)] for idx, dists in dists_dict_sw.items()])
            alignment_est[d, s]['sliced 2-Wasserstein'] = angles_est_sw

        if 'signed sliced 2-Wasserstein' in metrics:
            U_Ip, U_In = transforms.Transform(f, apply_ramp=True, angles=proj_angles, n_points=p).signed_inverse_cdf_transform()
            V_Ip, V_In = transforms.Transform(g, apply_ramp=True, angles=proj_angles, n_points=p).signed_inverse_cdf_transform()
            dists_dict_sw_sgn = distances.reference_signed_rotational_distances(U_Ip[0], V_Ip, U_In[0], V_In)
            angles_est_sw_sgn = np.array([proj_angles[np.argmin(dists)] for idx, dists in dists_dict_sw_sgn.items()])
            alignment_est[d, s]['signed sliced 2-Wasserstein'] = angles_est_sw_sgn

        if 'sliced 2-Cramér' in metrics:
            U_C = transforms.Transform(f, apply_ramp=False, angles=proj_angles, n_points=p).cdf_transform()
            V_C = transforms.Transform(g, apply_ramp=False, angles=proj_angles, n_points=p).cdf_transform()
            dists_dict_sc = distances.reference_rotational_distances(U_C[0], V_C)
            angles_est_sc = np.array([proj_angles[np.argmin(dists)] for idx, dists in dists_dict_sc.items()])
            alignment_est[d, s]['sliced 2-Cramér'] = angles_est_sc

        if 'Euclidean' in metrics:
            dists_dict_l2 = {idx: distances.real_space_rotational_distances(f[0], g[idx], proj_angles) for idx in range(N)}
            angles_est_l2 = np.array([proj_angles[np.argmin(dists)] for idx, dists in dists_dict_l2.items()])
            alignment_est[d, s]['Euclidean'] = angles_est_l2
        
        if 'wavelet EMD' in metrics:
            dists_dict_wemd = {idx: distances.wemd_rotational_distances(f[0], g[idx], proj_angles) for idx in range(N)}
            angles_est_wemd = np.array([proj_angles[np.argmin(dists)] for idx, dists in dists_dict_wemd.items()])
            alignment_est[d, s]['wavelet EMD'] = angles_est_wemd

        # if '2-Wasserstein' in metrics:
        #     M = distances.compute_transport_matrix(imgs[0], metric='sqeuclidean')
        #     dists_dict_w2 = {idx: distances.rotational_wasserstein_distances(f[0], g[idx], proj_angles, M) for idx in range(N)}
        #     angles_est_w2 = np.array([angles[np.argmin(dists)] for idx, dists in dists_dict_w2.items()])
        #     alignment_est[d, s]['2-Wasserstein'] = angles_est_w2

print(time.time() - t)

In [None]:
# save = False
# load = False

# if save:
#     with open("", "wb") as file:
#         pickle.dump(alignment_est, file)

# if load:
#     with open("", "rb") as file:
#         alignment_est = pickle.load(file)

In [None]:
### calculate cumulative percent of digits aligned within n_deg of ground truth

n_deg_tol = 46

cumulative_alignment = {(d, s): {m: np.zeros(n_deg_tol) for m in metrics} for d in digits for s in shifts}

for d in digits:
    for s in shifts:
        for m in metrics:
            for n_deg in range(n_deg_tol):
                cumulative_alignment[(d, s)][m][n_deg] = plots.within_n_degrees(rotations[:len(digit_images[d])], 
                                                                                alignment_est[(d, s)][m], 
                                                                                n_deg) * 100

In [None]:
### plot the results (rows = digits, cols = shifts)

nrows = len(digits)
ncols = len(shifts)

fig, axs = plt.subplots(nrows, ncols, figsize=(12, 16), sharex=True, sharey=True)

c = {'sliced 2-Wasserstein': 'dodgerblue',
     'signed sliced 2-Wasserstein': 'mediumpurple',
     'sliced 2-Cramér': 'orange',
     'Euclidean': 'tomato',
     'wavelet EMD': 'darkgrey'}

for idx1, d in enumerate(digits):
    for idx2, s in enumerate(shifts):
        for m in metrics:
            axs[idx1, idx2].plot(cumulative_alignment[(d, s)][m], linewidth=2, color=c[m], label=m)

        axs[idx1, idx2].set_ylim(-3, 103)
        axs[idx1, idx2].set_yticks([0, 50, 100])
        axs[idx1, idx2].set_xticks(np.arange(0, n_deg+1, step=10))
        axs[idx1, idx2].tick_params(axis="both", rotation=45, pad=-1.2) 
        axs[idx1, idx2].grid(which='major', linestyle='--')

plt.subplots_adjust(hspace=0.15, wspace=0.07)
# plt.savefig('')
plt.show()

In [None]:
# for d in digits:
#     plt.figure(figsize=(3,3))
#     plt.imshow(reference_images[d], cmap='gray')
#     plt.axis('off')
#     # plt.savefig('Output/MNIST_figures/Reference_digits/{}.png'.format(d), transparent=True)
#     plt.show()

In [None]:
n_align = 30

d = 2
s = 6
m = 'signed sliced 2-Wasserstein'

ref = reference_images[d].astype('float32')
imgs = digit_images[d][:n_align].astype('float32')

### rotate images
thetas_gt = rotations[:n_align]
imgs_rot = np.array([utils.rotate(imgs[idx], thetas_gt[idx]) for idx in range(n_align)])

### translate images
shifts_gt = translations[s][:n_align]
imgs_rot_shift = np.array([utils.translate(imgs_rot[idx], shifts_gt[idx][0], shifts_gt[idx][1]) for idx in range(n_align)])

### rotate by estimated alignment
thetas_est = alignment_est[(d, s)][m][:n_align]
aligned_imgs = np.array([utils.rotate(imgs_rot_shift[idx], -thetas_est[idx]) for idx in range(n_align)])

In [None]:
save = False
plots.rectangular_tile_plot(aligned_imgs - ref, rows=3, cols=5, 
                            spine_colors=[c[m] for _ in range(n_align)], spine_size=2,
                            size_x=5, size_y=3, wpad=-0.5, hpad=0.5, save_path=save)