# Jupyter notebook based on ImageD11 to process scanning 3DXRD data
# Written by Haixing Fang, Jon Wright and James Ball
## Date: 28/03/2024

In [None]:
# USER: Change the path below to point to your local copy of ImageD11:

import os

home_dir = !echo $HOME
home_dir = str(home_dir[0])

# USER: You can change this location if you want

id11_code_path = os.path.join(home_dir, "Code/ImageD11")

import sys

sys.path.insert(0, id11_code_path)

In [None]:
# import functions we need

import concurrent.futures

%matplotlib ipympl

import h5py
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt

import ImageD11.columnfile
import ImageD11.sinograms.dataset
from ImageD11.grain import grain
from ImageD11.peakselect import select_ring_peaks_by_intensity
from ImageD11.sinograms.sinogram import GrainSinogram, build_slice_arrays, write_slice_recon, write_h5, get_2d_peaks_from_4d_peaks, run_astra
from ImageD11.sinograms.roi_iradon import run_iradon

from skimage.filters import threshold_otsu
from skimage.morphology import convex_hull_image

import ImageD11.nbGui.nb_utils as utils

import ipywidgets as widgets
from ipywidgets import interact

In [None]:
# NOTE: For old datasets before the new directory layout structure, we don't distinguish between RAW_DATA and PROCESSED_DATA

### USER: specify your experimental directory

rawdata_path = "/data/visitor/ihma439/id11/20231211/RAW_DATA"

!ls -lrt {rawdata_path}

### USER: specify where you want your processed data to go

processed_data_root_dir = "/data/visitor/ihma439/id11/20231211/PROCESSED_DATA/James/nb_testing"

In [None]:
# USER: pick a sample and a dataset you want to segment

sample = "FeAu_0p5_tR_nscope"
dataset = "top_100um"

In [None]:
# desination of H5 files

dset_path = os.path.join(processed_data_root_dir, sample, f"{sample}_{dataset}", f"{sample}_{dataset}_dataset.h5")

In [None]:
# Load the dataset
ds = ImageD11.sinograms.dataset.load(dset_path)

In [None]:
# If the sinograms are only half-sinograms (we scanned dty across half the sample rather than the full sample), set the below to true:
is_half_scan = False

In [None]:
if is_half_scan:
    ds.correct_bins_for_half_scan()

In [None]:
# Import 4D peaks

cf_4d = ds.get_cf_4d_from_disk()

cf_4d.parameters.loadparameters(ds.parfile)
cf_4d.updateGeometry()

print(f"Read {cf_4d.nrows} 4D peaks")

In [None]:
# here we are filtering our peaks (cf_4d) to select only the strongest ones
# this time as opposed to indexing, our frac is slightly weaker but we are NOT filtering in dstar!!!!!
# this means many more peaks per grain = stronger sinograms

# USER: modify the "frac" parameter below and re-run the cell until the orange dot sits nicely on the "elbow" of the blue line
# this indicates the fractional intensity cutoff we will select
# if the blue line does not look elbow-shaped in the logscale plot, try changing the "doplot" parameter (the y scale of the logscale plot) until it does

cf_strong_frac = 0.995
cf_strong_dstol = 0.005

cf_strong = select_ring_peaks_by_intensity(cf_4d, frac=cf_strong_frac, dstol=cf_strong_dstol, dsmax=cf_4d.ds.max(), doplot=0.9)
print(cf_4d.nrows)
cf_strong.nrows

In [None]:
# import the grains from disk

grains = ds.get_grains_from_disk()
print(f"{len(grains)} grains imported")

In [None]:
# assign peaks to the grains

peak_assign_tol = 0.25
utils.assign_peaks_to_grains(grains, cf_strong, peak_assign_tol)

for grain_label, g in enumerate(grains):
    g.npks_4d = np.sum(cf_strong.grain_id == grain_label)

In [None]:
# let's make a GrainSinogram object for each grain

grainsinos = [GrainSinogram(g, ds) for g in grains]

In [None]:
# Now let's determine the positions of each grain from the 4D peaks

for grain_label, gs in enumerate(grainsinos):
    gs.update_lab_position_from_peaks(cf_strong, grain_label)

In [None]:
# We can also determine the RGB IPF colours of the grains which will be useful for plotting
# To do this, we first need to set a reference unitcell for each grain
# This will be used to determine the Orix Phase and therefore Orix Orientation

cf_pars = cf_strong.parameters.get_parameters()
spacegroup = 229  # spacegroup for BCC iron
cf_pars["cell_lattice_[P,A,B,C,I,F,R]"] = spacegroup

ref_ucell = ImageD11.unitcell.unitcell_from_parameters(cf_pars)

for g in grains:
    g.ref_unitcell = ref_ucell

# Now colours should work

utils.get_rgbs_for_grains(grains)

In [None]:
utils.plot_all_ipfs(grains)

In [None]:
# Now we can plot our grain positions and RGB colours:

# plt.style.use('dark_background')
fig, ax = plt.subplots(2,2, figsize=(12,12))
a = ax.ravel()
x = [g.translation[0] for g in grains]
y = [g.translation[1] for g in grains]
s = [g.npks_4d/10 for g in grains]
a[0].scatter(y, x, c=[g.rgb_z for g in grains], s=s)
a[0].set(title='IPF color Z',  aspect='equal')
a[1].scatter(y, x, c=[g.rgb_y for g in grains], s=s)
a[1].set(title='IPF color Y', aspect='equal')
a[2].scatter(y, x, c=[g.rgb_x for g in grains], s=s)
a[2].set(title='IPF color X',  aspect='equal')
a[3].scatter(y, x, c=s)
a[3].set(title='Number of 4D peaks', aspect='equal')

fig.supxlabel("<- Lab y (transverse)")
fig.supylabel("Lab x (beam) ->")

for a in ax.ravel():
    a.invert_xaxis()

plt.show()

In [None]:
# we need to determine what the value of dty is where the rotation axis intercepts the beam
# we'll call this y0
# should be the result of the centre-of-mass fit

fig, ax = plt.subplots()

sample_y0s = [gs.recon_y0 for gs in grainsinos]

ax.plot(sample_y0s)

plt.show()

y0 = np.median(sample_y0s)

print('y0 is', y0)

# update the y0 for each grain with the median y0:

for gs in grainsinos:
    gs.update_recon_parameters(y0=y0)

# the shift we have to apply to the reconstructions is equal to -y0/ystep (in integer coords)

shift = -y0/ds.ystep

print('shift is', shift)

In [None]:
# now let's do a whole-sample tomographic reconstruction
# generate sinogram for whole sample

whole_sample_sino, xedges, yedges = np.histogram2d(cf_4d.dty, cf_4d.omega, bins=[ds.ybinedges, ds.obinedges])

fig, ax = plt.subplots()
ax.imshow(whole_sample_sino, interpolation="nearest", vmin=0)
ax.set_aspect(4)
plt.show()

In [None]:
# "quick" whole-sample reconstruction

pad = 50

whole_sample_recon = run_astra(whole_sample_sino, ds.obincens, pad=pad, shift=shift, astra_method="FBP_CUDA", niter=100)

In [None]:
# without a mask, MLEM can introduce artifacts in the corners
# so we can manually mask those out

# we can incoporate our own mask too
# by modifying the below function

def apply_manual_mask(mask_in):
    mask_out = mask_in.copy()
    
    # mask_out[200:, 250:] = 0
    
    return mask_out

# we should be able to easily segment this using scikit-image
recon_man_mask = apply_manual_mask(whole_sample_recon)

# we can also override the threshold if we don't like it:
# manual_threshold = 0.05
manual_threshold = None

if manual_threshold is None:
    thresh = threshold_otsu(recon_man_mask)
else:
    thresh = manual_threshold

binary = recon_man_mask > thresh

chull = convex_hull_image(binary)

whole_sample_mask = chull

fig, axs = plt.subplots(1, 3, sharex=True, sharey=True, constrained_layout=True)
axs[0].imshow(recon_man_mask, vmin=0, origin="lower")
axs[1].imshow(binary, origin="lower")
axs[2].imshow(chull, origin="lower")

axs[0].set_title("Reconstruction")
axs[1].set_title("Binarised threshold")
axs[2].set_title("Convex hull")

fig.supxlabel("<-- Y axis")
fig.supylabel("Beam >")

plt.show()

In [None]:
# now we have a whole-sample reconstruction we can use as a sample mask
# let's build the sinograms for our grains
# before we do this, we need to determine our 2D peaks that will be used for the sinogram
# here we can get them from the 4D peaks:

hkltol = 0.25

gord, inds = get_2d_peaks_from_4d_peaks(ds.pk2d, cf_strong)

for grain_label, gs in enumerate(tqdm(grainsinos)):
    gs.prepare_peaks_from_4d(cf_strong, gord, inds, grain_label, hkltol)

In [None]:
# now we can actually generate the sinograms

for gs in tqdm(grainsinos):
    gs.build_sinogram()

In [None]:
# optionally correct the halfmask:

if is_half_scan:
    for gs in grainsinos:
        gs.correct_halfmask()

In [None]:
# Show sinogram of single grain

gs = grainsinos[0]

fig, ax = plt.subplots()

ax.imshow(gs.ssino, aspect='auto')
ax.set_title("ssino")

plt.show()

In [None]:
# We can optionally correct each row of the sinogram by the ring current of that rotation
# This helps remove artifacts in the reconstruction

correct_sinos_with_ring_current = True
if correct_sinos_with_ring_current:
    ds.get_ring_current_per_scan()
    
    for gs in grainsinos:
        gs.correct_ring_current(is_half_scan=is_half_scan)

In [None]:
# Show sinogram of single grain

gs = grainsinos[0]

fig, ax = plt.subplots()

ax.imshow(gs.ssino, aspect='auto')
ax.set_title("ssino")

plt.show()

In [None]:
# go straight to ASTRA

gs = grainsinos[0]

# update the parameters used for the iradon reconstruction

niter = 500

gs.update_recon_parameters(pad=pad, shift=shift, y0=y0, niter=niter, mask=whole_sample_mask)

gs.recon(method="astra", astra_method="EM_CUDA")

if is_half_scan:
    halfmask_radius = 25
    gs.mask_central_zingers("iradon", radius=halfmask_radius)

# view the result

fig, axs = plt.subplots(1,2, figsize=(10,5))
axs[0].imshow(gs.ssino, aspect='auto')
axs[0].set_title("ssino")
axs[1].imshow(gs.recons["astra"], vmin=0, origin="lower")
axs[1].set_title("Astra")

plt.show()

In [None]:
# once you're happy with the reconstruction parameters used, set them for all the grains

for gs in grainsinos:
    gs.update_recon_parameters(pad=pad, shift=shift, y0=y0, niter=niter, mask=whole_sample_mask)

In [None]:
# reconstruct all grains

for gs in tqdm(grainsinos):
    gs.recon(method="astra", astra_method="EM_CUDA")

    if is_half_scan:
        gs.mask_central_zingers("astra", radius=halfmask_radius)

In [None]:
fig, a = plt.subplots(1,2,figsize=(10,5))
rec = a[0].imshow(grainsinos[0].recons["astra"], vmin=0, origin="lower")
sin = a[1].imshow(grainsinos[0].ssino, aspect='auto')

# Function to update the displayed image based on the selected frame
def update_frame(i):
    rec.set_array(grainsinos[i].recons["astra"])
    sin.set_array(grainsinos[i].ssino)
    a[0].set(title=str(i))
    fig.canvas.draw()

# Create a slider widget to select the frame number
frame_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=len(grains) - 1,
    step=1,
    description='Grain:'
)

interact(update_frame, i=frame_slider)

plt.show()

In [None]:
# Let's assemble all the recons into one map

rgb_x_array, rgb_y_array, rgb_z_array, grain_labels_array, raw_intensity_array = build_slice_arrays(grainsinos, cutoff_level=0.3, method="astra")

In [None]:
# plot initial output

fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(rgb_z_array, origin="lower")
plt.show()

In [None]:
fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(grain_labels_array, origin="lower")  # originally 1,2,0
ax.set_title("Grain label map")
plt.show()

In [None]:
fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(raw_intensity_array, origin="lower")
ax.set_title("Raw intensity array")
plt.show()

In [None]:
# look at all our recons in a grid

n_grains_to_plot = 25

grains_step = len(grainsinos)//n_grains_to_plot

grid_size = np.ceil(np.sqrt(len(grainsinos[::grains_step]))).astype(int)
nrows = (len(grainsinos[::grains_step])+grid_size-1)//grid_size

fig, axs = plt.subplots(grid_size, nrows, figsize=(10,10), layout="constrained", sharex=True, sharey=True)
for i, ax in enumerate(axs.ravel()):
    if i < len(grainsinos[::grains_step]):
    # get corresponding grain for this axis
        gs = grainsinos[::grains_step][i]
        ax.imshow(gs.recons["astra"], vmin=0, origin="lower")
        # ax.invert_yaxis()
        ax.set_title(i)
    
plt.show()

In [None]:
# write our results to disk

write_h5(ds.grainsfile, grainsinos, write_grains_too=True)

In [None]:
# write the slice maps to disk too

write_slice_recon(ds.grainsfile, slice_arrays)

In [None]:
ds.save()

In [None]:
if 1:
    raise ValueError("Change the 1 above to 0 to allow 'Run all cells' in the notebook")

In [None]:
# Now that we're happy with our indexing parameters, we can run the below cell to do this in bulk for many samples/datasets
# by default this will do all samples in sample_list, all datasets with a prefix of dset_prefix
# you can add samples and datasets to skip in skips_dict

skips_dict = {
    "FeAu_0p5_tR_nscope": ["top_-50um", "top_-100um"]
}

dset_prefix = "top"

sample_list = ["FeAu_0p5_tR_nscope"]
    
samples_dict = utils.find_datasets_to_process(rawdata_path, skips_dict, dset_prefix, sample_list)
    
# manual override:
# samples_dict = {"FeAu_0p5_tR_nscope": ["top_100um", "top_200um"]}
    
# now we have our samples_dict, we can process our data:

for sample, datasets in samples_dict.items():
    for dataset in datasets:
        print(f"Processing dataset {dataset} in sample {sample}")
        dset_path = os.path.join(processed_data_root_dir, sample, f"{sample}_{dataset}", f"{sample}_{dataset}_dataset.h5")
        if not os.path.exists(dset_path):
            print(f"Missing DataSet file for {dataset} in sample {sample}, skipping")
            continue
        
        print("Importing DataSet object")
        
        ds = ImageD11.sinograms.dataset.load(dset_path)
        print(f"I have a DataSet {ds.dset} in sample {ds.sample}")
        
        if not os.path.exists(ds.grainsfile):
            print(f"Missing grains file for {dataset} in sample {sample}, skipping")
            continue
            
        # check grains file for existance of slice_recon, skip if it's there
        with h5py.File(ds.grainsfile, "r") as hin:
            if "slice_recon" in hin.keys():
                print(f"Already reconstructed {dataset} in {sample}, skipping")
                continue
                
        if is_half_scan:
            ds.correct_bins_for_half_scan()
        
        print("Peaks")
        cf_4d = ds.get_cf_4d_from_disk()
        cf_4d.parameters.loadparameters(ds.parfile)
        cf_4d.updateGeometry()
        
        cf_strong = select_ring_peaks_by_intensity(cf_4d, frac=cf_strong_frac, dstol=cf_strong_dstol, dsmax=cf_4d.ds.max())
        
        print("Grains")
        grains = ds.get_grains_from_disk()
        utils.assign_peaks_to_grains(grains, cf_strong, peak_assign_tol)
        
        grainsinos = [GrainSinogram(g, ds) for g in grains]
        
        print("Fitting grain positions")
        for grain_label, gs in enumerate(grainsinos):
            gs.update_lab_position_from_peaks(cf_strong, grain_label)
        
        sample_y0s = [gs.recon_y0 for gs in grainsinos]
        y0 = np.median(sample_y0s)
        shift = -y0/ds.ystep
        
        cf_pars = cf_strong.parameters.get_parameters()
        cf_pars["cell_lattice_[P,A,B,C,I,F,R]"] = spacegroup
        ref_ucell = ImageD11.unitcell.unitcell_from_parameters(cf_pars)

        for g in grains:
            g.ref_unitcell = ref_ucell
        
        print("Determining RGB colours")
        utils.get_rgbs_for_grains(grains)
        
        print("Whole sample mask recon")
        whole_sample_sino, xedges, yedges = np.histogram2d(cf_4d.dty, cf_4d.omega, bins=[ds.ybinedges, ds.obinedges])
        whole_sample_recon = run_iradon(whole_sample_sino, ds.obincens, pad, shift, workers=nthreads, apply_halfmask=is_half_scan, mask_central_zingers=is_half_scan)
        
        recon_man_mask = apply_manual_mask(whole_sample_recon)
        
        if manual_threshold is None:
            thresh = threshold_otsu(recon_man_mask)
        else:
            thresh = manual_threshold
            
        binary = recon_man_mask > thresh
        whole_sample_mask = convex_hull_image(binary)
        
        gord, inds = get_2d_peaks_from_4d_peaks(ds.pk2d, cf_strong)
        
        print("Building sinograms")
        for grain_label, gs in enumerate(tqdm(grainsinos)):
            gs.prepare_peaks_from_4d(cf_strong, gord, inds, grain_label, hkltol)
            gs.build_sinogram()
            
            if is_half_scan:
                gs.correct_halfmask()
            
        if correct_sinos_with_ring_current:
            print("Correcting for ring current")
            ds.get_ring_current_per_scan()
            for gs in grainsinos:
                gs.correct_ring_current(is_half_scan=is_half_scan)
        
        for gs in grainsinos:
            gs.update_recon_parameters(pad=pad, shift=shift, mask=whole_sample_mask, niter=niter, y0=y0)
        
        for gs in tqdm(grainsinos):
            gs.recon(method="astra")

            if is_half_scan:
                halfmask_radius = 25
                gs.mask_central_zingers("astra", radius=halfmask_radius)
        
        print("Final save")
        slice_arrays = build_slice_arrays(grainsinos, cutoff_level=cutoff_level, method="mlem")
        write_h5(ds.grainsfile, grainsinos, write_grains_too=True)
        write_slice_recon(ds.grainsfile, slice_arrays)
        
        ds.save()

print("Done!")