In [None]:
from cil.io import NEXUSDataWriter, NEXUSDataReader

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import AxesGrid
import numpy as np
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar

from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
from mpl_toolkits.axes_grid1.inset_locator import InsetPosition
import matplotlib.patches as patches

import numpy as np

In [None]:
# create a circular mask

def create_circular_mask(h, w, center=None, radius=None):

    if center is None: # use the middle of the image
        center = (int(w/2), int(h/2))
    if radius is None: # use the smallest distance between the center and image walls
        radius = min(center[0], center[1], w-center[0], h-center[1])

    Y, X = np.ogrid[:h, :w]
    dist_from_center = np.sqrt((X - center[0])**2 + (Y-center[1])**2)

    mask = dist_from_center <= radius
    return mask

tmp_mask = create_circular_mask(256,256,radius=110)
mask = np.repeat(tmp_mask[np.newaxis, :, :], 17, axis=0)

In [None]:
# Load FBP reconstructions

fbp_recons = []
tikhonov_recons = []
tv_recons = []
dtv_recons = []

for i in [18, 36, 72, 360]:
    
    reader = NEXUSDataReader(file_name = "FBP_reconstructions/FBP_projections_{}.nxs".format(i))
    fbp_recons.append(reader.load_data()*mask)
    
    reader = NEXUSDataReader(file_name = "Tikhonov_reconstructions/TikhonovReconstruction_projections_{}.nxs".format(i))
    tikhonov_recons.append(reader.load_data()*mask)
    
    reader = NEXUSDataReader(file_name = "TV_reconstructions/TVReconstruction_projections_{}.nxs".format(i))
    tv_recons.append(reader.load_data()*mask) 
    
    reader = NEXUSDataReader(file_name = "dTV_reconstructions/dTVReconstruction_projections_{}.nxs".format(i))
    dtv_recons.append(reader.load_data()*mask)    
    

In [None]:
# Show fbp, tikhonov, tv and dtv reconstructions for the 8th time frame

ind = 8
recons = [fbp_recons[0].as_array()[ind], fbp_recons[1].as_array()[ind], fbp_recons[2].as_array()[ind], fbp_recons[3].as_array()[ind],
          tikhonov_recons[0].as_array()[ind], tikhonov_recons[1].as_array()[ind], tikhonov_recons[2].as_array()[ind], tikhonov_recons[3].as_array()[ind],
          tv_recons[0].as_array()[ind], tv_recons[1].as_array()[ind], tv_recons[2].as_array()[ind], tv_recons[3].as_array()[ind],
          dtv_recons[0].as_array()[ind], dtv_recons[1].as_array()[ind], dtv_recons[2].as_array()[ind], dtv_recons[3].as_array()[ind]]

labels_x = ["18 projections", "36 projections", "72 projections", "360 projections"]
labels_y = ["FBP", "Tikhonov", "TV", "dTV"]

# set fontszie xticks/yticks
plt.rcParams['xtick.labelsize']=15
plt.rcParams['ytick.labelsize']=15

fig = plt.figure(figsize=(20, 20))

grid = AxesGrid(fig, 111,
                nrows_ncols=(4, 4),
                axes_pad=0.08,
                cbar_mode='single',
                cbar_location='bottom',
                cbar_size = 0.5,
                cbar_pad=0.1
                )

k = 0

for ax in grid:
    im = ax.imshow(recons[k],vmin=0,vmax=0.07, cmap="inferno")
    
    axins = zoomed_inset_axes(ax, 3, loc=2)
    ip = InsetPosition(ax, [1-0.35, 1-0.40, 0.35, 0.35]) #posx, posy, width, height
    axins.set_axes_locator(ip)
    axins.imshow(recons[k], vmin=0, vmax=0.07, interpolation="none", cmap='inferno')

    x1, x2, y1, y2 = 45,95, 105, 155
    axins.set_xlim(x1, x2)
    axins.set_ylim(y2, y1)
    axins.tick_params(axis='both', which='both', left=False, top=False, bottom=False)
    plt.xticks(visible=False)
    plt.yticks(visible=False)
    rect = patches.Rectangle((x1, y1), 50, 50, linewidth=1, edgecolor='w', facecolor='none')
    ax.add_patch(rect)   
    
        
    # for the horizontal slice
    if k==0:
        ax.set_title(labels_x[0],fontsize=25)
        ax.set_ylabel(labels_y[0],fontsize=25, labelpad=20)
        
        scalebar = AnchoredSizeBar(ax.transData,
                                   44.13, '5 mm', 'lower right', 
                                   pad=0.5,
                                   color='white',
                                   frameon=False,
                                   size_vertical=5)

        ax.add_artist(scalebar)        

        
    if k==1:
        ax.set_title(labels_x[1],fontsize=25)
    if k==2:
        ax.set_title(labels_x[2],fontsize=25)
    if k==3:        
        ax.set_title(labels_x[3],fontsize=25)
        
    if k==4:        
        ax.set_ylabel(labels_y[1],fontsize=25, labelpad=20) 
        
    if k==8:        
        ax.set_ylabel(labels_y[2],fontsize=25, labelpad=20)  
        
    if k==12:        
        ax.set_ylabel(labels_y[3],fontsize=25, labelpad=20)         
                           
    if k!=0:
        ax.tick_params(axis='both', which='both', 
                           left=False, bottom=False, top=False) 
    k+=1

    ax.set_xticks([])
    ax.set_yticks([]) 
   
cbar = ax.cax.colorbar(im)
cbar.ax.set_xlabel('Attenuation', fontsize=25)

plt.show()






In [None]:
# Show fbp, tikhonov, tv and dtv reconstructions for the 1st, 5th, 10th and 16th time frames with 18 projections

# time frame
ind1 = 0
ind2 = 5
ind3 = 10
ind4 = 16

proj_case = 0

recons = [fbp_recons[3].as_array()[ind1], fbp_recons[3].as_array()[ind2], fbp_recons[3].as_array()[ind3], fbp_recons[3].as_array()[ind4],
          fbp_recons[0].as_array()[ind1], fbp_recons[0].as_array()[ind2], fbp_recons[0].as_array()[ind3], fbp_recons[0].as_array()[ind4],
          tikhonov_recons[proj_case].as_array()[ind1], tikhonov_recons[proj_case].as_array()[ind2], tikhonov_recons[proj_case].as_array()[ind3], tikhonov_recons[proj_case].as_array()[ind4],
          tv_recons[proj_case].as_array()[ind1], tv_recons[proj_case].as_array()[ind2], tv_recons[proj_case].as_array()[ind3], tv_recons[proj_case].as_array()[ind4],
          dtv_recons[proj_case].as_array()[ind1], dtv_recons[proj_case].as_array()[ind2], dtv_recons[proj_case].as_array()[ind3], dtv_recons[proj_case].as_array()[ind4]]

labels_x = ["Time-frame {}".format(ind1), "Time-frame {}".format(ind2), "Time-frame {}".format(ind3), "Time-frame {}".format(ind4) ]
labels_y = ["FBP \n 360 projections", "FBP \n 18 projections", "Tikhonov \n 18 projections", "TV \n 18 projections", "dTV \n 18 projections"]

# set fontszie xticks/yticks
plt.rcParams['xtick.labelsize']=15
plt.rcParams['ytick.labelsize']=15

fig = plt.figure(figsize=(20, 20))

grid = AxesGrid(fig, 111,
                nrows_ncols=(5, 4),
                axes_pad=0.08,
                cbar_mode='single',
                cbar_location='bottom',
                cbar_size = 0.5,
                cbar_pad=0.1
                )

k = 0

for ax in grid:

    im = ax.imshow(recons[k],vmin=0,vmax=0.07, cmap="inferno")
    
    axins = zoomed_inset_axes(ax, 3, loc=2)
    ip = InsetPosition(ax, [1-0.35, 1-0.4, 0.35, 0.35])
    
    axins.set_axes_locator(ip)
    axins.imshow(recons[k], vmin=0, vmax=0.07, interpolation="none", cmap='inferno')
    # sub region of the original image
    x1, x2, y1, y2 = 45,95, 105, 155
    axins.set_xlim(x1, x2)
    axins.set_ylim(y2, y1)
    axins.tick_params(axis='both', which='both', left=False, top=False, bottom=False)
    plt.xticks(visible=False)
    plt.yticks(visible=False)
    rect = patches.Rectangle((x1, y1), 50, 50, linewidth=1, edgecolor='w', facecolor='none')
    ax.add_patch(rect)   
        
    if k==0:

        x1, x2, y1, y2 = 178, 180, 123, 125
    
        rect = patches.Rectangle((x1, y1), 3, 3, linewidth=2, color='blue', facecolor='none')
        ax.add_patch(rect)      
        ax.annotate(r"$\mathrm{ROI}$", xy=(100,200), xytext=(100, 120), fontSize=20, color="blue")        

        scalebar = AnchoredSizeBar(ax.transData,
                           44.13, '5 mm', 'lower right', 
                           pad=0.5,
                           color='white',
                           frameon=False,
                           size_vertical=5)

        ax.add_artist(scalebar)   
     
    
    if k==0:
        ax.set_title(labels_x[0],fontsize=25)
        ax.set_ylabel(labels_y[0],fontsize=25, labelpad=15)
    if k==1:
        ax.set_title(labels_x[1],fontsize=25)
    if k==2:
        ax.set_title(labels_x[2],fontsize=25)
    if k==3:        
        ax.set_title(labels_x[3],fontsize=25)
        
    if k==4:        
        ax.set_ylabel(labels_y[1],fontsize=25, labelpad=15) 
        
    if k==8:        
        ax.set_ylabel(labels_y[2],fontsize=25, labelpad=20)  
        
    if k==12:        
        ax.set_ylabel(labels_y[3],fontsize=25, labelpad=20) 
        
    if k==16:        
        ax.set_ylabel(labels_y[4],fontsize=25, labelpad=20)         
                    
    if k!=0:
        ax.tick_params(axis='both', which='both', 
                           left=False, bottom=False, top=False) 
    k+=1

    ax.set_xticks([])
    ax.set_yticks([]) 
   
cbar = ax.cax.colorbar(im)
cbar.ax.set_xlabel('Attenuation', fontsize=25)
plt.show()




