# STEM-EELS tomography of Ge-rich GST material. 
### Example of 3D reconstruction.
### Note: the tilt axis should be horizontal.

In [None]:
%matplotlib inline

import os

import numpy as np
import matplotlib.pyplot as plt
from skimage.io import imread
from skimage.transform import rotate, radon, iradon, iradon_sart
from skimage.external.tifffile import imsave

from modopt.opt.proximity import SparseThreshold
from modopt.opt.cost import costObj

from pyetomo import utils, linear, gradient, fourier, reconstruct 

In [None]:
# Load the data and the angles: 
cartos_ali = imread('Data/EELS_GST_1_5_ali.tif')
theta = np.linspace(0,180,18,endpoint=False)

# In the current version, images have to be square for the 3D reconstruction:
nb_proj,Nx,Ny = cartos_ali.shape
cartos_ali_sq = np.zeros((nb_proj,np.maximum(Nx,Ny),np.maximum(Nx,Ny)))
cartos_ali_sq[:,:Nx,:Ny] = cartos_ali

#Extract the dimensions:
nb_proj = cartos_ali_sq.shape[0]
dim_sq = cartos_ali_sq.shape[1]

In [None]:
# Step 1: Define 3D NUFFT sampling:

kspace_lib = utils.generate_kspace_etomo_3D(cartos_ali_sq)
samples = np.pi*2*utils.generate_locations_etomo_3D(dim_sq,dim_sq,theta)
fourier_op = fourier.NUFFT3(samples=samples, shape=[dim_sq, dim_sq, dim_sq])

# Step 2: Define the gradient calculation: 

gradient_op = gradient.GradAnalysis(data=kspace_lib, fourier_op=fourier_op)

# Step 3: Define the sparsity operator linear_op, the threshold proximity operation prox_op and the cost function cost_op:

# Bior4.4 undecimated:
linear_op_bior4_undecimated = linear.pyWavelet('bior4.4', nb_scale = 2, undecimated = True)
prox_op_bior4_undecimated = SparseThreshold(linear_op_bior4_undecimated, 1, thresh_type="soft")
cost_op_bior4_undecimated = costObj((gradient_op,prox_op_bior4_undecimated), verbose=False)

# Bior4.4 decimated:
linear_op_bior4_decimated = linear.pyWavelet('bior4.4', nb_scale = 2, undecimated = False)
prox_op_bior4_decimated = SparseThreshold(linear_op_bior4_decimated, 1, thresh_type="soft")
cost_op_bior4_decimated = costObj((gradient_op,prox_op_bior4_decimated), verbose=False)

# Haar:
linear_op_haar = linear.pyWavelet('haar', nb_scale=3)
prox_op_haar = SparseThreshold(linear_op_haar, 1, thresh_type="soft")
cost_op_haar = costObj((gradient_op,prox_op_haar), verbose=False)

#TV:
img_shape = (cartos_ali_sq.shape[1], cartos_ali_sq.shape[2])
nb_slices = cartos_ali_sq.shape[1]
linear_op_tv = linear.HOTV_3D(img_shape, nb_slices, order=1)
prox_op_tv = SparseThreshold(linear_op_tv, 1, thresh_type="soft")
cost_op_tv = costObj((gradient_op,prox_op_tv), verbose= False)

#HOTV, order = 3:
img_shape = (cartos_ali_sq.shape[1], cartos_ali_sq.shape[2])
nb_slices = cartos_ali_sq.shape[1]
linear_op_hotv3 = linear.HOTV_3D(img_shape, nb_slices, order=3)
prox_op_hotv3 = SparseThreshold(linear_op_hotv3, 1, thresh_type="soft")
cost_op_hotv3 = costObj((gradient_op,prox_op_hotv3), verbose=False)

# Step 4: Run the Condat-Vu sparse reconstruction:

rec_ = []

regul_param = [0.0003, 0.0003, 0.0007, 0.0007, 0.0007]
sparse_op = ['TV', 'HOTV_3', 'Haar', 'Bior4.4_Undecimated', 'Bior4.4_Decimated']
linear_op = [linear_op_tv, linear_op_hotv3, linear_op_haar, linear_op_bior4_undecimated, linear_op_bior4_decimated]
prox_op = [prox_op_tv, prox_op_hotv3, prox_op_haar, prox_op_bior4_undecimated, prox_op_bior4_decimated]
cost_op = [cost_op_tv, cost_op_hotv3, cost_op_haar, cost_op_bior4_undecimated, cost_op_bior4_decimated]

"""
# For selecting the regularization parameter:

mu=[100,200,300,400,500]

for k,m in enumerate(mu):
    reconstruction, wt_coeff, costs, metrics = reconstruct.sparse_rec_condatvu(
        gradient_op,
        linear_op[0],
        prox_op[0],
        cost_op[0],
        mu=m,
        max_nb_of_iter=300,
        nb_of_reweights=2,
        add_positivity=True,
        verbose=0)
    rec_.append(np.abs(reconstruction))
""" 

for k,m in enumerate(sparse_op):
    reconstruction, coeff, costs, metrics = reconstruct.sparse_rec_condatvu(
        gradient_op,
        linear_op[k],
        prox_op[k],
        cost_op[k],
        mu=regul_param[k],
        max_nb_of_iter=300,
        nb_of_reweights=1,
        add_positivity=True,
        verbose=0)
    rec_.append(np.abs(reconstruction))
    
## For saving the 3D reconstructions: 
#for k,m in enumerate(sparse_op):
#    imsave('EELS_GST_1_5_ali_'+ '_' + str(sparse_op[k]) + '_mu_'+ str(regul_param[k])+ '.tif', np.asarray(np.abs(rec_[k])).astype('float32')) 
    

In [None]:
# Plot of the reconstructions:

fig, ((ax0, ax1, ax2), (ax3, ax4, ax5)) = plt.subplots(2, 3, figsize = (15,10))

ax0.imshow(rec_[0][:,:,44],cmap=plt.cm.gray)
ax0.axis('off')
ax0.set_title(str(sparse_op[0]),{'fontsize': 20})

ax1.imshow(rec_[1][:,:,44],cmap=plt.cm.gray)
ax1.axis('off')
ax1.set_title(str(sparse_op[1]),{'fontsize': 20})

ax2.imshow(rec_[2][:,:,44],cmap=plt.cm.gray)
ax2.axis('off')
ax2.set_title(str(sparse_op[2]),{'fontsize': 20})

ax3.imshow(rec_[3][:,:,44],cmap=plt.cm.gray)
ax3.axis('off')
ax3.set_title(str(sparse_op[3]),{'fontsize': 20})

ax4.imshow(rec_[4][:,:,44],cmap=plt.cm.gray)
ax4.axis('off')
ax4.set_title(str(sparse_op[4]),{'fontsize': 20})