# Jupyter notebook based on ImageD11 to process scanning 3DXRD data
# Written by Haixing Fang, Jon Wright and James Ball
## Date: 26/02/2024

In [None]:
# USER: Change the path below to point to your local copy of ImageD11:

import os

home_dir = !echo $HOME
home_dir = str(home_dir[0])

# USER: You can change this location if you want

id11_code_path = os.path.join(home_dir, "Code/ImageD11")

import sys

sys.path.insert(0, id11_code_path)

In [None]:
# import functions we need

import concurrent.futures
import timeit
import glob
import pprint
import time
from functools import partial

%matplotlib ipympl

import h5py
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt

import ImageD11.columnfile
import ImageD11.sinograms.dataset
from ImageD11.grain import grain

from skimage.filters import threshold_otsu
from skimage.morphology import convex_hull_image

import ImageD11.nbGui.nb_utils as utils

import ipywidgets as widgets
from ipywidgets import interact

In [None]:
# NOTE: For old datasets before the new directory layout structure, we don't distinguish between RAW_DATA and PROCESSED_DATA

### USER: specify your experimental directory

rawdata_path = "/data/visitor/ihma439/id11/20231211/RAW_DATA"

!ls -lrt {rawdata_path}

### USER: specify where you want your processed data to go

processed_data_root_dir = "/data/visitor/ihma439/id11/20231211/PROCESSED_DATA/James/nb_testing"

In [None]:
# USER: pick a sample and a dataset you want to segment

sample = "FeAu_0p5_tR_nscope"
dataset = "top_250um"

In [None]:
# desination of H5 files

dset_path = os.path.join(processed_data_root_dir, sample, f"{sample}_{dataset}", f"{sample}_{dataset}_dataset.h5")

In [None]:
# Load the dataset (for motor positions, not sure why these are not in peaks)
ds = ImageD11.sinograms.dataset.load(dset_path)

In [None]:
# determine ring currents for sinogram row-by-row intensity correction

utils.get_ring_current_per_scan(ds)

In [None]:
# Import 4D peaks

cf_4d = ImageD11.columnfile.columnfile(ds.col4dfile)

cf_4d.parameters.loadparameters(ds.parfile)
cf_4d.updateGeometry()

print(f"Read {cf_4d.nrows} 4D peaks")

In [None]:
grains = utils.read_s3dxrd_grains_for_recon(ds)

for grain in grains:
    # print(grain.gid)
    grain.a = np.cbrt(np.linalg.det(grain.ubi))
    
print(f"{len(grains)} grains imported")

In [None]:
# here we are filtering our peaks (cf_4d) to select only the strongest ones
# this time as opposed to indexing, our frac is slightly weaker but we are NOT filtering in dstar!!!!!
# this means many more peaks per grain = stronger sinograms

# USER: modify the "frac" parameter below and re-run the cell until the orange dot sits nicely on the "elbow" of the blue line
# this indicates the fractional intensity cutoff we will select
# if the blue line does not look elbow-shaped in the logscale plot, try changing the "doplot" parameter (the y scale of the logscale plot) until it does

cf_strong_frac = 0.995
cf_strong_dstol = 0.005

cf_strong = utils.selectpeaks(cf_4d, frac=cf_strong_frac, dstol=cf_strong_dstol, dsmax=cf_4d.ds.max(), doplot=0.9)
print(cf_4d.nrows)
cf_strong.nrows

In [None]:
# # now let's do a whole-sample tomographic reconstruction

In [None]:
# If the sinograms are only half-sinograms (we scanned dty across half the sample rather than the full sample), set the below to true:
is_half_scan = False

In [None]:
if is_half_scan:
    utils.correct_half_scan(ds)

In [None]:
peak_assign_tol = 0.25
utils.assign_peaks_to_grains(grains, cf_strong, peak_assign_tol)

print("Storing peak data in grains")
# iterate through all the grains
for g in tqdm(grains):
    # store this grain's peak indices so we know which 4D peaks we used for indexing
    g.mask_4d = cf_strong.grain_id == g.gid
    g.peaks_4d = cf_strong.index[cf_strong.grain_id == g.gid]

In [None]:
# Get grain translations from sinograms:
for grain in tqdm(grains):
    utils.fit_grain_position_from_sino(grain, cf_strong)
    grain.translation = np.array([grain.dx, grain.dy, 0])

# Get grain IPF colours:

utils.get_rgbs_for_grains(grains)

In [None]:
# make sure we get centre right (centre of rotation should be the middle of dty)
fig, ax = plt.subplots()
ax.plot([g.cen for g in grains])

plt.show()

c0 = np.median([g.cen for g in grains])

print('Center of rotation in dty', c0)

y0 = c0/2

print('y0 is', y0)

In [None]:
# plt.style.use('dark_background')
fig, ax = plt.subplots(2,2, figsize=(12,12))
a = ax.ravel()
x = [g.dy for g in grains]
y = [g.dx for g in grains]
s = [g.mask_4d.sum()/10 for g in grains]
a[0].scatter(x, y, c=[g.rgb_z for g in grains], s=s)
a[0].set(title='IPF color Z',  aspect='equal')
a[1].scatter(x, y, c=[g.rgb_y for g in grains], s=s)
a[1].set(title='IPF color Y', aspect='equal')
a[2].scatter(x, y, c=[g.rgb_x for g in grains], s=s)
a[2].set(title='IPF color X',  aspect='equal')
a[3].scatter(x, y, c=s)
a[3].set(title='Number of 4D peaks', aspect='equal')

fig.supxlabel("Lab y (transverse)")
fig.supylabel("Lab x (beam)")

for a in ax.ravel():
    a.invert_xaxis()


plt.show()

In [None]:
# generate sinogram for whole sample

whole_sample_sino, xedges, yedges = np.histogram2d(cf_4d.dty, cf_4d.omega, bins=[ds.ybinedges, ds.obinedges])

fig, ax = plt.subplots()
ax.imshow(whole_sample_sino, interpolation="nearest", vmin=0)
ax.set_aspect(4)
plt.show()

In [None]:
# "quick" whole-sample reconstruction

pad = 50
nthreads = len(os.sched_getaffinity(os.getpid()))

whole_sample_recon = utils.run_iradon_id11(whole_sample_sino, ds.obincens, pad, y0, workers=nthreads, apply_halfmask=is_half_scan, mask_central_zingers=is_half_scan)

In [None]:
# without a mask, MLEM can introduce artifacts in the corners
# so we can manually mask those out

# we can incoporate our own mask too
# by modifying the below function

def apply_manual_mask(mask_in):
    mask_out = mask_in.copy()
    
    # mask_out[200:, 250:] = 0
    
    return mask_out

# we should be able to easily segment this using scikit-image
recon_man_mask = apply_manual_mask(whole_sample_recon)

# we can also override the threshold if we don't like it:
# manual_threshold = 0.05
manual_threshold = None

if manual_threshold is None:
    thresh = threshold_otsu(recon_man_mask)
else:
    thresh = manual_threshold

thresh = threshold_otsu(recon_man_mask)

binary = recon_man_mask > thresh

chull = convex_hull_image(binary)

whole_sample_mask = chull

fig, axs = plt.subplots(1, 3, sharex=True, sharey=True, constrained_layout=True)
axs[0].imshow(recon_man_mask, vmin=0, origin="lower")
axs[1].imshow(binary, origin="lower")
axs[2].imshow(chull, origin="lower")

axs[0].set_title("Reconstruction")
axs[1].set_title("Binarised threshold")
axs[2].set_title("Convex hull")

fig.supxlabel("<-- Y axis")
fig.supylabel("Beam >")

plt.show()

In [None]:
fig, ax = plt.subplots()
m = cf_strong.grain_id >= 0
ax.scatter(cf_strong.omega[m], cf_strong.dty[m], c=cf_strong.grain_id[m])
plt.show()

In [None]:
# get corresponding 2D peaks from 4D peaks so we can build the sinograms with them

gord, inds, p2d = utils.get_2d_peaks_from_4d_peaks(ds, cf_strong)

# now our 2D peak assignments are known, let's populate our grain objects with our 2D peaks

for grain in tqdm(grains):
    i = grain.gid
    grain.peaks_2d = gord[inds[i+1] : inds[i+2]]

In [None]:
# Determine sinograms of all grains

nthreads = len(os.sched_getaffinity(os.getpid()))

do_sinos_partial = partial(utils.do_sinos, p2d=p2d, ds=ds)

with concurrent.futures.ThreadPoolExecutor(max_workers= max(1,nthreads-1)) as pool:
    for i in tqdm(pool.map(do_sinos_partial, grains), total=len(grains)):
        pass

In [None]:
# we can optionally correct the grain sinograms by scaling each row by the ring current:


correct_sinos_with_ring_current = True

if correct_sinos_with_ring_current:
    for grain in tqdm(grains):
        utils.correct_sinogram_rows_with_ring_current(grain, ds)

In [None]:
# Show sinogram and reconstruction of single grain

g = grains[0]

utils.iradon_grain(g, pad=pad, y0=y0, workers=max(nthreads, 20), sample_mask=whole_sample_mask, apply_halfmask=is_half_scan, mask_central_zingers=is_half_scan)

fig, axs = plt.subplots(1,2, figsize=(10,5))
axs[0].imshow(g.recon, vmin=0, origin="lower")
axs[0].set_title("ID11 iradon")
axs[1].imshow(g.ssino, aspect='auto')
axs[1].set_title("ssino")

plt.show()

In [None]:
# Now compute reconstructions for all grains

nthreads = len(os.sched_getaffinity(os.getpid()))

run_this_iradon = partial(utils.iradon_grain, pad=pad, y0=y0, sample_mask=whole_sample_mask, workers=1, apply_halfmask=is_half_scan, mask_central_zingers=is_half_scan)

with concurrent.futures.ThreadPoolExecutor( max_workers= max(1,nthreads-1) ) as pool:
    for i in tqdm(pool.map(run_this_iradon, grains), total=len(grains)):
        pass

In [None]:
for grain in grains:
    grain.og_recon = grain.recon

In [None]:
fig, a = plt.subplots(1,2,figsize=(10,5))
rec = a[0].imshow(grains[8].og_recon, vmin=0, origin="lower")
sin = a[1].imshow(grains[8].ssino, aspect='auto')

# Function to update the displayed image based on the selected frame
def update_frame(i):
    rec.set_array(grains[i].og_recon)
    sin.set_array(grains[i].ssino)
    a[0].set(title=grains[i].gid)
    fig.canvas.draw()

# Create a slider widget to select the frame number
frame_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=len(grains) - 1,
    step=1,
    description='Grain:'
)

interact(update_frame, i=frame_slider)

plt.show()

In [None]:
utils.plot_ipfs(grains)

In [None]:
rgb_x_array, rgb_y_array, rgb_z_array, grain_labels_array, raw_intensity_array = utils.build_slice_arrays(grains, cutoff_level=0.4)

In [None]:
# plot initial output

fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(rgb_z_array, origin="lower")
plt.show()

In [None]:
fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(grain_labels_array, origin="lower")  # originally 1,2,0
ax.set_title("Grain label map")
plt.show()

In [None]:
fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(raw_intensity_array, origin="lower")
ax.set_title("Raw intensity array")
plt.show()

In [None]:
# we can clean up these reconstructions using an MLEM iterative recon
# we can use the whole sample shape mask for this

In [None]:
utils.save_s3dxrd_grains_for_mlem(grains, ds, gord, inds, whole_sample_mask, y0=y0)

In [None]:
n_simultaneous_jobs = 1000
cores_per_task = 8
niter = 50

bash_script_path, recons_path = utils.prepare_mlem_bash(ds, grains, pad, is_half_scan, id11_code_path, n_simultaneous_jobs, cores_per_task, niter)

In [None]:
utils.slurm_submit_and_wait(bash_script_path, 30)

In [None]:
# collect results into grain attributes
# the filenames are element position not gid

for i, grain in enumerate(tqdm(grains)):
    grain.recon = np.loadtxt(os.path.join(recons_path, ds.dsname + f"_mlem_recon_{i}.txt"))

In [None]:
# look at all our grains

n_grains_to_plot = 25

grains_step = len(grains)//n_grains_to_plot

grid_size = np.ceil(np.sqrt(len(grains[::grains_step]))).astype(int)
nrows = (len(grains[::grains_step])+grid_size-1)//grid_size

fig, axs = plt.subplots(grid_size, nrows, figsize=(10,10), layout="constrained", sharex=True, sharey=True)
for i, ax in enumerate(axs.ravel()):
    if i < len(grains[::grains_step]):
    # get corresponding grain for this axis
        g = grains[::grains_step][i]
        ax.imshow(g.recon, vmin=0, origin="lower")
        # ax.invert_yaxis()
        ax.set_title(g.gid)
    
plt.show()

In [None]:
cutoff_level = 0.3

rgb_x_array, rgb_y_array, rgb_z_array, grain_labels_array, raw_intensity_array = utils.build_slice_arrays(grains, cutoff_level)

In [None]:
fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(rgb_z_array, origin="lower")
plt.show()

In [None]:
fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(raw_intensity_array, origin="lower")
ax.set_title("Sinogram raw intensity map")
plt.show()

In [None]:
fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(grain_labels_array, origin="lower")
ax.set_title("Grain label map")
plt.show()

In [None]:
utils.save_s3dxrd_grains_after_recon(grains, ds, raw_intensity_array, grain_labels_array, rgb_x_array, rgb_y_array, rgb_z_array)

In [None]:
ds.save()

In [None]:
if 1:
    raise ValueError("Change the 1 above to 0 to allow 'Run all cells' in the notebook")

In [None]:
# Now that we're happy with our indexing parameters, we can run the below cell to do this in bulk for many samples/datasets
# by default this will do all samples in sample_list, all datasets with a prefix of dset_prefix
# you can add samples and datasets to skip in skips_dict

skips_dict = {
    "FeAu_0p5_tR_nscope": ["top_-50um", "top_-100um"]
}

dset_prefix = "top"

sample_list = ["FeAu_0p5_tR_nscope"]
    
samples_dict = utils.find_datasets_to_process(rawdata_path, skips_dict, dset_prefix, sample_list)
    
# manual override:
# samples_dict = {"FeAu_0p5_tR_nscope": ["top_100um", "top_200um"]}
    
# now we have our samples_dict, we can process our data:

for sample, datasets in samples_dict.items():
    for dataset in datasets:
        print(f"Processing dataset {dataset} in sample {sample}")
        dset_path = os.path.join(processed_data_root_dir, sample, f"{sample}_{dataset}", f"{sample}_{dataset}_dataset.h5")
        if not os.path.exists(dset_path):
            print(f"Missing DataSet file for {dataset} in sample {sample}, skipping")
            continue
        
        print("Importing DataSet object")
        
        ds = ImageD11.sinograms.dataset.load(dset_path)
        print(f"I have a DataSet {ds.dset} in sample {ds.sample}")
        
        if not os.path.exists(ds.grainsfile):
            print(f"Missing grains file for {dataset} in sample {sample}, skipping")
            continue
            
        # check grains file for existance of slice_recon, skip if it's there
        with h5py.File(ds.grainsfile, "r") as hin:
            if "slice_recon" in hin.keys():
                print(f"Already reconstructed {dataset} in {sample}, skipping")
                continue
        
        # determine ring currents for sinogram row-by-row intensity correction
        utils.get_ring_current_per_scan(ds)
            
        cf_4d = ImageD11.columnfile.columnfile(ds.col4dfile)
        cf_4d.parameters.loadparameters(ds.parfile)
        cf_4d.updateGeometry()
        
        grains = utils.read_s3dxrd_grains_for_recon(ds)
        
        cf_strong = utils.selectpeaks(cf_4d, frac=cf_strong_frac, dsmax=cf_4d.ds.max(), dstol=cf_strong_dstol)
        
        if is_half_scan:
            utils.correct_half_scan(ds)
            
        utils.assign_peaks_to_grains(grains, cf_strong, tol=peak_assign_tol)
        
        for grain in tqdm(grains):
            grain.mask_4d = cf_strong.grain_id == grain.gid
            grain.peaks_4d = cf_strong.index[cf_strong.grain_id == grain.gid]
            utils.fit_grain_position_from_sino(grain, cf_strong)
            grain.translation = np.array([grain.dx, grain.dy, 0])
            
        utils.get_rgbs_for_grains(grains)

        c0 = np.median([g.cen for g in grains])
        y0 = c0/2
        
        whole_sample_sino, xedges, yedges = np.histogram2d(cf_4d.dty, cf_4d.omega, bins=[ds.ybinedges, ds.obinedges])
        
        print("Whole sample mask")
        whole_sample_recon = utils.run_iradon_id11(whole_sample_sino, ds.obincens, pad, y0, workers=nthreads, apply_halfmask=is_half_scan, mask_central_zingers=is_half_scan)
        
        recon_man_mask = apply_manual_mask(whole_sample_recon)
        if manual_threshold is None:
            thresh = threshold_otsu(recon_man_mask)
        else:
            thresh = manual_threshold
            
        binary = recon_man_mask > thresh
        whole_sample_mask = convex_hull_image(binary)
        
        print("Peak 2D organise")
        gord, inds, p2d = utils.get_2d_peaks_from_4d_peaks(ds, cf_strong)
        
        for grain in tqdm(grains):
            i = grain.gid
            grain.peaks_2d = gord[inds[i+1] : inds[i+2]]
        
        print("Making sinograms")
        do_sinos_partial = partial(utils.do_sinos, p2d=p2d, ds=ds)

        with concurrent.futures.ThreadPoolExecutor(max_workers= max(1,nthreads-1)) as pool:
            for i in tqdm(pool.map(do_sinos_partial, grains), total=len(grains)):
                pass
        
        if correct_sinos_with_ring_current:
            for grain in tqdm(grains):
                utils.correct_sinogram_rows_with_ring_current(grain, ds)
        
        print("Running iradon")
        
        run_this_iradon = partial(utils.iradon_grain, pad=pad, y0=y0, sample_mask=whole_sample_mask, workers=1, apply_halfmask=is_half_scan, mask_central_zingers=is_half_scan)

        with concurrent.futures.ThreadPoolExecutor( max_workers= max(1,nthreads-1) ) as pool:
            for i in tqdm(pool.map(run_this_iradon, grains), total=len(grains)):
                pass
            
        for grain in grains:
            grain.og_recon = grain.recon
            
        utils.save_s3dxrd_grains_for_mlem(grains, ds, gord, inds, whole_sample_mask, y0)
        
        bash_script_path, recons_path = utils.prepare_mlem_bash(ds, grains, pad, is_half_scan, id11_code_path, n_simultaneous_jobs, cores_per_task, niter)
        
        utils.slurm_submit_and_wait(bash_script_path, 30)
        
        for i, grain in enumerate(tqdm(grains)):
            grain.recon = np.loadtxt(os.path.join(recons_path, ds.dsname + f"_mlem_recon_{i}.txt"))
            
        rgb_x_array, rgb_y_array, rgb_z_array, grain_labels_array, raw_intensity_array = utils.build_slice_arrays(grains, cutoff_level)
        
        utils.save_s3dxrd_grains_after_recon(grains, ds, raw_intensity_array, grain_labels_array, rgb_x_array, rgb_y_array, rgb_z_array)
        
        ds.save()

print("Done!")