# Jupyter notebook based on ImageD11 to process scanning 3DXRD data
# Written by Haixing Fang, Jon Wright and James Ball
## Date: 26/02/2024

In [None]:
# USER: Change the path below to point to your local copy of ImageD11:

import os

home_dir = !echo $HOME
home_dir = str(home_dir[0])

# USER: You can change this location if you want

id11_code_path = os.path.join(home_dir, "Code/ImageD11")

import sys

sys.path.insert(0, id11_code_path)

In [None]:
# import functions we need

import os
import concurrent.futures
import timeit

import matplotlib
%matplotlib widget

from skimage.feature import blob_log

import h5py
from tqdm.notebook import tqdm
import numba
import pprint
import numpy as np
import skimage.transform
import ipywidgets as ipyw
import matplotlib.pyplot as plt

from functools import partial

import ImageD11.nbGui.nb_utils as utils

import ImageD11.refinegrains
import ImageD11.columnfile
import ImageD11.sinograms.properties
import ImageD11.sinograms.roi_iradon
from ImageD11.blobcorrector import eiger_spatial
from ImageD11.grain import grain

In [None]:
# define our functions

def read_grains_minor_phase(ds):
    with h5py.File(ds.grainsfile_minor_phase, 'r') as hin:      
        grains_group = 'grains'
        
        grains = []
        for gid_string in tqdm(sorted(hin[grains_group].keys(), key=lambda x: int(x))):
            gg = hin[grains_group][gid_string]
            ubi = gg.attrs['ubi'][:]
            g = ImageD11.grain.grain(ubi)
            g.gid = int(gid_string)
            g.peaks_4d = gg['peaks_4d_indexing'][:]
            grains.append(g)
    
    return grains



def read_grains_main_phase(ds):
    with h5py.File(ds.grainsfile, 'r') as hin:      
        grains_group = 'grains'
        
        grains = []
        for gid_string in tqdm(sorted(hin[grains_group].keys(), key=lambda x: int(x))):
            gg = hin[grains_group][gid_string]
            ubi = gg.attrs['ubi'][:]

            g = ImageD11.grain.grain(ubi)
            g.gid = int(gid_string)
            g.y0 = gg.attrs['y0'][()]
            g.sample_mask = gg['circle_mask'][:]
            g.recon = gg['recon'][:]
            g.ssino = gg['ssino'][:]
            grains.append(g)
    
    return grains


def map_grain_from_peaks(g, flt, ds):
    """
    Computes sinogram
    flt is already the peaks for this grain
    Returns angles, sino
    """   
    NY = len(ds.ybincens)  # number of y translations
    iy = np.round((flt.dty - ds.ybincens[0]) / (ds.ybincens[1]-ds.ybincens[0])).astype(int)  # flt column for y translation index

    # The problem is to assign each spot to a place in the sinogram
    hklmin = g.hkl_2d_strong.min(axis=1)  # Get minimum integer hkl (e.g -10, -9, -10)
    dh = g.hkl_2d_strong - hklmin[:,np.newaxis]  # subtract minimum hkl from all integer hkls
    de = (g.etasigns_2d_strong.astype(int) + 1)//2  # something signs related
    #   4D array of h,k,l,+/-
    # pkmsk is whether a peak has been observed with this HKL or not
    pkmsk = np.zeros(list(dh.max(axis=1) + 1 )+[2,], int)  # make zeros-array the size of (max dh +1) and add another axis of length 2
    pkmsk[ dh[0], dh[1], dh[2], de ] = 1  # we found these HKLs for this grain
    #   sinogram row to hit
    pkrow = np.cumsum(pkmsk.ravel()).reshape(pkmsk.shape) - 1  #
    # counting where we hit an HKL position with a found peak
    # e.g (-10, -9, -10) didn't get hit, but the next one did, so increment

    npks = pkmsk.sum( )
    destRow = pkrow[ dh[0], dh[1], dh[2], de ] 
    sino = np.zeros( ( npks, NY ), 'f' )
    hits = np.zeros( ( npks, NY ), 'f' )
    angs = np.zeros( ( npks, NY ), 'f' )
    adr = destRow * NY + iy 
    # Just accumulate 
    sig = flt.sum_intensity
    ImageD11.cImageD11.put_incr64( sino, adr, sig )
    ImageD11.cImageD11.put_incr64( hits, adr, np.ones(len(de),dtype='f'))
    ImageD11.cImageD11.put_incr64( angs, adr, flt.omega)
    
    sinoangles = angs.sum( axis = 1) / hits.sum( axis = 1 )
    # Normalise:
    sino = (sino.T/sino.max( axis=1 )).T
    # Sort (cosmetic):
    order = np.lexsort((np.arange(npks), sinoangles))
    sinoangles = sinoangles[order]
    ssino = sino[order].T
    
    return sinoangles, ssino, hits[order].T

def do_sinos(g, hkltol=0.25):
    flt = utils.tocolf({p:p2d[p][g.peaks_2d] for p in p2d}, par_path, dxfile=e2dx_path, dyfile=e2dy_path)  # convert it to a columnfile and spatially correct
    
    hkl_real = np.dot(g.ubi, (flt.gx, flt.gy, flt.gz))  # calculate hkl of all assigned peaks
    hkl_int = np.round(hkl_real).astype(int) # round to nearest integer
    dh = ((hkl_real - hkl_int)**2).sum(axis = 0)  # calculate square of difference

    # g.dherrall = dh.mean()  # mean hkl error across all assigned peaks
    # g.npksall = flt.nrows  # total number of assigned peaks
    flt.filter(dh < hkltol*hkltol)  # filter all assigned peaks to be less than hkltol squared
    hkl_real = np.dot(g.ubi, (flt.gx, flt.gy, flt.gz))  # recalculate error after filtration
    hkl_int = np.round(hkl_real).astype(int)
    dh = ((hkl_real - hkl_int)**2).sum(axis = 0)
    # g.dherr = dh.mean()  # dherr is mean hkl error across assigned peaks after hkltol filtering
    # g.npks = flt.nrows  # total number of assigned peaks after hkltol filtering
    g.etasigns_2d_strong = np.sign(flt.eta)
    g.hkl_2d_strong = hkl_int  # integer hkl of assigned peaks after hkltol filtering
    g.sinoangles, g.ssino, g.hits = map_grain_from_peaks(g, flt, ds)
    return i,g


def run_iradon_id11(grain, pad=20, y0=0, workers=1, sample_mask=None, apply_halfmask=False, mask_central_zingers=False):
    outsize = grain.ssino.shape[0] + pad
    
    if apply_halfmask:
        halfmask = np.zeros_like(grain.ssino)

        halfmask[:len(halfmask)//2-1, :] = 1
        halfmask[len(halfmask)//2-1, :] = 0.5
        
        ssino_to_recon = grain.ssino * halfmask
    else:
        ssino_to_recon = grain.ssino
        
    # # pad the sample mask
    # sample_mask_padded = np.pad(sample_mask, pad//2)

    
    # Perform iradon transform of grain sinogram, store result (reconstructed grain shape) in g.recon
    grain.recon = ImageD11.sinograms.roi_iradon.iradon(ssino_to_recon, 
                                                       theta=grain.sinoangles, 
                                                       mask=sample_mask,
                                                       output_size=outsize,
                                                       projection_shifts=np.full(grain.ssino.shape, -y0),
                                                       filter_name='hamming',
                                                       interpolation='linear',
                                                       workers=workers)
    
    if mask_central_zingers:
        grs = grain.recon.shape[0]
        xpr, ypr = -grs//2 + np.mgrid[:grs, :grs]
        inner_mask_radius = 25
        outer_mask_radius = inner_mask_radius + 2

        inner_circle_mask = (xpr ** 2 + ypr ** 2) < inner_mask_radius ** 2
        outer_circle_mask = (xpr ** 2 + ypr ** 2) < outer_mask_radius ** 2

        mask_ring = inner_circle_mask & outer_circle_mask
        # we now have a mask to apply
        fill_value = np.median(grain.recon[mask_ring])
        grain.recon[inner_circle_mask] = fill_value
    
    return grain


def find_cens_from_recon(grain):
    grain.bad_recon = False
    blobs = blob_log(grain.recon, min_sigma=1, max_sigma=10, num_sigma=10, threshold=.01)
    blobs_sorted = sorted(blobs, key=lambda x: x[2], reverse=True)
    try:
        largest_blob = blobs_sorted[0]
        
        # we now have the blob position in recon space
        # we need to go back to microns
        
        # first axis (vertical) is x
        # second axis (horizontal) is y
        
        x_recon_space = largest_blob[0]
        y_recon_space = largest_blob[1]
        
        # centre of the recon image is centre of space
        
        x_microns = (x_recon_space - grain.recon.shape[0]//2) * ds.ystep + y0
        y_microns = -(y_recon_space - grain.recon.shape[1]//2) * ds.ystep + y0
        
        grain.x_blob = x_microns
        grain.y_blob = y_microns
    except IndexError:
        # didn't find any blobs
        # for small grains like these, if we didn't find a blob, normally indicates recon is bad
        # we will exclude it from maps and export
        grain.bad_recon = True

        
cmp = {'compression':'gzip',
'compression_opts': 2,
'shuffle' : True }

def save_array(grp, name, ary):
    hds = grp.require_dataset(name, 
                              shape=ary.shape,
                              dtype=ary.dtype,
                              **cmp)
    hds[:] = ary
    return hds

def save_grains_minor_phase(grains, ds):

    # delete existing file, because our grain numbers have changed
    if os.path.exists(ds.grainsfile_minor_phase):
        os.remove(ds.grainsfile_minor_phase)
    
    with h5py.File(ds.grainsfile_minor_phase, 'w-') as hout:  # fail if exists
        try:
            grp = hout.create_group('peak_assignments')
        except ValueError:
            grp = hout['peak_assignments']

        ds_gord = save_array( grp, 'gord', gord )
        ds_gord.attrs['description'] = 'Grain ordering: g[i].pks = gord[ inds[i] : inds[i+1] ]'
        ds_inds = save_array( grp, 'inds', inds )
        ds_inds.attrs['description'] = 'Grain indices: g[i].pks = gord[ inds[i] : inds[i+1] ]'
        
        try:
            grp = hout.create_group('slice_recon')
        except ValueError:
            grp = hout['slice_recon']
        save_array(grp, 'intensity', raw_intensity_array).attrs['description'] = 'Raw intensity array for all grains'
        save_array(grp, 'labels', grain_labels_array).attrs['description'] = 'Grain labels array for all grains'
        
        ipfxdset = save_array(grp, 'ipf_x_col_map', rgb_x_array)
        ipfxdset.attrs['description'] = 'IPF X color at each pixel'
        ipfxdset.attrs['CLASS'] = 'IMAGE'
        ipfydset = save_array(grp, 'ipf_y_col_map', rgb_y_array)
        ipfydset.attrs['description'] = 'IPF Y color at each pixel'
        ipfydset.attrs['CLASS'] = 'IMAGE'
        ipfzdset = save_array(grp, 'ipf_z_col_map', rgb_z_array)
        ipfzdset.attrs['description'] = 'IPF Z color at each pixel'
        ipfzdset.attrs['CLASS'] = 'IMAGE'
        
        grains_group = hout.create_group('grains')
        for g in tqdm(grains):
            gg = grains_group.create_group(str(g.gid))
            # save stuff for sinograms
            
            gg.attrs.update({'ubi':g.ubi})
            
            save_array(gg, 'peaks_4d_indexing', g.peaks_4d).attrs['description'] = "Strong 4D peaks that were assigned to this grain during indexing"
            
            save_array(gg, 'ssino', g.ssino).attrs['description'] = 'Sinogram of peak intensities sorted by omega'
            save_array(gg, 'sinoangles', g.sinoangles).attrs['description'] = 'Projection angles for sinogram'
            save_array(gg, 'og_recon', g.recon).attrs['description'] = 'Original ID11 iRadon reconstruction'
            save_array(gg, 'recon', g.recon).attrs['description'] = 'Final reconstruction'
            save_array(gg, 'circle_mask', whole_sample_mask).attrs['description'] = 'Reconstruction mask to use for MLEM'
            
            # might as well save peaks stuff while we're here
            save_array(gg, 'translation', g.translation).attrs['description'] = 'Grain translation in lab frame'
            save_array(gg, 'peaks_2d_sinograms', g.peaks_2d).attrs['description'] = "2D peaks from strong 4D peaks that were assigned to this grain for sinograms"
            save_array(gg, 'peaks_4d_sinograms', g.peaks_4d).attrs['description'] = "Strong 4D peaks that were assigned to this grain for sinograms"

            gg.attrs['cen'] = g.cen

In [None]:
# NOTE: For old datasets before the new directory layout structure, we don't distinguish between RAW_DATA and PROCESSED_DATA

### USER: specify your experimental directory

rawdata_path = "/home/esrf/james1997a/Data/ihma439/id11/20231211/RAW_DATA"

!ls -lrt {rawdata_path}

### USER: specify where you want your processed data to go

processed_data_root_dir = "/home/esrf/james1997a/Data/ihma439/id11/20231211/PROCESSED_DATA/James/20240226"

In [None]:
# USER: pick a sample and a dataset you want to segment

sample = "FeAu_0p5_tR_nscope"
dataset = "top_250um"

In [None]:
# desination of H5 files

dset_path = os.path.join(processed_data_root_dir, sample, f"{sample}_{dataset}", f"{sample}_{dataset}_dataset.h5")

par_path = os.path.join(processed_data_root_dir, 'Fe_refined.par')

e2dx_path = os.path.join(processed_data_root_dir, '../../CeO2/e2dx_E-08-0173_20231127.edf')
e2dy_path = os.path.join(processed_data_root_dir, '../../CeO2/e2dy_E-08-0173_20231127.edf')

In [None]:
# Load the dataset (for motor positions, not sure why these are not in peaks)
ds = ImageD11.sinograms.dataset.load(dset_path)

In [None]:
# determine ring currents for sinogram row-by-row intensity correction

utils.get_ring_current_per_scan(ds)

In [None]:
# Import 4D peaks

cf_4d = ImageD11.columnfile.columnfile(ds.col4dfile)

cf_4d.parameters.loadparameters(par_path)
cf_4d.updateGeometry()

print(f"Read {cf_4d.nrows} 4D peaks")

In [None]:
phase_name = "Au"

ds.grainsfile_minor_phase = os.path.join(ds.analysispath, ds.dsname + f'_grains_{phase_name}.h5')

grains = read_grains_minor_phase(ds)

for grain in grains:
    # print(grain.gid)
    grain.a = np.cbrt(np.linalg.det(grain.ubi))
    
print(f"{len(grains)} grains imported")

In [None]:
# isolate main phase peaks, and remove them from the dataset
main_phase_peaks_mask = utils.unitcell_peaks_mask(cf_4d, dstol=0.0075, dsmax=cf_4d.ds.max())

minor_phase_peaks = cf_4d.copy()
minor_phase_peaks.filter(~main_phase_peaks_mask)

# Update geometry for minor phase peaks

par_path = os.path.join(processed_data_root_dir, 'Au.par')
minor_phase_peaks.parameters.loadparameters(par_path)
minor_phase_peaks.updateGeometry()

cf_strong = utils.selectpeaks(minor_phase_peaks, dstol=0.005, dsmax=minor_phase_peaks.ds.max(), frac=0.9, doplot=0.01)
print(cf_strong.nrows)

In [None]:
# If the sinograms are only half-sinograms (we scanned dty across half the sample rather than the full sample), set the below to true:
is_half_scan = False

In [None]:
if is_half_scan:
    utils.correct_half_scan(ds)

In [None]:
# load major phase grain reconstruction
# for pad and y0

major_phase_grains = read_grains_main_phase(ds)
whole_sample_mask = major_phase_grains[0].sample_mask
y0 = major_phase_grains[0].y0

pad = ((major_phase_grains[0].recon.shape[0] - major_phase_grains[0].ssino.shape[0]))
pad

In [None]:
tol = 0.25

utils.assign_peaks_to_grains(grains, cf_strong, tol=tol)

print("Storing peak data in grains")
# iterate through all the grains
for g in tqdm(grains):
    # store this grain's peak indices so we know which 4D peaks we used for sinograms
    g.mask_4d = cf_strong.grain_id == g.gid
    g.peaks_4d = cf_strong.index[g.mask_4d]

In [None]:
fig, ax = plt.subplots()
m = cf_strong.grain_id >= 0
ax.scatter(cf_strong.omega[m], cf_strong.dty[m], c=cf_strong.grain_id[m])
plt.show()

In [None]:
mean_unit_cell_lengths = [grain.a for grain in grains]

fig, ax = plt.subplots()
ax.plot(mean_unit_cell_lengths)
ax.set_xlabel("Grain ID")
ax.set_ylabel("Unit cell length")
plt.show()

a0 = np.median(mean_unit_cell_lengths)
    
print(a0)

In [None]:
utils.plot_grain_sinograms(grains, cf_strong, 25)

In [None]:
for grain in tqdm(grains):
    # grain.peaks_4d_selected, grain.cen, grain.dx, grain.dy = utils.graincen(grain.gid, cf_strong, doplot=False)
    grain.rgb_z = utils.grain_to_rgb(grain, ax=(0,0,1),)# symmetry = Symmetry.cubic)
    grain.rgb_y = utils.grain_to_rgb(grain, ax=(0,1,0),)# symmetry = Symmetry.cubic)
    grain.rgb_x = utils.grain_to_rgb(grain, ax=(1,0,0),)# symmetry = Symmetry.cubic)
    utils.fit_grain_position_from_sino(grain, cf_strong)

In [None]:
# make sure we get cen right (centre of rotation should be the middle of dty)
fig, ax = plt.subplots()
ax.plot([g.cen for g in grains])

plt.show()

In [None]:
c0 = np.median([g.cen for g in grains])

print('Center of rotation in dty', c0)

In [None]:
# plt.style.use('dark_background')
fig, ax = plt.subplots(2,2, figsize=(12,12))
a = ax.ravel()
x = [g.dy for g in grains]
y = [g.dx for g in grains]
s = [g.mask_4d.sum()/10 for g in grains]
a[0].scatter(x, y, c=[g.rgb_z for g in grains], s=s)
a[0].set(title='IPF color Z',  aspect='equal')
a[1].scatter(x, y, c=[g.rgb_y for g in grains], s=s)
a[1].set(title='IPF color Y', aspect='equal')
a[2].scatter(x, y, c=[g.rgb_x for g in grains], s=s)
a[2].set(title='IPF color X',  aspect='equal')
a[3].scatter(x, y, c=s)
a[3].set(title='Number of 4D peaks', aspect='equal')

fig.supxlabel("Lab y (transverse)")
fig.supylabel("Lab x (beam)")

for a in ax.ravel():
    a.invert_xaxis()


plt.show()

In [None]:
# populate translations of grains
for g in grains:
    g.translation = np.array([g.dx, g.dy, 0])

In [None]:
# Big scary block
# Must understand what this does!

# Ensure cf is sorted by spot3d_id
# NOTE: spot3d_id should be spot4d_id, because we have merged into 4D?
assert (np.argsort(cf_strong.spot3d_id) == np.arange(cf_strong.nrows)).all()

# load the 2d peak labelling output
pks = ImageD11.sinograms.properties.pks_table.load(ds.pksfile)

# Grab the 2d peak centroids
p2d = pks.pk2d(ds.omega, ds.dty)

# NOTE: These are not spatially corrected?!

numba_order, numba_histo = utils.counting_sort(p2d['spot3d_id'])

grain_2d_id = utils.palloc(p2d['spot3d_id'].shape, np.dtype(int))

cleanid = cf_strong.grain_id.copy()

utils.find_grain_id(cf_strong.spot3d_id, cleanid, p2d['spot3d_id'], grain_2d_id, numba_order)

gord, counts = utils.counting_sort(grain_2d_id)

inds = np.concatenate(((0,), np.cumsum(counts)))

# I think what we end up with is:
# inds
# this is an array which tells you which 2D spots each grain owns
# the 2D spots are sorted by spot ID
# inds tells you for each grain were you can find its associated 2D spots

In [None]:
# now our 2D peak assignments are known, let's populate our grain objects with our 2D peaks

for grain in tqdm(grains):
    i = grain.gid
    grain.peaks_2d = gord[inds[i+1] : inds[i+2]]
    # grain.mask_2d = np.isin(cf_2d.index, grain.peaks_2d)

In [None]:
# Determine sinograms of all grains

nthreads = len(os.sched_getaffinity(os.getpid()))

with concurrent.futures.ThreadPoolExecutor(max_workers= max(1,nthreads-1)) as pool:
    for i in tqdm(pool.map(do_sinos, grains), total=len(grains)):
        pass

In [None]:
# we can optionally correct the grain sinograms by scaling each row by the ring current:

for grain in tqdm(grains):
    utils.correct_sinogram_rows_with_ring_current(grain, ds)

In [None]:
# Show sinogram of single grain

g = grains[0]

fig, ax = plt.subplots()

ax.imshow((g.ssino/g.ssino.mean(axis=0)), norm=matplotlib.colors.LogNorm(), interpolation='nearest', origin="lower")

plt.show()

In [None]:
# you can pick a grain and investigate the effects of changing y0 that gets passed to iradon
# it' best to pick the grain AFTER reconstructing all grains, so you can pick a grain of interest

g = grains[5]
    
vals = np.linspace(-8.5, -7.5, 9)

grid_size = np.ceil(np.sqrt(len(vals))).astype(int)
nrows = (len(vals)+grid_size-1)//grid_size

fig, axs = plt.subplots(grid_size, nrows, sharex=True, sharey=True)

for inc, val in enumerate(tqdm(vals)):
    run_iradon_id11(g, y0=val)
    # crop = g.recon[200:240, 230:290]
    crop = g.recon
    
    axs.ravel()[inc].imshow(crop, origin="lower", vmin=0)
    axs.ravel()[inc].set_title(val)
    
plt.show()

In [None]:
# you can overwrite y0 here

# y0 = -7.875
# pad = 50

In [None]:
nthreads = len(os.sched_getaffinity(os.getpid()))

run_this_iradon = partial(run_iradon_id11, pad=pad, y0=y0, sample_mask=whole_sample_mask, workers=1, apply_halfmask=is_half_scan, mask_central_zingers=is_half_scan)

with concurrent.futures.ThreadPoolExecutor( max_workers= max(1,nthreads-1) ) as pool:
    for i in tqdm(pool.map(run_this_iradon, grains), total=len(grains)):
        pass

In [None]:
import ipywidgets as widgets
from ipywidgets import interact
%matplotlib ipympl

fig, a = plt.subplots(1,2,figsize=(10,5))
rec = a[0].imshow(grains[8].recon, vmin=0, origin="lower")
sin = a[1].imshow(grains[8].ssino, aspect='auto')

# Function to update the displayed image based on the selected frame
def update_frame(i):
    rec.set_array(grains[i].recon)
    sin.set_array(grains[i].ssino)
    a[0].set(title=str(grains[i].gid))
    fig.canvas.draw()

# Create a slider widget to select the frame number
frame_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=len(grains) - 1,
    step=1,
    description='Grain:'
)

interact(update_frame, i=frame_slider)

plt.show()

In [None]:
with concurrent.futures.ThreadPoolExecutor(max_workers= max(1, nthreads-1)) as pool:
    for i in tqdm(pool.map(find_cens_from_recon, grains), total=len(grains)):
        pass

In [None]:
# remove bad recon grains from future analysis
print(f"{len(grains)} grains before filtration")
grains = [grain for grain in grains if not grain.bad_recon]
print(f"{len(grains)} grains after filtration")

In [None]:
for g in grains:
    g.translation = np.array([g.x_blob, -g.y_blob, 0])

In [None]:
# plt.style.use('dark_background')
fig, ax = plt.subplots(2,2, figsize=(12,12))
a = ax.ravel()
x = [g.translation[1] for g in grains]
y = [g.translation[0] for g in grains]
s = [g.mask_4d.sum()/10 for g in grains]
a[0].scatter(x, y, c=[g.rgb_z for g in grains], s=s)
a[0].set(title='IPF color Z',  aspect='equal')
a[1].scatter(x, y, c=[g.rgb_y for g in grains], s=s)
a[1].set(title='IPF color Y', aspect='equal')
a[2].scatter(x, y, c=[g.rgb_x for g in grains], s=s)
a[2].set(title='IPF color X',  aspect='equal')
a[3].scatter(x, y, c=s)
a[3].set(title='Number of 4D peaks', aspect='equal')

fig.supxlabel("Lab y (transverse)")
fig.supylabel("Lab x (beam)")

for a in ax.ravel():
    a.invert_xaxis()


plt.show()

In [None]:
f,a = plt.subplots( 1,3, figsize=(15,5) )
ty, tx = utils.triangle().T
for i,title in enumerate( 'xyz' ):
    ax = np.zeros(3)
    ax[i] = 1.
    hkl = [utils.crystal_direction_cubic( g.ubi, ax ) for g in grains]
    xy = np.array([utils.hkl_to_pf_cubic(h) for h in hkl ])
    rgb = np.array([utils.hkl_to_color_cubic(h) for h in hkl ])
    for j in range(len(grains)):
        grains[j].rgb = rgb[j]
    a[i].scatter( xy[:,1], xy[:,0], c = rgb )   # Note the "x" axis of the plot is the 'k' direction and 'y' is h (smaller)
    a[i].set(title=title, aspect='equal', facecolor='silver', xticks=[], yticks=[])
    a[i].plot( tx, ty, 'k-', lw = 1 )

In [None]:
rgb_x_array, rgb_y_array, rgb_z_array, grain_labels_array, raw_intensity_array = utils.build_slice_arrays(grains, cutoff_level=0.7)

In [None]:
# plot initial output

fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(rgb_z_array, origin="lower")  # originally 1,2,0
plt.show()

In [None]:
fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(grain_labels_array, origin="lower")  # originally 1,2,0
ax.set_title("Grain label map")
plt.show()

In [None]:
fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(raw_intensity_array, origin="lower")  # originally 1,2,0
ax.set_title("Raw intensity array")
plt.show()

In [None]:
labels, counts = np.unique(grain_labels_array, return_counts=True)

fig, ax = plt.subplots()
ax.plot(labels[labels > 0], counts[labels > 0])
plt.show()

In [None]:
# filter out grains with more than 25 pixels in the label map
# this normally indicates a dodgy reconstruction for this grain
# only really applies if the grains are very small!

bad_gids = [int(label) for (label, count) in zip(labels, counts) if count > 25 and label > 0]
bad_gids

In [None]:
print(f"{len(grains)} grains before filtration")
grains = [grain for grain in grains if grain.gid not in bad_gids]
print(f"{len(grains)} grains after filtration")

In [None]:
rgb_x_array, rgb_y_array, rgb_z_array, grain_labels_array, raw_intensity_array = utils.build_slice_arrays(grains, cutoff_level=0.7)

In [None]:
# plot initial output

fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(rgb_z_array, origin="lower")  # originally 1,2,0
plt.show()

In [None]:
fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(grain_labels_array, origin="lower")  # originally 1,2,0
ax.set_title("Grain label map")
plt.show()

In [None]:
fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(raw_intensity_array, origin="lower")  # originally 1,2,0
ax.set_title("Raw intensity array")
plt.show()

In [None]:
# write grains to disk

In [None]:
save_grains_minor_phase(grains, ds)

In [None]:
if 1:
    raise ValueError("Change the 1 above to 0 to allow 'Run all cells' in the notebook")

In [None]:
# Now that we're happy with our indexing parameters, we can run the below cell to do this in bulk for many samples/datasets
# by default this will do all samples in sample_list, all datasets with a prefix of dset_prefix
# you can add samples and datasets to skip in skips_dict

skips_dict = {
    "FeAu_0p5_tR_nscope": ["top_-50um", "top_-100um"]
}

dset_prefix = "top"

sample_list = ["FeAu_0p5_tR_nscope"]
    
samples_dict = utils.find_datasets_to_process(rawdata_path, skips_dict, dset_prefix, sample_list)
    
# manual override:
# samples_dict = {"FeAu_0p5_tR_nscope": ["top_100um", "top_200um"]}
    
# now we have our samples_dict, we can process our data:

par_path = os.path.join(processed_data_root_dir, 'Fe_refined.par')
minor_phase_par_path = os.path.join(processed_data_root_dir, 'Au.par')

e2dx_path = os.path.join(processed_data_root_dir, '../../CeO2/e2dx_E-08-0173_20231127.edf')
e2dy_path = os.path.join(processed_data_root_dir, '../../CeO2/e2dy_E-08-0173_20231127.edf')

main_phase_cf_dstol = 0.0075
phase_name = "Au"

cf_strong_frac = 0.9
cf_strong_dstol = 0.005

is_half_scan = False
correct_sinos_with_ring_current = True

peak_assign_tol = 0.25

nthreads = len(os.sched_getaffinity(os.getpid()))

# pad = 50

cutoff_level = 0.7

grain_too_many_px = 22

for sample, datasets in samples_dict.items():
    for dataset in datasets:
        print(f"Processing dataset {dataset} in sample {sample}")
        dset_path = os.path.join(processed_data_root_dir, sample, f"{sample}_{dataset}", f"{sample}_{dataset}_dataset.h5")
        if not os.path.exists(dset_path):
            print(f"Missing DataSet file for {dataset} in sample {sample}, skipping")
            continue
        
        print("Importing DataSet object")
        
        ds = ImageD11.sinograms.dataset.load(dset_path)
        print(f"I have a DataSet {ds.dset} in sample {ds.sample}")
        
        ds.grainsfile_minor_phase = os.path.join(ds.analysispath, ds.dsname + f'_grains_{phase_name}.h5')
        
        if not os.path.exists(ds.grainsfile_minor_phase):
            print(f"Missing grains file for {dataset} in sample {sample}, skipping")
            continue
            
        # check grains file for existance of slice_recon, skip if it's there
        with h5py.File(ds.grainsfile_minor_phase, "r") as hin:
            if "slice_recon" in hin.keys():
                print(f"Already reconstructed {dataset} in {sample}, skipping")
                continue
        
        # determine ring currents for sinogram row-by-row intensity correction
        utils.get_ring_current_per_scan(ds)
            
        cf_4d = ImageD11.columnfile.columnfile(ds.col4dfile)
        cf_4d.parameters.loadparameters(par_path)
        cf_4d.updateGeometry()
        
        grains = read_grains_minor_phase(ds)
        
        main_phase_peaks_mask = utils.unitcell_peaks_mask(cf_4d, dstol=main_phase_cf_dstol, dsmax=cf_4d.ds.max())

        minor_phase_peaks = cf_4d.copy()
        minor_phase_peaks.filter(~main_phase_peaks_mask)

        # Update geometry for minor phase peaks

        minor_phase_peaks.parameters.loadparameters(minor_phase_par_path)
        minor_phase_peaks.updateGeometry()
        
        cf_strong = utils.selectpeaks(minor_phase_peaks, frac=cf_strong_frac, dsmax=cf_4d.ds.max(), dstol=cf_strong_dstol)
        
        if is_half_scan:
            utils.correct_half_scan(ds)
        
        main_phase_grains = read_grains_main_phase(ds)
        whole_sample_mask = main_phase_grains[0].sample_mask
        y0 = main_phase_grains[0].y0
        pad = ((main_phase_grains[0].recon.shape[0] - main_phase_grains[0].ssino.shape[0]))
            
        utils.assign_peaks_to_grains(grains, cf_strong, tol=peak_assign_tol)
        
        for g in tqdm(grains):
            g.mask_4d = cf_strong.grain_id == g.gid
            g.peaks_4d = cf_strong.index[g.mask_4d]
            
        for grain in tqdm(grains):
            # grain.peaks_4d_selected, grain.cen, grain.dx, grain.dy = utils.graincen(grain.gid, cf_strong, doplot=False)
            grain.rgb_z = utils.grain_to_rgb(grain, ax=(0,0,1),)# symmetry = Symmetry.cubic)
            grain.rgb_y = utils.grain_to_rgb(grain, ax=(0,1,0),)# symmetry = Symmetry.cubic)
            grain.rgb_x = utils.grain_to_rgb(grain, ax=(1,0,0),)# symmetry = Symmetry.cubic)
            utils.fit_grain_position_from_sino(grain, cf_strong)
            
        c0 = np.median([g.cen for g in grains])
        
        for g in grains:
            g.translation = np.array([g.dx, g.dy, 0])
        
        print("Peak 2D organise")
        pks = ImageD11.sinograms.properties.pks_table.load(ds.pksfile)
        p2d = pks.pk2d(ds.omega, ds.dty)
        numba_order, numba_histo = utils.counting_sort(p2d['spot3d_id'])
        grain_2d_id = utils.palloc(p2d['spot3d_id'].shape, np.dtype(int))
        cleanid = cf_strong.grain_id.copy()
        utils.find_grain_id(cf_strong.spot3d_id, cleanid, p2d['spot3d_id'], grain_2d_id, numba_order)
        gord, counts = utils.counting_sort(grain_2d_id)
        inds = np.concatenate(((0,), np.cumsum(counts)))
        
        for grain in tqdm(grains):
            i = grain.gid
            grain.peaks_2d = gord[inds[i+1] : inds[i+2]]
        
        print("Making sinograms")
        with concurrent.futures.ThreadPoolExecutor(max_workers= max(1,nthreads-1)) as pool:
            for i in tqdm(pool.map(do_sinos, grains), total=len(grains)):
                pass
        
        if correct_sinos_with_ring_current:
            for grain in tqdm(grains):
                utils.correct_sinogram_rows_with_ring_current(grain, ds)
        
        print("Running iradon")
        
        run_this_iradon = partial(run_iradon_id11, pad=pad, y0=y0, sample_mask=whole_sample_mask, workers=1, apply_halfmask=is_half_scan, mask_central_zingers=is_half_scan)

        with concurrent.futures.ThreadPoolExecutor( max_workers= max(1,nthreads-1) ) as pool:
            for i in tqdm(pool.map(run_this_iradon, grains), total=len(grains)):
                pass
            
        with concurrent.futures.ThreadPoolExecutor(max_workers= max(1, nthreads-1)) as pool:
            for i in tqdm(pool.map(find_cens_from_recon, grains), total=len(grains)):
                pass
        
        grains = [grain for grain in grains if not grain.bad_recon]
        
        for g in grains:
            g.translation = np.array([g.x_blob, -g.y_blob, 0])
            
        rgb_x_array, rgb_y_array, rgb_z_array, grain_labels_array, raw_intensity_array = utils.build_slice_arrays(grains, cutoff_level)
        
        labels, counts = np.unique(grain_labels_array, return_counts=True)
        bad_gids = [int(label) for (label, count) in zip(labels, counts) if count > grain_too_many_px and label > 0]
        
        grains = [grain for grain in grains if grain.gid not in bad_gids]
        
        rgb_x_array, rgb_y_array, rgb_z_array, grain_labels_array, raw_intensity_array = utils.build_slice_arrays(grains, cutoff_level)
        
        save_grains_minor_phase(grains, ds)

print("Done!")