In [None]:
import numpy as np
import matplotlib.pyplot as plt

### my library

import utils
import images
import distances
import transforms

In [None]:
###

# distance over translations

###

In [None]:
### show stability of sliced 2-Wasserstein to translations in image using simple Gaussian example

L = 65
sigma = 3
img = utils.generate_centered_gaussian(L=L, d=2, sigma=sigma)
img = utils.translate(img, 0, 5)  # off-center Gaussian

n_shifts = 20  # in pixels
shifts = np.arange(n_shifts+1)

imgs = np.zeros((n_shifts+1, L, L))

for idx in range(n_shifts+1):
    imgs[idx] = utils.translate(img, 0, shifts[idx])  # just shift in x-direction
    
imgs = images.Image(imgs, normalize=True, mask=True).preprocess_images()

In [None]:
### compute various distances

N, ny, nx = imgs.shape
p = ny + 1
proj_angles = np.linspace(0, 360, ny, endpoint=False)

# ### sliced 2-Cramér distance
# U_C = transforms.Transform(imgs, apply_ramp=False, angles=proj_angles, n_points=p).cdf_transform()
# dists_sc = np.array([distances.sliced_distance(U_C[0], U_C[idx]) for idx in range(N)])

### sliced 2-Wasserstein distance
U_I = transforms.Transform(imgs, apply_ramp=False, angles=proj_angles, n_points=p).inverse_cdf_transform()
dists_sw = np.array([distances.sliced_distance(U_I[0], U_I[idx]) for idx in range(N)])

### signed sliced 2-Wasserstein distance 
U_Ip, U_In = transforms.Transform(imgs, apply_ramp=True, angles=proj_angles, n_points=p).signed_inverse_cdf_transform()
dists_sw_sgn = np.array([distances.signed_sliced_distance(U_Ip[0], U_Ip[idx], U_In[0], U_In[idx]) for idx in range(N)])

### Euclidean distance
dists_l2 = np.array([distances.l2_distance(imgs[0], imgs[idx]) for idx in range(N)])

### 2-Wasserstein
M = distances.compute_transport_matrix(imgs[0])
dists_w2 = np.array([distances.wasserstein_distance(imgs[0], imgs[idx], M) for idx in range(N)])

# ### wavelet EMD
# dists_wemd = np.array([distances.wemd_distance(imgs[0], imgs[idx]) for idx in range(N)])

In [None]:
plt.figure(figsize=(4,4))
plt.imshow(imgs[0])
plt.axvline(L//2, color='white', linestyle='--', linewidth=2)
plt.axhline(L//2, color='white', linestyle='--', linewidth=2)
plt.axis('off')
# plt.savefig('')
plt.show()

plt.figure(figsize=(4,4))
plt.imshow(imgs[-1])
plt.axvline(L//2, color='white', linestyle='--', linewidth=2)
plt.axhline(L//2, color='white', linestyle='--', linewidth=2)
plt.axis('off')
# plt.savefig('')
plt.show()

plt.figure(figsize=(4,4))
plt.imshow(imgs[0] - imgs[-1])
plt.axvline(L//2, color='white', linestyle='--', linewidth=2)
plt.axhline(L//2, color='white', linestyle='--', linewidth=2)
plt.axis('off')
# plt.savefig('')
plt.show()

In [None]:
plt.figure(figsize=(8,4))

c = {'sliced 2-Wasserstein': 'dodgerblue',
     'signed sliced 2-Wasserstein': 'mediumpurple',
     'sliced 2-Cramér': 'orange',
     'Euclidean': 'tomato',
     'wavelet EMD': 'darkgrey',
     '2-Wasserstein': 'forestgreen'}

plt.plot(shifts, dists_l2, color=c['Euclidean'], linestyle='-', linewidth=3, label='Euclidean')
plt.plot(shifts, np.sqrt(dists_w2), color=c['2-Wasserstein'], linestyle='-', linewidth=3, label='2-Wasserstein')
plt.plot(shifts, dists_sw, color=c['sliced 2-Wasserstein'], linestyle='-', linewidth=3, label='sliced 2-Wasserstein')
plt.plot(shifts, dists_sw_sgn, color=c['signed sliced 2-Wasserstein'], linestyle='-', linewidth=3, label='signed sliced 2-Wasserstein')

plt.xticks(shifts)
plt.grid(which='major', linestyle='--')
plt.ylabel('distance')
plt.xlabel('shift (pixels)')
plt.title('distance over translations')
plt.rc('axes', labelsize=11)
plt.rc('xtick', labelsize=8)
plt.rc('ytick', labelsize=8)
plt.rc('legend', fontsize=8)
plt.legend(loc='upper left')
# plt.savefig('')
plt.show()

In [None]:
###

# distance over rotations

###

In [None]:
### show stability of sliced 2-Wasserstein to rotations using simple Gaussian example

L = 65
sigma = 3
img = utils.generate_centered_gaussian(L=L, d=2, sigma=sigma)

f = utils.translate(img, 0, 5)  # off-center Gaussian
g = utils.rotate(utils.translate(img, 0, 20), 180)

In [None]:
plt.figure(figsize=(4,4))
plt.imshow(f)
plt.axvline(L//2, color='white', linestyle='--', linewidth=2)
plt.axhline(L//2, color='white', linestyle='--', linewidth=2)
plt.axis('off')
# plt.savefig('')
plt.show()

plt.figure(figsize=(4,4))
plt.imshow(g)
plt.axvline(L//2, color='white', linestyle='--', linewidth=2)
plt.axhline(L//2, color='white', linestyle='--', linewidth=2)
plt.axis('off')
# plt.savefig('')
plt.show()

plt.figure(figsize=(4,4))
plt.imshow(f - g)
plt.axvline(L//2, color='white', linestyle='--', linewidth=2)
plt.axhline(L//2, color='white', linestyle='--', linewidth=2)
plt.axis('off')
# plt.savefig('')
plt.show()

In [None]:
### compute various distances

ny, nx = f.shape
p = ny + 1
proj_angles = np.linspace(0, 360, ny, endpoint=False)

### sliced 2-Wasserstein distance
U_I = transforms.Transform(f, apply_ramp=False, angles=proj_angles, n_points=p).inverse_cdf_transform()
V_I = transforms.Transform(g, apply_ramp=False, angles=proj_angles, n_points=p).inverse_cdf_transform()
dists_sw = distances.rotational_distances(U_I[0], V_I[0])

### signed sliced 2-Wasserstein distance 
U_Ip, U_In = transforms.Transform(f, apply_ramp=True, angles=proj_angles, n_points=p).signed_inverse_cdf_transform()
V_Ip, V_In = transforms.Transform(g, apply_ramp=True, angles=proj_angles, n_points=p).signed_inverse_cdf_transform()
dists_sw_sgn = distances.signed_rotational_distances(U_Ip[0], V_Ip[0], U_In[0], V_In[0])

### Euclidean distance
dists_l2 = distances.real_space_rotational_distances(f, g, proj_angles)

### 2-Wasserstein
M = distances.compute_transport_matrix(f)
dists_w2 = distances.rotational_wasserstein_distances(f, g, M, proj_angles)

In [None]:
plt.figure(figsize=(8,4))

plt.plot(proj_angles, dists_l2 / np.amax(dists_l2), color=c['Euclidean'], linewidth=3, label='Euclidean')
plt.plot(proj_angles, dists_w2 / np.amax(dists_w2), color=c['2-Wasserstein'], linewidth=3, label='2-Wasserstein')
plt.plot(proj_angles, dists_sw / np.amax(dists_sw), color=c['sliced 2-Wasserstein'], linewidth=3, label='sliced 2-Wasserstein')
plt.plot(proj_angles, dists_sw_sgn / np.amax(dists_sw_sgn), color=c['signed sliced 2-Wasserstein'], linewidth=3, label='signed sliced 2-Wasserstein')

plt.grid(which='major', linestyle='--')
plt.ylabel('normalized distance')
plt.xlabel('rotation (degrees)')
plt.title('distance over rotations')
plt.rc('axes', labelsize=11)
plt.rc('xtick', labelsize=8)
plt.rc('ytick', labelsize=8)
plt.rc('legend', fontsize=8)
plt.legend(loc='lower left')
plt.xticks(np.arange(0, 361, 20))
plt.ylim(-0.1, 1.1)
# plt.savefig('')
plt.show()