# Jupyter notebook based on ImageD11 to process scanning 3DXRD data
# Written by Haixing Fang, Jon Wright and James Ball
## Date: 26/02/2024

In [None]:
# USER: Change the path below to point to your local copy of ImageD11:

import os

home_dir = !echo $HOME
home_dir = str(home_dir[0])

# USER: You can change this location if you want

id11_code_path = os.path.join(home_dir, "Code/ImageD11")

import sys

sys.path.insert(0, id11_code_path)

In [None]:
# import functions we need

import os
import concurrent.futures
import timeit

import matplotlib
%matplotlib widget

from skimage.feature import blob_log

import h5py
from tqdm.notebook import tqdm
import numba
import pprint
import numpy as np
import skimage.transform
import ipywidgets as ipyw
import matplotlib.pyplot as plt

from functools import partial

import ImageD11.nbGui.nb_utils as utils

import ImageD11.refinegrains
import ImageD11.columnfile
import ImageD11.sinograms.properties
import ImageD11.sinograms.roi_iradon
from ImageD11.blobcorrector import eiger_spatial
from ImageD11.grain import grain

In [None]:
# NOTE: For old datasets before the new directory layout structure, we don't distinguish between RAW_DATA and PROCESSED_DATA

### USER: specify your experimental directory

rawdata_path = "/data/visitor/ihma439/id11/20231211/RAW_DATA"

!ls -lrt {rawdata_path}

### USER: specify where you want your processed data to go

processed_data_root_dir = "/data/visitor/ihma439/id11/20231211/PROCESSED_DATA/James/nb_testing"

In [None]:
# USER: pick a sample and a dataset you want to segment

sample = "FeAu_0p5_tR_nscope"
dataset = "top_250um"

In [None]:
# desination of H5 files

dset_path = os.path.join(processed_data_root_dir, sample, f"{sample}_{dataset}", f"{sample}_{dataset}_dataset.h5")

In [None]:
# Load the dataset (for motor positions, not sure why these are not in peaks)
ds = ImageD11.sinograms.dataset.load(dset_path)

In [None]:
# determine ring currents for sinogram row-by-row intensity correction

utils.get_ring_current_per_scan(ds)

In [None]:
# Import 4D peaks

cf_4d = ImageD11.columnfile.columnfile(ds.col4dfile)

major_phase_par_file = ds.parfile

cf_4d.parameters.loadparameters(ds.parfile)
cf_4d.updateGeometry()

print(f"Read {cf_4d.nrows} 4D peaks")

In [None]:
phase_name = "Au"

grains = utils.read_s3dxrd_grains_minor_phase_for_recon(ds, phase_name=phase_name)

for grain in grains:
    # print(grain.gid)
    grain.a = np.cbrt(np.linalg.det(grain.ubi))
    
print(f"{len(grains)} grains imported")

In [None]:
# isolate main phase peaks, and remove them from the dataset

major_phase_cf_dstol = 0.0075
major_phase_peaks_mask = utils.unitcell_peaks_mask(cf_4d, dstol=major_phase_cf_dstol, dsmax=cf_4d.ds.max())

minor_phase_peaks = cf_4d.copy()
minor_phase_peaks.filter(~major_phase_peaks_mask)

# Update geometry for minor phase peaks

minor_phase_par_file = os.path.join(processed_data_root_dir, '../../../SCRIPTS/James/S3DXRD/Au.par')

ds.parfile = minor_phase_par_file

minor_phase_peaks.parameters.loadparameters(ds.parfile)
minor_phase_peaks.updateGeometry()

cf_strong_frac = 0.95
cf_strong_dstol = 0.005

cf_strong = utils.selectpeaks(minor_phase_peaks, dstol=cf_strong_dstol, dsmax=minor_phase_peaks.ds.max(), frac=cf_strong_frac, doplot=0.01)
print(cf_strong.nrows)

In [None]:
# If the sinograms are only half-sinograms (we scanned dty across half the sample rather than the full sample), set the below to true:
is_half_scan = False

In [None]:
if is_half_scan:
    utils.correct_half_scan(ds)

In [None]:
# load major phase grain reconstruction
# for pad and y0

major_phase_grains, _, _, _, _, _ = utils.read_s3dxrd_grains_after_recon(ds)
whole_sample_mask = major_phase_grains[0].sample_mask
y0 = major_phase_grains[0].y0

pad = ((major_phase_grains[0].recon.shape[0] - major_phase_grains[0].ssino.shape[0]))
pad

In [None]:
peak_assign_tol = 0.25

utils.assign_peaks_to_grains(grains, cf_strong, tol=peak_assign_tol)

print("Storing peak data in grains")
# iterate through all the grains
for g in tqdm(grains):
    # store this grain's peak indices so we know which 4D peaks we used for sinograms
    g.mask_4d = cf_strong.grain_id == g.gid
    g.peaks_4d = cf_strong.index[g.mask_4d]

In [None]:
fig, ax = plt.subplots()
m = cf_strong.grain_id >= 0
ax.scatter(cf_strong.omega[m], cf_strong.dty[m], c=cf_strong.grain_id[m])
plt.show()

In [None]:
mean_unit_cell_lengths = [grain.a for grain in grains]

fig, ax = plt.subplots()
ax.plot(mean_unit_cell_lengths)
ax.set_xlabel("Grain ID")
ax.set_ylabel("Unit cell length")
plt.show()

a0 = np.median(mean_unit_cell_lengths)
    
print(a0)

In [None]:
utils.plot_grain_sinograms(grains, cf_strong, 25)

In [None]:
# Get grain translations from sinograms:
for grain in tqdm(grains):
    utils.fit_grain_position_from_sino(grain, cf_strong)
    grain.translation = np.array([grain.dx, grain.dy, 0])

# Get grain IPF colours:

utils.get_rgbs_for_grains(grains)

In [None]:
# plt.style.use('dark_background')
fig, ax = plt.subplots(2,2, figsize=(12,12))
a = ax.ravel()
x = [g.dy for g in grains]
y = [g.dx for g in grains]
s = [g.mask_4d.sum()/10 for g in grains]
a[0].scatter(x, y, c=[g.rgb_z for g in grains], s=s)
a[0].set(title='IPF color Z',  aspect='equal')
a[1].scatter(x, y, c=[g.rgb_y for g in grains], s=s)
a[1].set(title='IPF color Y', aspect='equal')
a[2].scatter(x, y, c=[g.rgb_x for g in grains], s=s)
a[2].set(title='IPF color X',  aspect='equal')
a[3].scatter(x, y, c=s)
a[3].set(title='Number of 4D peaks', aspect='equal')

fig.supxlabel("Lab y (transverse)")
fig.supylabel("Lab x (beam)")

for a in ax.ravel():
    a.invert_xaxis()


plt.show()

In [None]:
# get corresponding 2D peaks from 4D peaks so we can build the sinograms with them

gord, inds, p2d = utils.get_2d_peaks_from_4d_peaks(ds, cf_strong)

# now our 2D peak assignments are known, let's populate our grain objects with our 2D peaks

for grain in tqdm(grains):
    i = grain.gid
    grain.peaks_2d = gord[inds[i+1] : inds[i+2]]
    # grain.mask_2d = np.isin(cf_2d.index, grain.peaks_2d)

In [None]:
# Determine sinograms of all grains

nthreads = len(os.sched_getaffinity(os.getpid()))

print("Making sinograms")
do_sinos_partial = partial(utils.do_sinos, p2d=p2d, ds=ds)

with concurrent.futures.ThreadPoolExecutor(max_workers= max(1,nthreads-1)) as pool:
    for i in tqdm(pool.map(do_sinos_partial, grains), total=len(grains)):
        pass

In [None]:
# we can optionally correct the grain sinograms by scaling each row by the ring current:

correct_sinos_with_ring_current = True

if correct_sinos_with_ring_current:
    for grain in tqdm(grains):
        utils.correct_sinogram_rows_with_ring_current(grain, ds)

In [None]:
# Show sinogram of single grain

g = grains[0]

fig, ax = plt.subplots()

ax.imshow((g.ssino/g.ssino.mean(axis=0)), norm=matplotlib.colors.LogNorm(), interpolation='nearest', origin="lower")

plt.show()

In [None]:
# you can pick a grain and investigate the effects of changing y0 that gets passed to iradon
# it' best to pick the grain AFTER reconstructing all grains, so you can pick a grain of interest

g = grains[5]
    
vals = np.linspace(-8.5, -7.5, 9)

grid_size = np.ceil(np.sqrt(len(vals))).astype(int)
nrows = (len(vals)+grid_size-1)//grid_size

fig, axs = plt.subplots(grid_size, nrows, sharex=True, sharey=True)

for inc, val in enumerate(tqdm(vals)):
    
    crop = utils.run_iradon_id11(g.ssino, g.sinoangles, pad, y0=val, workers=1, apply_halfmask=is_half_scan, mask_central_zingers=is_half_scan)

    
    axs.ravel()[inc].imshow(crop, origin="lower", vmin=0)
    axs.ravel()[inc].set_title(val)
    
plt.show()

In [None]:
# Now compute reconstructions for all grains

nthreads = len(os.sched_getaffinity(os.getpid()))

run_this_iradon = partial(utils.iradon_grain, pad=pad, y0=y0, sample_mask=whole_sample_mask, workers=1, apply_halfmask=is_half_scan, mask_central_zingers=is_half_scan)

with concurrent.futures.ThreadPoolExecutor( max_workers= max(1,nthreads-1) ) as pool:
    for i in tqdm(pool.map(run_this_iradon, grains), total=len(grains)):
        pass

In [None]:
import ipywidgets as widgets
from ipywidgets import interact
%matplotlib ipympl

fig, a = plt.subplots(1,2,figsize=(10,5))
rec = a[0].imshow(grains[8].recon, vmin=0, origin="lower")
sin = a[1].imshow(grains[8].ssino, aspect='auto')

# Function to update the displayed image based on the selected frame
def update_frame(i):
    rec.set_array(grains[i].recon)
    sin.set_array(grains[i].ssino)
    a[0].set(title=str(grains[i].gid))
    fig.canvas.draw()

# Create a slider widget to select the frame number
frame_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=len(grains) - 1,
    step=1,
    description='Grain:'
)

interact(update_frame, i=frame_slider)

plt.show()

In [None]:
# Fit grain positions from reconstructions

fit_pos_partial = partial(utils.fit_grain_position_from_recon, ds=ds, y0=y0)

with concurrent.futures.ThreadPoolExecutor(max_workers= max(1, nthreads-1)) as pool:
    for i in tqdm(pool.map(fit_pos_partial, grains), total=len(grains)):
        pass

In [None]:
# remove bad recon grains from future analysis
print(f"{len(grains)} grains before filtration")
grains = [grain for grain in grains if not grain.bad_recon]
print(f"{len(grains)} grains after filtration")

In [None]:
for g in grains:
    g.translation = np.array([g.x_blob, g.y_blob, 0])

In [None]:
# plt.style.use('dark_background')
fig, ax = plt.subplots(2,2, figsize=(12,12))
a = ax.ravel()
x = [g.translation[1] for g in grains]
y = [g.translation[0] for g in grains]
s = [g.mask_4d.sum()/10 for g in grains]
a[0].scatter(x, y, c=[g.rgb_z for g in grains], s=s)
a[0].set(title='IPF color Z',  aspect='equal')
a[1].scatter(x, y, c=[g.rgb_y for g in grains], s=s)
a[1].set(title='IPF color Y', aspect='equal')
a[2].scatter(x, y, c=[g.rgb_x for g in grains], s=s)
a[2].set(title='IPF color X',  aspect='equal')
a[3].scatter(x, y, c=s)
a[3].set(title='Number of 4D peaks', aspect='equal')

fig.supxlabel("Lab y (transverse)")
fig.supylabel("Lab x (beam)")

for a in ax.ravel():
    a.invert_xaxis()


plt.show()

In [None]:
utils.plot_ipfs(grains)

In [None]:
cutoff_level = 0.9

rgb_x_array, rgb_y_array, rgb_z_array, grain_labels_array, raw_intensity_array = utils.build_slice_arrays(grains, cutoff_level=cutoff_level)

In [None]:
# plot initial output

fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(rgb_z_array, origin="lower")  # originally 1,2,0
plt.show()

In [None]:
fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(grain_labels_array, origin="lower")  # originally 1,2,0
ax.set_title("Grain label map")
plt.show()

In [None]:
fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(raw_intensity_array, origin="lower")  # originally 1,2,0
ax.set_title("Raw intensity array")
plt.show()

In [None]:
# There will likely be many streaks, indicating a few grains have dodgy reconstructions and are probably not to be trusted
# To fix this, we can count how many pixels in the grain labels array each grain has
# It can be helpful to run this filtration more than once

labels, counts = np.unique(grain_labels_array, return_counts=True)

fig, ax = plt.subplots()
ax.plot(labels[labels > 0], counts[labels > 0])
plt.show()

In [None]:
# filter out grains with more than 22 pixels in the label map
# this normally indicates a dodgy reconstruction for this grain
# only really applies if the grains are very small!

grain_too_many_px = 20

bad_gids = [int(label) for (label, count) in zip(labels, counts) if count > grain_too_many_px and label > 0]

In [None]:
print(f"{len(grains)} grains before filtration")
grains = [grain for grain in grains if grain.gid not in bad_gids]
print(f"{len(grains)} grains after filtration")

In [None]:
cutoff_level = 0.7

rgb_x_array, rgb_y_array, rgb_z_array, grain_labels_array, raw_intensity_array = utils.build_slice_arrays(grains, cutoff_level=cutoff_level)

In [None]:
# plot initial output

fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(rgb_z_array, origin="lower")  # originally 1,2,0
plt.show()

In [None]:
fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(grain_labels_array, origin="lower")  # originally 1,2,0
ax.set_title("Grain label map")
plt.show()

In [None]:
fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(raw_intensity_array, origin="lower")  # originally 1,2,0
ax.set_title("Raw intensity array")
plt.show()

In [None]:
# write grains to disk

In [None]:
utils.save_s3dxrd_grains_minor_phase_after_recon(grains, ds, raw_intensity_array, grain_labels_array, rgb_x_array, rgb_y_array, rgb_z_array, phase_name=phase_name)

In [None]:
ds.parfile = major_phase_par_file
ds.save()

In [None]:
if 1:
    raise ValueError("Change the 1 above to 0 to allow 'Run all cells' in the notebook")

In [None]:
# Now that we're happy with our indexing parameters, we can run the below cell to do this in bulk for many samples/datasets
# by default this will do all samples in sample_list, all datasets with a prefix of dset_prefix
# you can add samples and datasets to skip in skips_dict

skips_dict = {
    "FeAu_0p5_tR_nscope": ["top_-50um", "top_-100um"]
}

dset_prefix = "top"

sample_list = ["FeAu_0p5_tR_nscope"]
    
samples_dict = utils.find_datasets_to_process(rawdata_path, skips_dict, dset_prefix, sample_list)
    
# manual override:
# samples_dict = {"FeAu_0p5_tR_nscope": ["top_100um", "top_200um"]}
    
# now we have our samples_dict, we can process our data:


for sample, datasets in samples_dict.items():
    for dataset in datasets:
        print(f"Processing dataset {dataset} in sample {sample}")
        dset_path = os.path.join(processed_data_root_dir, sample, f"{sample}_{dataset}", f"{sample}_{dataset}_dataset.h5")
        if not os.path.exists(dset_path):
            print(f"Missing DataSet file for {dataset} in sample {sample}, skipping")
            continue
        
        print("Importing DataSet object")
        
        ds = ImageD11.sinograms.dataset.load(dset_path)
        print(f"I have a DataSet {ds.dset} in sample {ds.sample}")
        
        ds.grainsfile_minor_phase = os.path.join(ds.analysispath, ds.dsname + f'_grains_{phase_name}.h5')
        
        if not os.path.exists(ds.grainsfile_minor_phase):
            print(f"Missing grains file for {dataset} in sample {sample}, skipping")
            continue
            
        # check grains file for existance of slice_recon, skip if it's there
        with h5py.File(ds.grainsfile_minor_phase, "r") as hin:
            if "slice_recon" in hin.keys():
                print(f"Already reconstructed {dataset} in {sample}, skipping")
                continue
        
        # determine ring currents for sinogram row-by-row intensity correction
        utils.get_ring_current_per_scan(ds)
            
        cf_4d = ImageD11.columnfile.columnfile(ds.col4dfile)
        cf_4d.parameters.loadparameters(ds.parfile)
        cf_4d.updateGeometry()
        
        major_phase_par_file = ds.parfile
        
        grains = utils.read_s3dxrd_grains_minor_phase_for_recon(ds)
        
        major_phase_peaks_mask = utils.unitcell_peaks_mask(cf_4d, dstol=major_phase_cf_dstol, dsmax=cf_4d.ds.max())

        minor_phase_peaks = cf_4d.copy()
        minor_phase_peaks.filter(~major_phase_peaks_mask)

        # Update geometry for minor phase peaks
        
        ds.parfile = minor_phase_par_file
        
        minor_phase_peaks.parameters.loadparameters(ds.parfile)
        minor_phase_peaks.updateGeometry()
        
        cf_strong = utils.selectpeaks(minor_phase_peaks, frac=cf_strong_frac, dsmax=cf_4d.ds.max(), dstol=cf_strong_dstol)
        
        if is_half_scan:
            utils.correct_half_scan(ds)
        
        major_phase_grains, _, _, _, _, _ = utils.read_s3dxrd_grains_after_recon(ds)
        whole_sample_mask = major_phase_grains[0].sample_mask
        y0 = major_phase_grains[0].y0
        pad = ((major_phase_grains[0].recon.shape[0] - major_phase_grains[0].ssino.shape[0]))
            
        utils.assign_peaks_to_grains(grains, cf_strong, tol=peak_assign_tol)
        
        for grain in tqdm(grains):
            grain.mask_4d = cf_strong.grain_id == grain.gid
            grain.peaks_4d = cf_strong.index[cf_strong.grain_id == grain.gid]
            utils.fit_grain_position_from_sino(grain, cf_strong)
            grain.translation = np.array([grain.dx, grain.dy, 0])

        utils.get_rgbs_for_grains(grains)
        
        print("Peak 2D organise")
        gord, inds, p2d = utils.get_2d_peaks_from_4d_peaks(ds, cf_strong)
        
        for grain in tqdm(grains):
            i = grain.gid
            grain.peaks_2d = gord[inds[i+1] : inds[i+2]]
        
        print("Making sinograms")
        do_sinos_partial = partial(utils.do_sinos, p2d=p2d, ds=ds)

        with concurrent.futures.ThreadPoolExecutor(max_workers= max(1,nthreads-1)) as pool:
            for i in tqdm(pool.map(do_sinos_partial, grains), total=len(grains)):
                pass
        
        if correct_sinos_with_ring_current:
            for grain in tqdm(grains):
                utils.correct_sinogram_rows_with_ring_current(grain, ds)
        
        print("Running iradon")
        
        run_this_iradon = partial(utils.iradon_grain, pad=pad, y0=y0, sample_mask=whole_sample_mask, workers=1, apply_halfmask=is_half_scan, mask_central_zingers=is_half_scan)

        with concurrent.futures.ThreadPoolExecutor( max_workers= max(1,nthreads-1) ) as pool:
            for i in tqdm(pool.map(run_this_iradon, grains), total=len(grains)):
                pass
            
        fit_pos_partial = partial(utils.fit_grain_position_from_recon, ds=ds, y0=y0)

        with concurrent.futures.ThreadPoolExecutor(max_workers= max(1, nthreads-1)) as pool:
            for i in tqdm(pool.map(fit_pos_partial, grains), total=len(grains)):
                pass
        
        grains = [grain for grain in grains if not grain.bad_recon]
        
        for g in grains:
            g.translation = np.array([g.x_blob, g.y_blob, 0])
        
        # run filtration twice (works better to filter out dodgy grains)
        rgb_x_array, rgb_y_array, rgb_z_array, grain_labels_array, raw_intensity_array = utils.build_slice_arrays(grains, cutoff_level)
        
        labels, counts = np.unique(grain_labels_array, return_counts=True)
        bad_gids = [int(label) for (label, count) in zip(labels, counts) if count > grain_too_many_px and label > 0]
        
        grains = [grain for grain in grains if grain.gid not in bad_gids]
        
        rgb_x_array, rgb_y_array, rgb_z_array, grain_labels_array, raw_intensity_array = utils.build_slice_arrays(grains, cutoff_level)
        
        labels, counts = np.unique(grain_labels_array, return_counts=True)
        bad_gids = [int(label) for (label, count) in zip(labels, counts) if count > grain_too_many_px and label > 0]
        
        grains = [grain for grain in grains if grain.gid not in bad_gids]
        
        rgb_x_array, rgb_y_array, rgb_z_array, grain_labels_array, raw_intensity_array = utils.build_slice_arrays(grains, cutoff_level)
        
        utils.save_s3dxrd_grains_minor_phase_after_recon(grains, ds, raw_intensity_array, grain_labels_array, rgb_x_array, rgb_y_array, rgb_z_array, phase_name=phase_name)
        
        ds.parfile = major_phase_par_file
        ds.save()

print("Done!")