In [2]:
import astra
import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import loadmat
import scipy
import matplotlib.pyplot as plt
import plotly.graph_objects as go

def astra_reconstruction(fn_path, angles_fn_path, save_path, num_iter=50, algorithm="SIRT3D_CUDA", blur=False, visualize = False):    
    # load files and convert degrees to radians 
    proj = loadmat(fn_path)["stackNorm"]
    
    tilt_series = np.deg2rad(loadmat(angles_fn_path)['tiltdegrees'].ravel())

    #apply gaussian blur
    if blur == True: 
        proj = scipy.ndimage.gaussian_filter(proj, sigma=(0.5, 0.5, 0))
    
    # set the shape of the projections
    detector_rows = int(proj.shape[0])   # 450
    detector_cols = int(proj.shape[1])   # 450
    n_angles = int(proj.shape[2])        # 70
    
    proj_for_astra = np.transpose(proj, (1, 2, 0)) # reorder to (rows, angles, cols)
    
    # print shapes to verify
    # print("proj shape:", proj.shape)
    # print("angles shape:", tilt_series.shape)
    # print("expected (angles, rows, cols) = ", len(tilt_series), detector_rows, detector_cols)
    
    # set reconstruction volume 
    vol_geom = astra.create_vol_geom(detector_rows,detector_rows,detector_rows)
    
    detector_spacing_x, detector_spacing_y = 1.0, 1.0 # affects pixel size 
    
    # set the projection geometry and tilt series 
    proj_geom = astra.create_proj_geom('parallel3d', 
                                       detector_spacing_x, 
                                       detector_spacing_y, 
                                       detector_rows, 
                                       detector_cols,
                                       tilt_series
                                       )
    
    # sinograph data object
    sino_id = astra.data3d.create('-sino', proj_geom, proj_for_astra.astype(np.float32))
    
    # reconstruction data object
    rec_id = astra.data3d.create('-vol', vol_geom)
    
    # set the reconstruction algorithm and what to use 
    cfg = astra.astra_dict(algorithm) #
    cfg['ReconstructionDataId'] = rec_id
    cfg['ProjectionDataId'] = sino_id
    
    
    alg_id = astra.algorithm.create(cfg) # create algorithm object
    
    astra.algorithm.run(alg_id, num_iter)  # number iterations
    
    reconstruction = astra.data3d.get(rec_id) # get the resulting reconstruction 
    np.save(save_path, reconstruction) 

    if visualize:
        # show a slice of the reconstruction 
        plt.imshow(reconstruction[reconstruction.shape[0]//2, :, :])
        plt.show()
    #clean up after
    astra.algorithm.delete(alg_id)
    astra.data3d.delete(rec_id)
    astra.data3d.delete(sino_id)    
    # try realignment 
    # align all images using cross correlation py4dstem utils
    #def align_and_shift_images(

In [3]:
#assign file paths 
fn_path = "images_-2.2deg.mat"
angles_fn_path = "tiltangles.mat"

for num_iter in range(5,20,5):
    save_path = os.path.expandvars(f"$PSCRATCH/astra_test_iterations/Reconstructions_CGLS/astra_recon5_{num_iter}iter.npy")
    print(f"Running reconstruction for {num_iter} iterations")
    astra_reconstruction(fn_path, angles_fn_path, save_path, num_iter=num_iter,algorithm="CGLS3D_CUDA", blur=False)


Running reconstruction for 5 iterations
Running reconstruction for 10 iterations
Running reconstruction for 15 iterations
