# Jupyter notebook based on ImageD11 to process scanning 3DXRD data
# Written by Haixing Fang, Jon Wright and James Ball
## Date: 12/10/2024

This notebook will try to reconstruct grain shapes and positions from the grain orientations you found in the first notebook.  
This notebook is optimised for weak and noisy minor phase data.  
This notebook (and the tomo route in general) works best for low levels of deformation.  
If it doesn't seem to work well, try the point-by-point route instead!

In [None]:
exec(open('/data/id11/nanoscope/install_ImageD11_from_git.py').read())
PYTHONPATH = setup_ImageD11_from_git( ) # ( os.path.join( os.environ['HOME'],'Code'), 'ImageD11_git' )

In [None]:
# import functions we need

import concurrent.futures

%matplotlib ipympl

import h5py
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
from skimage.filters import threshold_otsu
from skimage.morphology import convex_hull_image

import ImageD11.columnfile
from ImageD11.grain import grain
from ImageD11.peakselect import select_ring_peaks_by_intensity, rings_mask
from ImageD11.sinograms.sinogram import GrainSinogram, build_slice_arrays, write_slice_recon, write_h5, read_h5, get_2d_peaks_from_4d_peaks
from ImageD11.sinograms.roi_iradon import run_iradon
from ImageD11.sinograms.tensor_map import TensorMap
import ImageD11.sinograms.dataset
import ImageD11.nbGui.nb_utils as utils

import ipywidgets as widgets
from ipywidgets import interact

In [None]:
# USER: Pass path to dataset file

dset_file = 'si_cube_test/processed/Si_cube/Si_cube_S3DXRD_nt_moves_dty/Si_cube_S3DXRD_nt_moves_dty_dataset.h5'

ds = ImageD11.sinograms.dataset.load(dset_file)
   
sample = ds.sample
dataset = ds.dsname
rawdata_path = ds.dataroot
processed_data_root_dir = ds.analysisroot

print(ds)
print(ds.shape)

In [None]:
# load phases from parameter file

ds.phases = ds.get_phases_from_disk()
ds.phases.unitcells

In [None]:
# now let's select a phase to index from our parameters json
major_phase_str = 'Fe'
minor_phase_str = 'Au'

major_phase_unitcell = ds.phases.unitcells[major_phase_str]
minor_phase_unitcell = ds.phases.unitcells[minor_phase_str]

print(major_phase_str, major_phase_unitcell.lattice_parameters, major_phase_unitcell.spacegroup)
print(minor_phase_str, minor_phase_unitcell.lattice_parameters, minor_phase_unitcell.spacegroup)

In [None]:
# If the sinograms are only half-sinograms (we scanned dty across half the sample rather than the full sample), set the below to true:
is_half_scan = False

In [None]:
if is_half_scan:
    ds.correct_bins_for_half_scan()

In [None]:
# Import 4D peaks

cf_4d = ds.get_cf_4d_from_disk()
ds.update_colfile_pars(cf_4d, phase_name=major_phase_str)

print(f"Read {cf_4d.nrows} 4D peaks")

In [None]:
# here we are filtering our peaks (cf_4d) to select only the strongest ones for indexing purposes only!
# dsmax is being set to limit rings given to the indexer - 6-8 rings is normally good

# USER: modify the "frac" parameter below and re-run the cell until the orange dot sits nicely on the "elbow" of the blue line
# this indicates the fractional intensity cutoff we will select
# if the blue line does not look elbow-shaped in the logscale plot, try changing the "doplot" parameter (the y scale of the logscale plot) until it does

major_phase_cf_frac = 0.994
major_phase_cf_dstol = 0.005

cf_major_phase = select_ring_peaks_by_intensity(cf_4d, frac=major_phase_cf_frac, dstol=major_phase_cf_dstol, doplot=0.95)
print(cf_4d.nrows)
print(cf_major_phase.nrows)

In [None]:
# Update geometry for minor phase peaks

ds.update_colfile_pars(cf_4d, phase_name=minor_phase_str)

In [None]:
minor_phase_cf_frac = 0.9975
minor_phase_cf_dstol = 0.005

cf_minor_phase = select_ring_peaks_by_intensity(cf_4d, frac=minor_phase_cf_frac, dstol=minor_phase_cf_dstol, doplot=0.95)
print(cf_4d.nrows)
print(cf_minor_phase.nrows)

In [None]:
major_phase_unitcell.makerings(cf_major_phase.ds.max())
minor_phase_unitcell.makerings(cf_minor_phase.ds.max())

In [None]:
# now we can take a look at the intensities of the remaining peaks

fig, ax = plt.subplots(figsize=(16, 9), constrained_layout=True)

ax.plot(cf_4d.ds, cf_4d.sum_intensity,',', label='cf_4d',c='blue')
ax.plot(cf_major_phase.ds, cf_major_phase.sum_intensity,',', label='major phase',c='orange')
ax.plot(cf_minor_phase.ds, cf_minor_phase.sum_intensity,',', label='minor phase',c='green')
ax.plot(major_phase_unitcell.ringds, [5e4,]*len(major_phase_unitcell.ringds), '|', ms=90, c="red")
ax.plot(minor_phase_unitcell.ringds, [1e4,]*len(minor_phase_unitcell.ringds), '|', ms=90, c="brown")
ax.semilogy()

ax.set_xlabel("Dstar")
ax.set_ylabel("Intensity")
ax.legend()

plt.show()

In [None]:
cf_strong = cf_minor_phase

In [None]:
# import the grains from disk

grains = ds.get_grains_from_disk(minor_phase_str)
print(f"{len(grains)} grains imported")

In [None]:
# load major phase grain reconstruction
# for pad and recon shifts

major_phase_grainsinos = read_h5(ds.grainsfile, ds, major_phase_str)
whole_sample_mask = major_phase_grainsinos[0].recon_mask
shift = major_phase_grainsinos[0].recon_shift
pad = major_phase_grainsinos[0].recon_pad
y0 = major_phase_grainsinos[0].recon_y0

print(shift, y0, pad)

In [None]:
# assign peaks to the grains

peak_assign_tol = 0.25
utils.assign_peaks_to_grains(grains, cf_strong, peak_assign_tol)

for grain_label, g in enumerate(grains):
    g.npks_4d = np.sum(cf_strong.grain_id == grain_label)

In [None]:
# let's make a GrainSinogram object for each grain

grainsinos = [GrainSinogram(g, ds) for g in grains]

In [None]:
# Now let's determine the positions of each grain from the 4D peaks

for grain_label, gs in enumerate(grainsinos):
    gs.update_lab_position_from_peaks(cf_strong, grain_label)

In [None]:
# We can also determine the RGB IPF colours of the grains which will be useful for plotting

utils.get_rgbs_for_grains(grains)

In [None]:
utils.plot_all_ipfs(grains)

In [None]:
# Now we can plot our grain positions and RGB colours:

# plt.style.use('dark_background')
fig, ax = plt.subplots(2,2, figsize=(12,12))
a = ax.ravel()
x = [g.translation[1] for g in grains]
y = [g.translation[0] for g in grains]
s = [g.npks_4d/10 for g in grains]
a[0].scatter(x, y, c=[g.rgb_z for g in grains], s=s)
a[0].set(title='IPF color Z',  aspect='equal')
a[1].scatter(x, y, c=[g.rgb_y for g in grains], s=s)
a[1].set(title='IPF color Y', aspect='equal')
a[2].scatter(x, y, c=[g.rgb_x for g in grains], s=s)
a[2].set(title='IPF color X',  aspect='equal')
a[3].scatter(x, y, c=s)
a[3].set(title='Number of 4D peaks', aspect='equal')

fig.supxlabel("Lab y (transverse)")
fig.supylabel("Lab x (beam)")

for a in ax.ravel():
    a.invert_xaxis()

plt.show()

In [None]:
# let's build the sinograms for our grains
# before we do this, we need to determine our 2D peaks that will be used for the sinogram
# here we can get them from the 4D peaks:

hkltol = 0.25

gord, inds = get_2d_peaks_from_4d_peaks(ds.pk2d, cf_strong)

for grain_label, gs in enumerate(tqdm(grainsinos)):
    gs.prepare_peaks_from_4d(cf_strong, gord, inds, grain_label, hkltol)

In [None]:
# now we can actually generate the sinograms

for gs in tqdm(grainsinos):
    gs.build_sinogram()

In [None]:
# optionally correct the halfmask:

if is_half_scan:
    for gs in grainsinos:
        gs.correct_halfmask()

In [None]:
# Show sinogram of single grain

gs = grainsinos[0]

fig, ax = plt.subplots()

ax.imshow(gs.ssino, aspect='auto')
ax.set_title("ssino")

plt.show()

In [None]:
# We can optionally correct each row of the sinogram by the ring current of that rotation
# This helps remove artifacts in the reconstruction

correct_sinos_with_ring_current = True
if correct_sinos_with_ring_current:
    
    ds.get_ring_current_per_scan()
    
    for gs in grainsinos:
        gs.correct_ring_current(is_half_scan=is_half_scan)

In [None]:
# Show sinogram of single grain

gs = grainsinos[0]

fig, ax = plt.subplots()

ax.imshow(gs.ssino, aspect='auto')
ax.set_title("ssino")

plt.show()

In [None]:
# let's try out an iradon reconstruction

gs = grainsinos[0]

# update the parameters used for the iradon reconstruction

gs.update_recon_parameters(pad=pad, shift=shift, mask=whole_sample_mask)

# perform the reconstruction

gs.recon()

if is_half_scan:
    halfmask_radius = 25
    gs.mask_central_zingers("iradon", radius=halfmask_radius)

# view the result

fig, axs = plt.subplots(1,2, figsize=(10,5))
axs[0].imshow(gs.ssino, aspect='auto')
axs[0].set_title("ssino")
axs[1].imshow(gs.recons["iradon"], vmin=0, origin="lower")
axs[1].set_title("ID11 iradon")

plt.show()

In [None]:
# once you're happy with the reconstruction parameters used, set them for all the grains

for gs in grainsinos:
    gs.update_recon_parameters(pad=pad, shift=shift, mask=whole_sample_mask)

In [None]:
# reconstruct all grains in parallel

nthreads = len(os.sched_getaffinity(os.getpid()))

with concurrent.futures.ThreadPoolExecutor(max_workers= max(1,nthreads-1)) as pool:
    for i in tqdm(pool.map(GrainSinogram.recon, grainsinos), total=len(grainsinos)):
        pass
    
if is_half_scan:
    for gs in grainsinos:
        gs.mask_central_zingers("iradon", radius=halfmask_radius)

In [None]:
fig, a = plt.subplots(1,2,figsize=(10,5))
rec = a[0].imshow(grainsinos[0].recons["iradon"], vmin=0, origin="lower")
sin = a[1].imshow(grainsinos[0].ssino, aspect='auto')

# Function to update the displayed image based on the selected frame
def update_frame(i):
    rec.set_array(grainsinos[i].recons["iradon"])
    sin.set_array(grainsinos[i].ssino)
    a[0].set(title=str(i))
    fig.canvas.draw()

# Create a slider widget to select the frame number
frame_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=len(grains) - 1,
    step=1,
    description='Grain:'
)

interact(update_frame, i=frame_slider)

plt.show()

In [None]:
# # you can pick a grain and investigate the effects of changing y0 that gets passed to iradon
# # it' best to pick the grain AFTER reconstructing all grains, so you can pick a grain of interest

# g = grains[5]
    
# vals = np.linspace(-8.5, -7.5, 9)

# grid_size = np.ceil(np.sqrt(len(vals))).astype(int)
# nrows = (len(vals)+grid_size-1)//grid_size

# fig, axs = plt.subplots(grid_size, nrows, sharex=True, sharey=True)

# for inc, val in enumerate(tqdm(vals)):
    
#     crop = utils.run_iradon_id11(g.ssino, g.sinoangles, pad, y0=val, workers=1, apply_halfmask=is_half_scan, mask_central_zingers=is_half_scan)

    
#     axs.ravel()[inc].imshow(crop, origin="lower", vmin=0)
#     axs.ravel()[inc].set_title(val)
    
# plt.show()

In [None]:
# let's assemble all the recons together into a TensorMap

tensor_map = TensorMap.from_grainsinos(grainsinos, cutoff_level=0.4)

In [None]:
# plot initial output

tensor_map.plot("ipf_z")
tensor_map.plot("labels")
tensor_map.plot("intensity")

In [None]:
# There will likely be many streaks, indicating a few grains have dodgy reconstructions and are probably not to be trusted
# You could optionally run ASTRA:


# choose the number of iterations
# experience shows 500 is good, and pretty quick on the GPU

niter = 500

for gs in grainsinos:
    gs.update_recon_parameters(pad=pad, shift=shift, mask=whole_sample_mask, niter=niter, y0=y0)

# save the GrainSinogram objects to disk
write_h5(ds.grainsfile, grainsinos, overwrite_grains=True, group_name=minor_phase_str)

# prepare ASTRA bash scripts to run it on the cluster

bash_script_path = utils.prepare_astra_bash(ds, ds.grainsfile, PYTHONPATH, group_name=minor_phase_str)

# submit the job

utils.slurm_submit_and_wait(bash_script_path, 10)

# re-import our reconstructed grains

grainsinos = read_h5(ds.grainsfile, ds, group_name=minor_phase_str)
# re-associate grainsino grain objects to existing grain objects

for gs, g in zip(grainsinos, grains):
    gs.grain = g
    gs.ds = ds

In [None]:
# Let's assemble all the recons into one map

cutoff_level = 0.5

tensor_map_astra = TensorMap.from_grainsinos(grainsinos, cutoff_level=cutoff_level, method="astra")

In [None]:
tensor_map_astra.plot("ipf_z")
tensor_map_astra.plot("labels")
tensor_map_astra.plot("intensity")

In [None]:
# To fix this, we can count how many pixels in the grain labels array each grain has
# It can be helpful to run this filtration more than once

labels, counts = np.unique(tensor_map_astra["labels"], return_counts=True)

fig, ax = plt.subplots()
ax.plot(labels[labels > 0], counts[labels > 0])
plt.show()

In [None]:
# filter out grains with more than 10 pixels in the label map
# this normally indicates a dodgy reconstruction for this grain
# only really applies if the grains are very small!

grain_too_many_px = 10

bad_gids = [int(label) for (label, count) in zip(labels, counts) if count > grain_too_many_px and label > 0]

In [None]:
# before we filter, determine our grain labels
# this is so we know which labels to give our grains in the filtered grain map
# such that it still agrees with the grain order

print(f"{len(grainsinos)} grains before filtration")
grainsinos_clean = [gs for (inc, gs) in enumerate(grainsinos) if inc not in bad_gids]
grain_labels_clean = [inc for (inc, gs) in enumerate(grainsinos) if inc not in bad_gids]
print(f"{len(grainsinos_clean)} grains after filtration")

In [None]:
# set the gids for grainsinos_clean based on their original labels

for gs, label in zip(grainsinos_clean, grain_labels_clean):
    gs.grain.gid = label

In [None]:
cutoff_level = 0.5

tensor_map_astra = TensorMap.from_grainsinos(grainsinos_clean, cutoff_level=cutoff_level, method="astra")

In [None]:
tensor_map_astra.plot("ipf_z")
tensor_map_astra.plot("labels")
tensor_map_astra.plot("intensity")

In [None]:
# we can determine improved positions of our grains from the positions of their reconstructions

for gs in tqdm(grainsinos):
    gs.update_lab_position_from_recon()
    
# change the y0 back to what we imported at the beginning:

for gs in grainsinos:
    gs.update_recon_parameters(y0=y0)

In [None]:
# Now we can plot our grain positions and RGB colours:

# plt.style.use('dark_background')
fig, ax = plt.subplots(2,2, figsize=(12,12))
a = ax.ravel()
x = [gs.grain.translation[1] for gs in grainsinos_clean]
y = [gs.grain.translation[0] for gs in grainsinos_clean]
s = [gs.grain.npks_4d/10 for gs in grainsinos_clean]
a[0].scatter(x, y, c=[gs.grain.rgb_z for gs in grainsinos_clean], s=s)
a[0].set(title='IPF color Z',  aspect='equal')
a[1].scatter(x, y, c=[gs.grain.rgb_y for gs in grainsinos_clean], s=s)
a[1].set(title='IPF color Y', aspect='equal')
a[2].scatter(x, y, c=[gs.grain.rgb_x for gs in grainsinos_clean], s=s)
a[2].set(title='IPF color X',  aspect='equal')
a[3].scatter(x, y, c=s)
a[3].set(title='Number of 4D peaks', aspect='equal')

fig.supxlabel("Lab y (transverse)")
fig.supylabel("Lab x (beam)")

for a in ax.ravel():
    a.invert_xaxis()

plt.show()

In [None]:
# write our results to disk

write_h5(ds.grainsfile, grainsinos, overwrite_grains=True, group_name=minor_phase_str)

In [None]:
# write the TensorMap to disk too

tensor_map_astra.to_h5(ds.grainsfile, h5group='TensorMap_' + minor_phase_str)

In [None]:
# we can also write an XDMF file so you can visualise the TensorMap with ParaView

tensor_map_astra.to_paraview(ds.grainsfile, h5group='TensorMap_' + minor_phase_str)

In [None]:
ds.save()

In [None]:
if 1:
    raise ValueError("Change the 1 above to 0 to allow 'Run all cells' in the notebook")

In [None]:
# Now that we're happy with our indexing parameters, we can run the below cell to do this in bulk for many samples/datasets
# by default this will do all samples in sample_list, all datasets with a prefix of dset_prefix
# you can add samples and datasets to skip in skips_dict

skips_dict = {
    "FeAu_0p5_tR_nscope": ["top_-50um", "top_-100um"]
}

dset_prefix = "top"

sample_list = ["FeAu_0p5_tR_nscope"]
    
samples_dict = utils.find_datasets_to_process(rawdata_path, skips_dict, dset_prefix, sample_list)
    
# manual override:
# samples_dict = {"FeAu_0p5_tR_nscope": ["top_100um", "top_150um"]}
    
# now we have our samples_dict, we can process our data:


for sample, datasets in samples_dict.items():
    for dataset in datasets:
        print(f"Processing dataset {dataset} in sample {sample}")
        dset_path = os.path.join(processed_data_root_dir, sample, f"{sample}_{dataset}", f"{sample}_{dataset}_dataset.h5")
        if not os.path.exists(dset_path):
            print(f"Missing DataSet file for {dataset} in sample {sample}, skipping")
            continue
        
        print("Importing DataSet object")
        
        ds = ImageD11.sinograms.dataset.load(dset_path)
        
        if not os.path.exists(ds.grainsfile):
            print(f"Missing grains file for {dataset} in sample {sample}, skipping")
            continue
        
        # check grains file for existance of output, skip if it's there
        with h5py.File(ds.grainsfile, "r") as hin:
            if 'TensorMap_' + minor_phase_str in hin.keys():
                print(f"Already reconstructed {dataset} in {sample}, skipping")
                continue
        
        ds.phases = ds.get_phases_from_disk()
        
        if is_half_scan:
            ds.correct_bins_for_half_scan()
        
        print("Importing peaks")
        cf_4d = ds.get_cf_4d_from_disk()
        ds.update_colfile_pars(cf_4d, phase_name=minor_phase_str)
        
        print("Filtering peaks")
        cf_strong = select_ring_peaks_by_intensity(cf_4d, frac=minor_phase_cf_frac, dstol=minor_phase_cf_dstol)
        
        print("Importing grains")
        grains = ds.get_grains_from_disk(minor_phase_str)

        major_phase_grainsinos = read_h5(ds.grainsfile, ds, major_phase_str)
        whole_sample_mask = major_phase_grainsinos[0].recon_mask
        shift = major_phase_grainsinos[0].recon_shift
        pad = major_phase_grainsinos[0].recon_pad
        
        utils.assign_peaks_to_grains(grains, cf_strong, tol=peak_assign_tol)
        for grain_label, g in enumerate(grains):
            g.npks_4d = np.sum(cf_strong.grain_id == grain_label)
        
        grainsinos = [GrainSinogram(g, ds) for g in grains]
        
        for grain_label, gs in enumerate(grainsinos):
            gs.update_lab_position_from_peaks(cf_strong, grain_label)

        utils.get_rgbs_for_grains(grains)
        
        print("Building sinograms")
        
        gord, inds = get_2d_peaks_from_4d_peaks(ds.pk2d, cf_strong)
        
        for grain_label, gs in enumerate(tqdm(grainsinos)):
            gs.prepare_peaks_from_4d(cf_strong, gord, inds, grain_label, hkltol)
            gs.build_sinogram()
            
            if is_half_scan:
                gs.correct_halfmask()
        
        if correct_sinos_with_ring_current:
            ds.get_ring_current_per_scan()
            for gs in grainsinos:
                gs.correct_ring_current(is_half_scan=is_half_scan)
        
        for gs in grainsinos:
            gs.update_recon_parameters(pad=pad, shift=shift, mask=whole_sample_mask)
        
        print("Running iradon reconstructions")
        with concurrent.futures.ThreadPoolExecutor(max_workers= max(1,nthreads-1)) as pool:
            for i in tqdm(pool.map(GrainSinogram.recon, grainsinos), total=len(grainsinos)):
                pass
        
        if is_half_scan:
            for gs in grainsinos:
                gs.mask_central_zingers("iradon", radius=halfmask_radius)
        
        for gs in grainsinos:
            gs.update_recon_parameters(pad=pad, shift=shift, mask=whole_sample_mask, niter=niter, y0=y0)
        
        print("Submitting ASTRA recon job")
        write_h5(ds.grainsfile, grainsinos, overwrite_grains=True, group_name=minor_phase_str)
        bash_script_path = utils.prepare_astra_bash(ds, ds.grainsfile, PYTHONPATH, group_name=minor_phase_str)
        utils.slurm_submit_and_wait(bash_script_path, 10)
        
        print("Importing reconstructed grains")
        grainsinos = read_h5(ds.grainsfile, ds, group_name=minor_phase_str)
        for gs, g in zip(grainsinos, grains):
            gs.grain = g
            gs.ds = ds
        
        print("Filtering noisy recons")
        tensor_map_astra = TensorMap.from_grainsinos(grainsinos, cutoff_level=cutoff_level, method="astra")
        labels, counts = np.unique(tensor_map_astra["labels"], return_counts=True)
        bad_gids = [int(label) for (label, count) in zip(labels, counts) if count > grain_too_many_px and label > 0]

        grainsinos_clean = [gs for (inc, gs) in enumerate(grainsinos) if inc not in bad_gids]
        grain_labels_clean = [inc for (inc, gs) in enumerate(grainsinos) if inc not in bad_gids]
        
        for gs, label in zip(grainsinos_clean, grain_labels_clean):
            gs.grain.gid = label

        tensor_map_astra = TensorMap.from_grainsinos(grainsinos_clean, cutoff_level=cutoff_level, method="astra")
        
        for gs in tqdm(grainsinos):
            gs.update_lab_position_from_recon()
            
        for gs in grainsinos:
            gs.update_recon_parameters(y0=y0)
        
        print("Exporting")
        write_h5(ds.grainsfile, grainsinos, overwrite_grains=True, group_name=minor_phase_str)
        
        tensor_map_astra.to_h5(ds.grainsfile, h5group='TensorMap_' + minor_phase_str)
        tensor_map_astra.to_paraview(ds.grainsfile, h5group='TensorMap_' + minor_phase_str)

        ds.save()

print("Done!")