# Jupyter notebook based on ImageD11 to process scanning 3DXRD data
# Written by Haixing Fang, Jon Wright and James Ball
## Date: 12/10/2024

This notebook will try to reconstruct grain shapes and positions from the grain orientations you found in the first notebook.  
This notebook (and the tomo route in general) works best for low levels of deformation.  
If it doesn't seem to work well, try the point-by-point route instead!

In [None]:
import os

os.environ['OMP_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'

In [None]:
exec(open('/data/id11/nanoscope/install_ImageD11_from_git.py').read())

In [None]:
# this cell is tagged with 'parameters'
# to view the tag, select the cell, then find the settings gear icon (right or left sidebar) and look for Cell Tags

# python environment stuff
PYTHONPATH = setup_ImageD11_from_git( ) # ( os.path.join( os.environ['HOME'],'Code'), 'ImageD11_git' )

# dataset file to import
dset_path = 'si_cube_test/processed/Si_cube/Si_cube_S3DXRD_nt_moves_dty/Si_cube_S3DXRD_nt_moves_dty_dataset.h5'

# which phase to index
phase_str = 'Si'

# peak filtration parameters
cf_strong_frac = 0.993
cf_strong_dstol = 0.005

# If the sinograms are only half-sinograms (we scanned dty across half the sample rather than the full sample), set the below to true:
is_half_scan = False
# If we did halfmask, choose the radius to mask in the centre of the reconstruction (normally hot pixels)
halfmask_radius = 25

# assign peaks to the grains with hkl tolerance peak_assign_tol
peak_assign_tol = 0.05

# We can interactively draw a mask
draw_mask_interactive = True
# or we can threshold with Otsu, or a manual threshold value:
# e.g. manual_threshold = 0.006
manual_threshold = None

# tolerance for building sinograms from assigned peaks
hkltol = 0.25

# We can optionally correct each row of the sinogram by the ring current of that rotation
# This helps remove artifacts in the reconstruction
correct_sinos_with_ring_current = True

first_tmap_cutoff_level = 0.4

# how many iterations for Astra reconstruction?
niter = 500

second_tmap_cutoff_level = 0.05

In [None]:
# import functions we need

import concurrent.futures

%matplotlib ipympl

import h5py
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt

import ImageD11.columnfile
from ImageD11.grain import grain
from ImageD11.peakselect import select_ring_peaks_by_intensity
from ImageD11.sinograms.sinogram import GrainSinogram, build_slice_arrays, write_slice_recon, read_h5, write_h5, get_2d_peaks_from_4d_peaks
from ImageD11.sinograms.roi_iradon import run_iradon
from ImageD11.sinograms.tensor_map import TensorMap
from ImageD11.sinograms.geometry import sino_shift_and_pad
import ImageD11.sinograms.dataset
import ImageD11.nbGui.nb_utils as utils
from ImageD11.nbGui.draw_mask import InteractiveMask, threshold_mask

import ipywidgets as widgets
from ipywidgets import interact

In [None]:
# USER: Pass path to dataset file

ds = ImageD11.sinograms.dataset.load(dset_path)
   
sample = ds.sample
dataset = ds.dsname
rawdata_path = ds.dataroot
processed_data_root_dir = ds.analysisroot

print(ds)
print(ds.shape)

In [None]:
# load phases from parameter file

ds.phases = ds.get_phases_from_disk()
ds.phases.unitcells

In [None]:
# pick a phase

ucell = ds.phases.unitcells[phase_str]

In [None]:
# Import 4D peaks

cf_4d = ds.get_cf_4d_from_disk()

ds.update_colfile_pars(cf_4d, phase_name=phase_str)

print(f"Read {cf_4d.nrows} 4D peaks")

In [None]:
# here we are filtering our peaks (cf_4d) to select only the strongest ones
# this time as opposed to indexing, our frac is slightly weaker but we are NOT filtering in dstar!!!!!
# this means many more peaks per grain = stronger sinograms

# USER: modify the "cf_strong_frac" parameter and re-run the cell until the orange dot sits nicely on the "elbow" of the blue line
# this indicates the fractional intensity cutoff we will select
# if the blue line does not look elbow-shaped in the logscale plot, try changing the "doplot" parameter (the y scale of the logscale plot) until it does

cf_strong = select_ring_peaks_by_intensity(cf_4d, frac=cf_strong_frac, dstol=cf_strong_dstol, dsmax=cf_4d.ds.max(), doplot=0.5)
print(cf_4d.nrows)
cf_strong.nrows

In [None]:
# now we can take a look at the intensities of the remaining peaks

# we can skip peaks to speed up plotting if needed
skip = 1

fig, ax = plt.subplots(figsize=(10, 5), constrained_layout=True)

ucell.makerings(cf_4d.ds.max())

ax.plot(cf_4d.ds[::skip], cf_4d.sum_intensity[::skip],',', label='cf_4d')
ax.plot(cf_strong.ds[::skip], cf_strong.sum_intensity[::skip],',', label='cf_strong')
ax.plot(ucell.ringds, [1e4,]*len(ucell.ringds), '|', ms=90, c="red")

ax.semilogy()

ax.set_xlabel("Dstar")
ax.set_ylabel("Intensity")
ax.legend()

plt.show()

In [None]:
# import the grains from disk

grains = ds.get_grains_from_disk(phase_str)
print(f"{len(grains)} grains imported")

In [None]:
# assign peaks to the grains with hkl tolerance peak_assign_tol

utils.assign_peaks_to_grains(grains, cf_strong, peak_assign_tol)

for grain_label, g in enumerate(grains):
    g.npks_4d = np.sum(cf_strong.grain_id == grain_label)

In [None]:
# let's make a GrainSinogram object for each grain

grainsinos = [GrainSinogram(g, ds) for g in grains]

In [None]:
# Now let's determine the positions of each grain from the 4D peaks

for grain_label, gs in enumerate(grainsinos):
    gs.update_lab_position_from_peaks(cf_strong, grain_label)

In [None]:
# We can also determine the RGB IPF colours of the grains which will be useful for plotting

utils.get_rgbs_for_grains(grains)

In [None]:
utils.plot_all_ipfs(grains)

In [None]:
# Now we can plot our grain positions and RGB colours:

# plt.style.use('dark_background')
fig, ax = plt.subplots(2,2, figsize=(12,12))
a = ax.ravel()
x = [g.translation[0] for g in grains]
y = [g.translation[1] for g in grains]
s = [g.npks_4d/10 for g in grains]
a[0].scatter(y, x, c=[g.rgb_z for g in grains], s=s)
a[0].set(title='IPF color Z',  aspect='equal')
a[1].scatter(y, x, c=[g.rgb_y for g in grains], s=s)
a[1].set(title='IPF color Y', aspect='equal')
a[2].scatter(y, x, c=[g.rgb_x for g in grains], s=s)
a[2].set(title='IPF color X',  aspect='equal')
a[3].scatter(y, x, c=s)
a[3].set(title='Number of 4D peaks', aspect='equal')

fig.supxlabel("<- Sample y (transverse)")
fig.supylabel("Sample x (beam) ->")

for a in ax.ravel():
    a.invert_xaxis()

plt.show()

In [None]:
# we need to determine what the value of dty is where the rotation axis intercepts the beam
# we'll call this y0
# should be the result of the centre-of-mass fit

fig, ax = plt.subplots()

sample_y0s = [gs.recon_y0 for gs in grainsinos]

ax.plot(sample_y0s)

plt.show()

y0 = np.median(sample_y0s)

print('y0 is', y0)

if is_half_scan:
    ds.correct_bins_for_half_scan(y0=y0)

# try to automatically determine the sinogram shift and the padding from the y0 values
shift, pad = sino_shift_and_pad(y0, len(ds.ybincens), min(ds.ybincens), ds.ystep)

print('shift is', shift)
print('pad is', pad)

# update the grainsinogram parameters accordingly:

for gs in grainsinos:
    gs.update_recon_parameters(y0=y0, shift=shift, pad=pad)

Our next task is to determine a reconstruction mask for the entire sample.

This should adequately differentiate between sample and air.

In [None]:
whole_sample_sino = ds.sinohist(omega=ds.pk2d['omega'], dty=ds.pk2d['dty'], weights=np.power(ds.pk2d['sum_intensity'], 0.1)).T

fig, ax = plt.subplots()
ax.imshow(whole_sample_sino, aspect='auto', vmin=0)
plt.show()

In [None]:
# now perform the tomographic reconstruction:

nthreads = len(os.sched_getaffinity(os.getpid()))

whole_sample_recon = run_iradon(whole_sample_sino, ds.obincens, pad, shift, workers=nthreads, apply_halfmask=is_half_scan, mask_central_zingers=is_half_scan)

In [None]:
# Now we generate a whole-sample mask for the image

if draw_mask_interactive:
    masker = InteractiveMask(whole_sample_recon)
else:
    whole_sample_mask = threshold_mask(whole_sample_recon, manual_threshold=manual_threshold, doplot=True)

In [None]:
if draw_mask_interactive:
    whole_sample_mask = masker.get_mask(doplot=True)

In [None]:
# now we have a whole-sample reconstruction we can use as a sample mask
# let's build the sinograms for our grains
# before we do this, we need to determine our 2D peaks that will be used for the sinogram
# here we can get them from the 4D peaks:

gord, inds = get_2d_peaks_from_4d_peaks(ds.pk2d, cf_strong)

for grain_label, gs in enumerate(tqdm(grainsinos)):
    gs.prepare_peaks_from_4d(cf_strong, gord, inds, grain_label, hkltol)

In [None]:
# now we can actually generate the sinograms

for gs in tqdm(grainsinos):
    gs.build_sinogram()

In [None]:
# optionally correct the halfmask:

if is_half_scan:
    for gs in grainsinos:
        gs.correct_halfmask()

In [None]:
# Show sinogram of single grain

gs = grainsinos[0]

fig, ax = plt.subplots()

ax.imshow(gs.ssino, aspect='auto')
ax.set_title("ssino")

plt.show()

In [None]:
# We can optionally correct each row of the sinogram by the ring current of that rotation
# This helps remove artifacts in the reconstruction

if correct_sinos_with_ring_current:
    ds.get_ring_current_per_scan()
    
    for gs in grainsinos:
        gs.correct_ring_current(is_half_scan=is_half_scan)

In [None]:
# Show sinogram of single grain

gs = grainsinos[0]

fig, ax = plt.subplots()

ax.imshow(gs.ssino, aspect='auto')
ax.set_title("ssino")

plt.show()

In [None]:
# let's try out an iradon reconstruction

gs = grainsinos[0]

# update the parameters used for the iradon reconstruction

gs.update_recon_parameters(pad=pad, shift=shift, mask=whole_sample_mask, y0=y0)

# perform the reconstruction

gs.recon()

if is_half_scan:

    gs.mask_central_zingers("iradon", radius=halfmask_radius)

# view the result

fig, axs = plt.subplots(1,2, figsize=(10,5))
axs[0].imshow(gs.ssino, aspect='auto')
axs[0].set_title("ssino")
axs[1].imshow(gs.recons["iradon"], vmin=0, origin="lower")
axs[1].set_title("ID11 iradon")
axs[1].set_xlabel("<-- Sample Y")
axs[1].set_ylabel("Sample X")

plt.show()

In [None]:
# once you're happy with the reconstruction parameters used, set them for all the grains

for gs in grainsinos:
    gs.update_recon_parameters(pad=pad, shift=shift, mask=whole_sample_mask, y0=y0)

In [None]:
# reconstruct all grains in parallel

nthreads = len(os.sched_getaffinity(os.getpid()))

with concurrent.futures.ThreadPoolExecutor(max_workers= max(1,nthreads-1)) as pool:
    for i in tqdm(pool.map(GrainSinogram.recon, grainsinos), total=len(grainsinos)):
        pass

if is_half_scan:
    for gs in grainsinos:
        gs.mask_central_zingers("iradon", radius=halfmask_radius)

In [None]:
fig, a = plt.subplots(1,2,figsize=(10,5))
rec = a[0].imshow(grainsinos[0].recons["iradon"], vmin=0, origin="lower")
sin = a[1].imshow(grainsinos[0].ssino, aspect='auto')
a[0].set_xlabel("<-- Sample Y")
a[0].set_ylabel("Sample X")

# Function to update the displayed image based on the selected frame
def update_frame(i):
    rec.set_array(grainsinos[i].recons["iradon"])
    sin.set_array(grainsinos[i].ssino)
    a[0].set(title=str(i))
    fig.canvas.draw()

# Create a slider widget to select the frame number
frame_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=len(grains) - 1,
    step=1,
    description='Grain:'
)

interact(update_frame, i=frame_slider)

plt.show()

In [None]:
# let's assemble all the recons together into a TensorMap

tensor_map = TensorMap.from_grainsinos(grainsinos, cutoff_level=first_tmap_cutoff_level)

In [None]:
# plot initial output

tensor_map.plot("ipf_z")
tensor_map.plot("labels")
tensor_map.plot("intensity")

In [None]:
# we can clean up these reconstructions using an MLEM iterative recon
# we will carry this out using ASTRA on the GPU on the cluster
# the ASTRA EM_CUDA method will be used
# note that the mask will not be applied - normally not needed for ASTRA EM_CUDA

In [None]:
# choose the number of iterations
# experience shows 500 is good, and pretty quick on the GPU

for gs in grainsinos:
    gs.update_recon_parameters(pad=pad, shift=shift, mask=whole_sample_mask, niter=niter, y0=y0)

In [None]:
# save the GrainSinogram objects to disk

write_h5(ds.grainsfile, grainsinos, overwrite_grains=True, group_name=phase_str)

In [None]:
# prepare ASTRA bash scripts to run it on the cluster

bash_script_path = utils.prepare_astra_bash(ds, ds.grainsfile, PYTHONPATH, group_name=phase_str)

In [None]:
# submit ASTRA jobs to cluster

utils.slurm_submit_and_wait(bash_script_path, 10)

In [None]:
# re-import our reconstructed grains

grainsinos = read_h5(ds.grainsfile, ds, group_name=phase_str)
# re-associate grainsino grain objects to existing grain objects

for gs, g in zip(grainsinos, grains):
    gs.grain = g
    gs.ds = ds

In [None]:
# look at all our ASTRA recons in a grid

n_grains_to_plot = min(25, len(grainsinos))

grains_step = len(grainsinos)//n_grains_to_plot

grid_size = np.ceil(np.sqrt(len(grainsinos[::grains_step]))).astype(int)
nrows = (len(grainsinos[::grains_step])+grid_size-1)//grid_size

if grid_size == 1:
    fig, ax = plt.subplots(figsize=(10,10), layout="constrained")
    gs = grainsinos[0]
    ax.imshow(gs.recons["astra"], vmin=0, origin="lower")
else:
    fig, axs = plt.subplots(grid_size, nrows, figsize=(10,10), layout="constrained", sharex=True, sharey=True)
    for i, ax in enumerate(axs.ravel()):
        if i < len(grainsinos[::grains_step]):
        # get corresponding grain for this axis
            gs = grainsinos[::grains_step][i]
            ax.imshow(gs.recons["astra"], vmin=0, origin="lower")
            # ax.invert_yaxis()
            ax.set_title(i)
    
plt.show()

In [None]:

tensor_map_astra = TensorMap.from_grainsinos(grainsinos, cutoff_level=second_tmap_cutoff_level, method="astra")

In [None]:
tensor_map_astra.plot("ipf_z")
tensor_map_astra.plot("labels")
tensor_map_astra.plot("intensity")

In [None]:
# write our results to disk

write_h5(ds.grainsfile, grainsinos, overwrite_grains=True, group_name=phase_str)

In [None]:
# write the TensorMap to disk too

tensor_map_astra.to_h5(ds.grainsfile, h5group='TensorMap_' + phase_str)

In [None]:
# we can also write an XDMF file so you can visualise the TensorMap with ParaView

tensor_map_astra.to_paraview(ds.grainsfile, h5group='TensorMap_' + phase_str)

In [None]:
ds.save()