# Jupyter notebook based on ImageD11 to process scanning 3DXRD data
# Written by Haixing Fang, Jon Wright and James Ball
## Date: 28/03/2024

In [None]:
# USER: Change the path below to point to your local copy of ImageD11:

import os

home_dir = !echo $HOME
home_dir = str(home_dir[0])

# USER: You can change this location if you want

id11_code_path = os.path.join(home_dir, "Code/ImageD11")

import sys

sys.path.insert(0, id11_code_path)

In [None]:
# import functions we need

import concurrent.futures

%matplotlib ipympl

import h5py
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt

import ImageD11.columnfile
import ImageD11.sinograms.dataset
from ImageD11.grain import grain
from ImageD11.peakselect import select_ring_peaks_by_intensity, rings_mask
from ImageD11.sinograms.sinogram import GrainSinogram, build_slice_arrays, write_slice_recon, write_h5, read_h5, get_2d_peaks_from_4d_peaks
from ImageD11.sinograms.roi_iradon import run_iradon

from skimage.filters import threshold_otsu
from skimage.morphology import convex_hull_image

import ImageD11.nbGui.nb_utils as utils

import ipywidgets as widgets
from ipywidgets import interact

In [None]:
# NOTE: For old datasets before the new directory layout structure, we don't distinguish between RAW_DATA and PROCESSED_DATA

### USER: specify your experimental directory

rawdata_path = "/data/visitor/ihma439/id11/20231211/RAW_DATA"

!ls -lrt {rawdata_path}

### USER: specify where you want your processed data to go

processed_data_root_dir = "/data/visitor/ihma439/id11/20231211/PROCESSED_DATA/James/nb_testing"

In [None]:
# USER: pick a sample and a dataset you want to segment

sample = "FeAu_0p5_tR_nscope"
dataset = "top_100um"

In [None]:
# desination of H5 files

dset_path = os.path.join(processed_data_root_dir, sample, f"{sample}_{dataset}", f"{sample}_{dataset}_dataset.h5")

In [None]:
# Load the dataset (for motor positions, not sure why these are not in peaks)
ds = ImageD11.sinograms.dataset.load(dset_path)

In [None]:
# If the sinograms are only half-sinograms (we scanned dty across half the sample rather than the full sample), set the below to true:
is_half_scan = False

In [None]:
if is_half_scan:
    ds.correct_bins_for_half_scan()

In [None]:
# Import 4D peaks

cf_4d = ds.get_cf_4d_from_disk()

cf_4d.parameters.loadparameters(ds.parfile)
cf_4d.updateGeometry()

print(f"Read {cf_4d.nrows} 4D peaks")

In [None]:
# isolate main phase peaks, and remove them from the dataset

major_phase_cf_dstol = 0.0075
major_phase_peaks_mask = rings_mask(cf_4d, dstol=major_phase_cf_dstol, dsmax=cf_4d.ds.max())

minor_phase_peaks = cf_4d.copy()
minor_phase_peaks.filter(~major_phase_peaks_mask)

# Update geometry for minor phase peaks

minor_phase_par_file = os.path.join(processed_data_root_dir, '../../../SCRIPTS/James/S3DXRD/Au.par')
major_phase_par_file = ds.parfile

ds.parfile = minor_phase_par_file

minor_phase_peaks.parameters.loadparameters(ds.parfile)
minor_phase_peaks.updateGeometry()

cf_strong_frac = 0.95
cf_strong_dstol = 0.005

cf_strong = select_ring_peaks_by_intensity(minor_phase_peaks, dstol=cf_strong_dstol, dsmax=minor_phase_peaks.ds.max(), frac=cf_strong_frac, doplot=0.01)
print(cf_strong.nrows)

In [None]:
phase_name = "Au"

minor_phase_grains_path = os.path.splitext(ds.grainsfile)[0] + f'_{phase_name}.h5'

grains = ImageD11.grain.read_grain_file_h5(minor_phase_grains_path)
print(f"{len(grains)} grains imported")

In [None]:
# load major phase grain reconstruction
# for pad and recon shifts

major_phase_grainsinos = read_h5(ds.grainsfile, ds)
whole_sample_mask = major_phase_grainsinos[0].recon_mask
shift = major_phase_grainsinos[0].recon_shift
pad = major_phase_grainsinos[0].recon_pad
y0 = major_phase_grainsinos[0].recon_y0

print(shift, y0, pad)

In [None]:
# assign peaks to the grains

peak_assign_tol = 0.25
utils.assign_peaks_to_grains(grains, cf_strong, peak_assign_tol)

for grain_label, g in enumerate(grains):
    g.npks_4d = np.sum(cf_strong.grain_id == grain_label)

In [None]:
# let's make a GrainSinogram object for each grain

grainsinos = [GrainSinogram(g, ds) for g in grains]

In [None]:
# Now let's determine the positions of each grain from the 4D peaks

for grain_label, gs in enumerate(grainsinos):
    gs.update_lab_position_from_peaks(cf_strong, grain_label)

In [None]:
# We can also determine the RGB IPF colours of the grains which will be useful for plotting
# To do this, we first need to set a reference unitcell for each grain
# This will be used to determine the Orix Phase and therefore Orix Orientation

cf_pars = cf_strong.parameters.get_parameters()
spacegroup = 225  # spacegroup for FCC Au
cf_pars["cell_lattice_[P,A,B,C,I,F,R]"] = spacegroup

ref_ucell = ImageD11.unitcell.unitcell_from_parameters(cf_pars)

for g in grains:
    g.ref_unitcell = ref_ucell

# Now colours should work

utils.get_rgbs_for_grains(grains)

In [None]:
utils.plot_all_ipfs(grains)

In [None]:
# Now we can plot our grain positions and RGB colours:

# plt.style.use('dark_background')
fig, ax = plt.subplots(2,2, figsize=(12,12))
a = ax.ravel()
x = [g.translation[1] for g in grains]
y = [g.translation[0] for g in grains]
s = [g.npks_4d/10 for g in grains]
a[0].scatter(x, y, c=[g.rgb_z for g in grains], s=s)
a[0].set(title='IPF color Z',  aspect='equal')
a[1].scatter(x, y, c=[g.rgb_y for g in grains], s=s)
a[1].set(title='IPF color Y', aspect='equal')
a[2].scatter(x, y, c=[g.rgb_x for g in grains], s=s)
a[2].set(title='IPF color X',  aspect='equal')
a[3].scatter(x, y, c=s)
a[3].set(title='Number of 4D peaks', aspect='equal')

fig.supxlabel("Lab y (transverse)")
fig.supylabel("Lab x (beam)")

for a in ax.ravel():
    a.invert_xaxis()

plt.show()

In [None]:
# now we have a whole-sample reconstruction we can use as a sample mask
# let's build the sinograms for our grains
# before we do this, we need to determine our 2D peaks that will be used for the sinogram
# here we can get them from the 4D peaks:

hkltol = 0.25

gord, inds = get_2d_peaks_from_4d_peaks(ds.pk2d, cf_strong)

for grain_label, gs in enumerate(tqdm(grainsinos)):
    gs.prepare_peaks_from_4d(cf_strong, gord, inds, grain_label, hkltol)

In [None]:
# now we can actually generate the sinograms

for gs in tqdm(grainsinos):
    gs.build_sinogram()

In [None]:
# optionally correct the halfmask:

if is_half_scan:
    for gs in grainsinos:
        gs.correct_halfmask()

In [None]:
# Show sinogram of single grain

gs = grainsinos[0]

fig, ax = plt.subplots()

ax.imshow(gs.ssino, aspect='auto')
ax.set_title("ssino")

plt.show()

In [None]:
# We can optionally correct each row of the sinogram by the ring current of that rotation
# This helps remove artifacts in the reconstruction

correct_sinos_with_ring_current = True
if correct_sinos_with_ring_current:
    
    ds.get_ring_current_per_scan()
    
    for gs in grainsinos:
        gs.correct_ring_current(is_half_scan=is_half_scan)

In [None]:
# Show sinogram of single grain

gs = grainsinos[0]

fig, ax = plt.subplots()

ax.imshow(gs.ssino, aspect='auto')
ax.set_title("ssino")

plt.show()

In [None]:
# let's try out an iradon reconstruction

gs = grainsinos[0]

# update the parameters used for the iradon reconstruction

gs.update_recon_parameters(pad=pad, shift=shift, mask=whole_sample_mask)

# perform the reconstruction

gs.recon()

if is_half_scan:
    halfmask_radius = 25
    gs.mask_central_zingers("iradon", radius=halfmask_radius)

# view the result

fig, axs = plt.subplots(1,2, figsize=(10,5))
axs[0].imshow(gs.ssino, aspect='auto')
axs[0].set_title("ssino")
axs[1].imshow(gs.recons["iradon"], vmin=0, origin="lower")
axs[1].set_title("ID11 iradon")

plt.show()

In [None]:
# once you're happy with the reconstruction parameters used, set them for all the grains

for gs in grainsinos:
    gs.update_recon_parameters(pad=pad, shift=shift, mask=whole_sample_mask)

In [None]:
# reconstruct all grains in parallel

nthreads = len(os.sched_getaffinity(os.getpid()))

with concurrent.futures.ThreadPoolExecutor(max_workers= max(1,nthreads-1)) as pool:
    for i in tqdm(pool.map(GrainSinogram.recon, grainsinos), total=len(grainsinos)):
        pass
    
if is_half_scan:
    for gs in grainsinos:
        gs.mask_central_zingers("iradon", radius=halfmask_radius)

In [None]:
fig, a = plt.subplots(1,2,figsize=(10,5))
rec = a[0].imshow(grainsinos[0].recons["iradon"], vmin=0, origin="lower")
sin = a[1].imshow(grainsinos[0].ssino, aspect='auto')

# Function to update the displayed image based on the selected frame
def update_frame(i):
    rec.set_array(grainsinos[i].recons["iradon"])
    sin.set_array(grainsinos[i].ssino)
    a[0].set(title=str(i))
    fig.canvas.draw()

# Create a slider widget to select the frame number
frame_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=len(grains) - 1,
    step=1,
    description='Grain:'
)

interact(update_frame, i=frame_slider)

plt.show()

In [None]:
# # you can pick a grain and investigate the effects of changing y0 that gets passed to iradon
# # it' best to pick the grain AFTER reconstructing all grains, so you can pick a grain of interest

# g = grains[5]
    
# vals = np.linspace(-8.5, -7.5, 9)

# grid_size = np.ceil(np.sqrt(len(vals))).astype(int)
# nrows = (len(vals)+grid_size-1)//grid_size

# fig, axs = plt.subplots(grid_size, nrows, sharex=True, sharey=True)

# for inc, val in enumerate(tqdm(vals)):
    
#     crop = utils.run_iradon_id11(g.ssino, g.sinoangles, pad, y0=val, workers=1, apply_halfmask=is_half_scan, mask_central_zingers=is_half_scan)

    
#     axs.ravel()[inc].imshow(crop, origin="lower", vmin=0)
#     axs.ravel()[inc].set_title(val)
    
# plt.show()

In [None]:
# Let's assemble all the recons into one map

rgb_x_array, rgb_y_array, rgb_z_array, grain_labels_array, raw_intensity_array = build_slice_arrays(grainsinos, cutoff_level=0.9)

In [None]:
# plot initial output

fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(rgb_z_array, origin="lower")
plt.show()

In [None]:
fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(grain_labels_array, origin="lower")  # originally 1,2,0
ax.set_title("Grain label map")
plt.show()

In [None]:
fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(raw_intensity_array, origin="lower")
ax.set_title("Raw intensity array")
plt.show()

In [None]:
# There will likely be many streaks, indicating a few grains have dodgy reconstructions and are probably not to be trusted
# To fix this, we can count how many pixels in the grain labels array each grain has
# It can be helpful to run this filtration more than once

labels, counts = np.unique(grain_labels_array, return_counts=True)

fig, ax = plt.subplots()
ax.plot(labels[labels > 0], counts[labels > 0])
plt.show()

In [None]:
# filter out grains with more than 10 pixels in the label map
# this normally indicates a dodgy reconstruction for this grain
# only really applies if the grains are very small!

grain_too_many_px = 10

bad_gids = [int(label) for (label, count) in zip(labels, counts) if count > grain_too_many_px and label > 0]

In [None]:
# before we filter, determine our grain labels
# this is so we know which labels to give our grains in the filtered grain map
# such that it still agrees with the grain order

print(f"{len(grainsinos)} grains before filtration")
grainsinos_clean = [gs for (inc, gs) in enumerate(grainsinos) if inc not in bad_gids]
grain_labels_clean = [inc for (inc, gs) in enumerate(grainsinos) if inc not in bad_gids]
print(f"{len(grainsinos_clean)} grains after filtration")

In [None]:
# Let's assemble all the recons into one map

cutoff_level = 0.7

slice_arrays = build_slice_arrays(grainsinos_clean, cutoff_level=cutoff_level, grain_labels=grain_labels_clean)
rgb_x_array, rgb_y_array, rgb_z_array, grain_labels_array, raw_intensity_array = slice_arrays

In [None]:
# plot initial output

fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(rgb_z_array, origin="lower")  # originally 1,2,0
plt.show()

In [None]:
fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(grain_labels_array, origin="lower")  # originally 1,2,0
ax.set_title("Grain label map")
plt.show()

In [None]:
fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(raw_intensity_array, origin="lower")  # originally 1,2,0
ax.set_title("Raw intensity array")
plt.show()

In [None]:
# we can determine improved positions of our grains from the positions of their reconstructions

for gs in tqdm(grainsinos):
    gs.update_lab_position_from_recon()
    
# change the y0 back to what we imported at the beginning:

for gs in grainsinos:
    gs.update_recon_parameters(y0=y0)

In [None]:
# Now we can plot our grain positions and RGB colours:

# plt.style.use('dark_background')
fig, ax = plt.subplots(2,2, figsize=(12,12))
a = ax.ravel()
x = [g.translation[1] for g in grains]
y = [g.translation[0] for g in grains]
s = [g.npks_4d/10 for g in grains]
a[0].scatter(x, y, c=[g.rgb_z for g in grains], s=s)
a[0].set(title='IPF color Z',  aspect='equal')
a[1].scatter(x, y, c=[g.rgb_y for g in grains], s=s)
a[1].set(title='IPF color Y', aspect='equal')
a[2].scatter(x, y, c=[g.rgb_x for g in grains], s=s)
a[2].set(title='IPF color X',  aspect='equal')
a[3].scatter(x, y, c=s)
a[3].set(title='Number of 4D peaks', aspect='equal')

fig.supxlabel("Lab y (transverse)")
fig.supylabel("Lab x (beam)")

for a in ax.ravel():
    a.invert_xaxis()

plt.show()

In [None]:
# write our results to disk

write_h5(minor_phase_grains_path, grainsinos, write_grains_too=True)

In [None]:
# write the slice maps to disk too

write_slice_recon(minor_phase_grains_path, slice_arrays)

In [None]:
ds.parfile = major_phase_par_file

ds.save()

In [None]:
if 1:
    raise ValueError("Change the 1 above to 0 to allow 'Run all cells' in the notebook")

In [None]:
# Now that we're happy with our indexing parameters, we can run the below cell to do this in bulk for many samples/datasets
# by default this will do all samples in sample_list, all datasets with a prefix of dset_prefix
# you can add samples and datasets to skip in skips_dict

skips_dict = {
    "FeAu_0p5_tR_nscope": ["top_-50um", "top_-100um"]
}

dset_prefix = "top"

sample_list = ["FeAu_0p5_tR_nscope"]
    
samples_dict = utils.find_datasets_to_process(rawdata_path, skips_dict, dset_prefix, sample_list)
    
# manual override:
# samples_dict = {"FeAu_0p5_tR_nscope": ["top_100um", "top_200um"]}
    
# now we have our samples_dict, we can process our data:


for sample, datasets in samples_dict.items():
    for dataset in datasets:
        print(f"Processing dataset {dataset} in sample {sample}")
        dset_path = os.path.join(processed_data_root_dir, sample, f"{sample}_{dataset}", f"{sample}_{dataset}_dataset.h5")
        if not os.path.exists(dset_path):
            print(f"Missing DataSet file for {dataset} in sample {sample}, skipping")
            continue
        
        print("Importing DataSet object")
        
        ds = ImageD11.sinograms.dataset.load(dset_path)
        print(f"I have a DataSet {ds.dset} in sample {ds.sample}")
        
        minor_phase_grains_path = os.path.splitext(ds.grainsfile)[0] + f'_{phase_name}.h5'
        
        if not os.path.exists(minor_phase_grains_path):
            print(f"Missing grains file for {dataset} in sample {sample}, skipping")
            continue
            
        # check grains file for existance of slice_recon, skip if it's there
        with h5py.File(minor_phase_grains_path, "r") as hin:
            if "slice_recon" in hin.keys():
                print(f"Already reconstructed {dataset} in {sample}, skipping")
                continue
        
        if is_half_scan:
            ds.correct_bins_for_half_scan()
        
        print("Importing peaks")
        cf_4d = ds.get_cf_4d_from_disk()
        cf_4d.parameters.loadparameters(ds.parfile)
        cf_4d.updateGeometry()
        
        print("Filtering peaks")
        major_phase_peaks_mask = rings_mask(cf_4d, dstol=major_phase_cf_dstol, dsmax=cf_4d.ds.max())
        minor_phase_peaks = cf_4d.copy()
        minor_phase_peaks.filter(~major_phase_peaks_mask)
        major_phase_par_file = ds.parfile
        ds.parfile = minor_phase_par_file
        minor_phase_peaks.parameters.loadparameters(ds.parfile)
        minor_phase_peaks.updateGeometry()
        cf_strong = select_ring_peaks_by_intensity(minor_phase_peaks, dstol=cf_strong_dstol, dsmax=minor_phase_peaks.ds.max(), frac=cf_strong_frac)
        
        print("Importing grains")
        minor_phase_grains_path = os.path.splitext(ds.grainsfile)[0] + f'_{phase_name}.h5'
        grains = ImageD11.grain.read_grain_file_h5(minor_phase_grains_path)

        major_phase_grainsinos = read_h5(ds.grainsfile, ds)
        whole_sample_mask = major_phase_grainsinos[0].recon_mask
        shift = major_phase_grainsinos[0].recon_shift
        pad = major_phase_grainsinos[0].recon_pad
        
        utils.assign_peaks_to_grains(grains, cf_strong, tol=peak_assign_tol)
        for grain_label, g in enumerate(grains):
            g.npks_4d = np.sum(cf_strong.grain_id == grain_label)
        
        grainsinos = [GrainSinogram(g, ds) for g in grains]
        
        for grain_label, gs in enumerate(grainsinos):
            gs.update_lab_position_from_peaks(cf_strong, grain_label)
            
        cf_pars = cf_strong.parameters.get_parameters()
        cf_pars["cell_lattice_[P,A,B,C,I,F,R]"] = spacegroup
        ref_ucell = ImageD11.unitcell.unitcell_from_parameters(cf_pars)
        for g in grains:
            g.ref_unitcell = ref_ucell

        utils.get_rgbs_for_grains(grains)
        
        print("Building sinograms")
        for grain_label, gs in enumerate(tqdm(grainsinos)):
            gs.prepare_peaks_from_4d(cf_strong, grain_label, hkltol)
            gs.build_sinogram()
            
            if is_half_scan:
                gs.correct_halfmask()
        
        if correct_sinos_with_ring_current:
            ds.get_ring_current_per_scan()
            for gs in grainsinos:
                gs.correct_ring_current(is_half_scan=is_half_scan)
        
        for gs in grainsinos:
            gs.update_recon_parameters(pad=pad, shift=shift, mask=whole_sample_mask)
        
        with concurrent.futures.ThreadPoolExecutor(max_workers= max(1,nthreads-1)) as pool:
            for i in tqdm(pool.map(GrainSinogram.recon, grainsinos), total=len(grainsinos)):
                pass
        
        if is_half_scan:
            for gs in grainsinos:
                gs.mask_central_zingers("iradon", radius=halfmask_radius)

        rgb_x_array, rgb_y_array, rgb_z_array, grain_labels_array, raw_intensity_array = build_slice_arrays(grainsinos, cutoff_level)
        labels, counts = np.unique(grain_labels_array, return_counts=True)
        bad_gids = [int(label) for (label, count) in zip(labels, counts) if count > grain_too_many_px and label > 0]
        
        grainsinos_clean = [gs for (inc, gs) in enumerate(grainsinos) if inc not in bad_gids]
        grain_labels_clean = [inc for (inc, gs) in enumerate(grainsinos) if inc not in bad_gids]
        
        slice_arrays = build_slice_arrays(grainsinos_clean, cutoff_level=cutoff_level, grain_labels=grain_labels_clean)
        
        for gs in tqdm(grainsinos):
            gs.update_lab_position_from_recon()
            
        for gs in grainsinos:
            gs.update_recon_parameters(y0=y0)
        
        write_h5(minor_phase_grains_path, grainsinos, write_grains_too=True)
        write_slice_recon(minor_phase_grains_path, slice_arrays)

        ds.parfile = major_phase_par_file
        ds.save()

print("Done!")