In [None]:
# -*- coding: utf-8 -*-
"""
Example script to serve as starting point for estimating scatter and evaluating scatter results

This notebook reads in simulated data, runs the scatter estimation via Python (equivalent of the
shell scrip t `run_scatter_0.sh`) and compares the results 
with the truth (i.e. simulation input and simulation scatter output)

With this notebook you get an idea
of what the scatter looks like and how accurate the scatter estimation is
in the ideal case (i.e. where the model exactly matches the actual scatter generation).

Prerequisite:
You should have executed the following on your command prompt
    ./run_simulations_thorax.sh

Author: Kris Thielemans
"""

In [None]:
%matplotlib widget

# Initial imports

In [None]:
import matplotlib.pyplot as plt
import stir
from stirextra import *
import os

# go to directory with input/output files

In [None]:
# adapt this path to your situation (or start everything in the exercises directory)
os.chdir(os.getenv('STIR_exercises_PATH'))

In [None]:
os.chdir('working_folder/GATE1')

# read in data from GATE1

In [None]:
prompts = stir.ProjData.read_from_file('my_prompts_g1.hs')
atten_image = stir.FloatVoxelsOnCartesianGrid.read_from_file('CTAC_g1.hv')
norm = stir.BinNormalisationFromProjData(stir.ProjData.read_from_file('my_norm_g1.hs'))
acf_factors = stir.ProjData.read_from_file('my_acfs_g1.hs')
randoms = stir.ProjData.read_from_file('my_randoms_g1.hs')

# Perform scatter estimation

In [None]:
scatter_estimator = stir.ScatterEstimation()

In [None]:
filter = stir.SeparableGaussianImageFilter3DFloat()
filter.set_fwhms(stir.make_FloatCoordinate(15,15,15))

recon = stir.OSMAPOSLReconstruction3DFloat()
recon.set_num_subsets(4)
recon.set_num_subiterations(7)
recon.set_disable_output(True)
recon.set_post_processor_sptr(filter)
objfunc = stir.PoissonLogLikelihoodWithLinearModelForMeanAndProjData3DFloat()
recon.set_objective_function(objfunc)

In [None]:
scatter_estimator.set_input_data(prompts)
scatter_estimator.set_attenuation_image_sptr(atten_image)
scatter_estimator.set_background_proj_data_sptr(randoms)
scatter_estimator.set_normalisation_sptr(norm)
scatter_estimator.set_reconstruction_method_sptr(recon)
scatter_estimator.set_attenuation_correction_proj_data_sptr(acf_factors)
scatter_estimator.set_output_scatter_estimate_prefix('p_scatter')
scatter_estimator.set_num_iterations(3)
scatter_estimator.set_up()

In [None]:
scatter_estimator.process_data()

In [None]:
estimated_scatter = scatter_estimator.get_output()
estimated_scatter.write_to_file('scatter_estimate_run0.hs')

# read in data from simulation and compare

In [None]:
# original scatter as generated by the simulation
org_scatter=stir.ProjData.read_from_file('my_scatter_g1.hs')
org_scatter_arr=to_numpy(org_scatter)
# estimated scatter
estimated_scatter_arr = to_numpy(stir.ProjData.read_from_file('scatter_estimate_run0.hs'));

In [None]:
maxforplot=org_scatter_arr.max()*1.1;

plt.figure()
ax=plt.subplot(1,2,1);
plt.imshow(org_scatter_arr[0,10,:,:,]);
plt.clim(0,maxforplot)
ax.set_title('Original simulated scatter');
plt.axis('off');

ax=plt.subplot(1,2,2);
plt.imshow(estimated_scatter_arr[0,10,:,:,]);
plt.clim(0,maxforplot);
ax.set_title('estimated scatter');
plt.axis('off');

# Display profiles through the sinogram

In [None]:
# central (over views)
plt.figure()
plt.plot(org_scatter_arr[0,10,:,192//2],'b', label='simulated')
plt.plot(estimated_scatter_arr[0,10,:,192//2],'c', label='estimated')
plt.legend();

In [None]:
# horizontal (one view)
plt.figure()
plt.plot(org_scatter_arr[0,10,1,:],'b', label='simulated')
plt.plot(estimated_scatter_arr[0,10,1,:],'c', label='estimated')
plt.legend();

The above plot seems to indicates that (at least in STIR 6.2) the default template used for the scatter estimation is too narrow for this scanner. This could be corrected by passing an explicit narrow.

# Reconstruct images with estimated and original scatter

Construct "normalisation" object that also contains the ACFs. Currently, `ChainedBinNormalisation` is not yet availabel in STIR Python, so need to do this manually.

In [None]:
full_norm_projdata =  stir.ProjDataInMemory(acf_factors)
norm.apply(full_norm_projdata)
full_norm = stir.BinNormalisationFromProjData(full_norm_projdata)
full_norm.set_up(prompts.get_exam_info(), prompts.get_proj_data_info())

construct additive sinogram

In [None]:
background = stir.ProjDataInMemory(randoms) + stir.ProjDataInMemory(estimated_scatter)
additive_term = stir.ProjDataInMemory(background)
full_norm.apply(additive_term)

In [None]:
osem = stir.OSMAPOSLReconstruction3DFloat()
osem.set_num_subsets(8)
osem.set_num_subiterations(96)
osem.set_disable_output(True)
#osem.set_post_processor_sptr(filter)
objfunc = stir.PoissonLogLikelihoodWithLinearModelForMeanAndProjData3DFloat()
osem.set_objective_function(objfunc)
objfunc.set_input_data(prompts)
objfunc.set_normalisation_sptr(full_norm)
objfunc.set_additive_proj_data_sptr(additive_term)

In [None]:
initial_image = stir.FloatVoxelsOnCartesianGrid(prompts.get_exam_info(), prompts.get_proj_data_info())
initial_image.fill(1)

In [None]:
osem.set_up(initial_image)

In [None]:
osem.reconstruct()

In [None]:
osem_image = osem.get_target_image()

In [None]:
org_image=to_numpy(stir.FloatVoxelsOnCartesianGrid.read_from_file('OSEM_recon_with_actual_scatter_96.hv'))
recon_image = stir.FloatVoxelsOnCartesianGrid.read_from_file('OSEM_recon_with_estimated_scatter_96_run0.hv')
result=to_numpy(recon_image)

In [None]:
osem_image_arr = to_numpy(osem_image)

In [None]:
org_additive_term = stir.ProjData.read_from_file('my_additive_sinogram_g1.hs')
objfunc.set_additive_proj_data_sptr(org_additive_term)
osem.set_start_subset_num(0)
osem.set_num_subiterations(96)
osem.reconstruct()
org_recon = osem.get_target_image()

In [None]:
org_image_arr = to_numpy(org_recon)

## bitmap display of images

In [None]:
maxforplot=org_image_arr.max()*1.1;

slice=10;
plt.figure();
ax=plt.subplot(1,2,1);
plt.imshow(org_image_arr[slice,:,:,]);
plt.colorbar();
plt.clim(0,maxforplot);
ax.set_title('OSEM with correct scatter')
plt.axis('off');

ax=plt.subplot(1,2,2);
plt.imshow(osem_image_arr[slice,:,:,]);
plt.clim(0,maxforplot);
plt.colorbar();
ax.set_title('OSEM with scatter estimation')
plt.axis('off');

## horizontal profiles through images

In [None]:
plt.figure();
plt.plot(org_image_arr[10,154//2,:],'b');
plt.plot(osem_image_arr[10,154//2,:],'c');
plt.legend(('actual scatter','estimated scatter'));

# Check by taking difference with measured data

We will compare the original and estimated scatter with the "remainder", i.e. prompts - randoms - forward(osem_image). This will generally be
- closer to the original scatter as the reconstruction tries to fit the image to the prompts
- have high frequency differences due to underconvergence of OSEM (and noise of course)

In [None]:
rt_matrix = stir.ProjMatrixByBinUsingRayTracing()
forward_projector = stir.ForwardProjectorByBinUsingProjMatrixByBin(rt_matrix)
forward_projection = stir.ProjDataInMemory(prompts)
forward_projector.set_up(forward_projection.get_proj_data_info(), osem_image)
#forward_projection.fill(0)

In [None]:
forward_projector.forward_project(forward_projection, recon_image)

full_norm.undo(forward_projection)

In [None]:
diff = to_numpy(stir.ProjDataInMemory(prompts) - stir.ProjDataInMemory(randoms) - forward_projection)

In [None]:
maxforplot=org_scatter_arr.max()*1.1;

plt.figure()
ax=plt.subplot(1,3,1);
plt.imshow(org_scatter_arr[0,10,:,:,]);
plt.clim(0,maxforplot)
ax.set_title('Original simulated scatter');
plt.axis('off');

ax=plt.subplot(1,3,2);
plt.imshow(estimated_scatter_arr[0,10,:,:,]);
plt.clim(0,maxforplot);
ax.set_title('estimated scatter');
plt.axis('off');

ax=plt.subplot(1,3,3);
plt.imshow(diff[0,10,:,:,]);
plt.clim(0,maxforplot);
ax.set_title('remainder');
plt.axis('off');

In [None]:
from scipy.ndimage import gaussian_filter

In [None]:
diff.shape

In [None]:
filtered = gaussian_filter(diff, (0,5,2,8))

In [None]:
plt.figure()
plt.imshow(filtered[0,10,:,:])

In [None]:
# horizontal (one view)
plt.figure()
plt.plot(org_scatter_arr[0,10,1,:],'b', label='simulated')
plt.plot(estimated_scatter_arr[0,10,1,:],'c', label='estimated')
plt.plot(filtered[0,10,1,:],'c', label='remainder after filtering')
plt.legend();