In [None]:
import mrcfile

import time
import numpy as np
import matplotlib.pyplot as plt

### my library

import utils
import images
import distances
import transforms

In [None]:
# you will need to download these
vol_file = 'Data/emd_11657.map'  # https://www.ebi.ac.uk/emdb/EMD-11657
mask_file = 'Data/emd_11657_msk_1.map'  # https://www.ebi.ac.uk/emdb/EMD-11657?tab=interpretation

with mrcfile.open(vol_file) as mrc:
    vol = mrc.data
    vox = mrc.voxel_size.x
    mrc.close()
    
with mrcfile.open(mask_file) as mrc:
    mask = mrc.data
    mrc.close()

vol = vol * mask
vol = np.where(vol < 0, 0, vol)
vol = vol[70:281, 70:281, 70:281]

In [None]:
# Define projection angles
a_max = 45
n_projections = 45
view_angles = np.linspace(0, a_max, n_projections)

# Generate projection images
projections = utils.generate_projections(vol, view_angles, axis=0)
imgs = images.Image(projections).preprocess_images()

In [None]:
plt.figure(figsize=(4,4))
plt.imshow(imgs[0])
plt.axis('off')
# plt.savefig('')
plt.show()

plt.figure(figsize=(4,4))
plt.imshow(imgs[-1])
plt.axis('off')
# plt.savefig('')
plt.show()

plt.figure(figsize=(4,4))
plt.imshow(imgs[0] - imgs[-1])
plt.axis('off')
# plt.savefig('')
plt.show()

In [None]:
N, ny, nx = imgs.shape
p = ny + 1
proj_angles = np.linspace(0, 360, ny, endpoint=False)

## Euclidean Distance
t = time.time()
dists_l2 = np.array([distances.l2_distance(imgs[0], imgs[idx]) for idx in range(N)])
print('  Euclidean distance runtime: ', time.time() - t)

### sliced 2-Wasserstein 
t = time.time()
U = transforms.Transform(imgs, apply_ramp=False, angles=proj_angles, n_points=p).inverse_cdf_transform()
dists_sw_nrf = np.array([distances.sliced_distance(U[0], U[idx]) for idx in range(N)]) 
print('sliced 2-Wasserstein runtime: ', time.time() - t)

# ### Wasserstein 2
# t = time.time()
# M = distances.compute_transport_matrix(imgs[0], metric='sqeuclidean')
# dists_w2 = np.array([distances.wasserstein_distance(imgs[0], imgs[idx], M) for idx in range(N)])
# print('       2-Wasserstein runtime: ', time.time() - t)

In [None]:
plt.figure(figsize=(8,5))

c = {'Euclidean': 'tomato',
     '2-Wasserstein': 'forestgreen',
     'sliced 2-Wasserstein': 'dodgerblue'}

plt.plot(view_angles, dists_l2, color=c['Euclidean'], linewidth=3, label='Euclidean')
# plt.plot(view_angles, np.sqrt(dists_w2), color=c['2-Wasserstein'], linewidth=3, label='2-Wasserstein')
plt.plot(view_angles, dists_sw_nrf, color=c['sliced 2-Wasserstein'], linewidth=3, label='sliced 2-Wasserstein')

plt.xlabel('out of plane rotation (degrees)')
plt.ylabel('distance')
plt.grid(which='major', linestyle='--')
plt.legend()
plt.rc('axes', labelsize=11)
plt.rc('xtick', labelsize=8)
plt.rc('ytick', labelsize=8)
plt.rc('legend', fontsize=8)
plt.legend(loc='upper left')
# plt.savefig('')
plt.show()