# Jupyter notebook based on ImageD11 to process scanning 3DXRD data
# Written by Haixing Fang, Jon Wright and James Ball
## Date: 16/01/2024

In [None]:
# There is a bug with the current version of ImageD11 in the site-wide Jupyter env.
# This has been fixed here: https://github.com/FABLE-3DXRD/ImageD11/commit/4af88b886b1775585e868f2339a0eb975401468f
# Until a new release has been made and added to the env, we need to get the latest version of ImageD11 from GitHub
# Put it in your home directory somewhere
# USER: Change the path below to point to your local copy of ImageD11:

id11_code_path = "/home/esrf/james1997a/Code/ImageD11"

import sys

sys.path.insert(0, id11_code_path)

In [None]:
# import functions we need

import os
import concurrent.futures
import timeit

import matplotlib
%matplotlib widget

import h5py
import tqdm
import numba
import pprint
import numpy as np
import skimage.transform
import ipywidgets as ipyw
import matplotlib.pyplot as plt

import ImageD11.refinegrains
import ImageD11.columnfile
import ImageD11.sinograms.properties
import ImageD11.sinograms.roi_iradon
from ImageD11.blobcorrector import eiger_spatial
from ImageD11.grain import grain

In [None]:
# OLD DATASETS

# NOTE: For old datasets before the new directory layout structure, we don't distinguish between RAW_DATA and PROCESSED_DATA
# In this case, use this cell to specify where your experimental folder is, and delete the # NEW DATASETS cell below
# e.g /data/visitor/4752/id11/20210513

### USER: specify your experimental directory

rawdata_path = "/home/esrf/james1997a/Data/ma4752/id11/20210618"

!ls -lrt {rawdata_path}

### USER: specify where you want your processed data to go

processed_data_root_dir = "/home/esrf/james1997a/Data/ma4752/id11/20240118/James"
sparse_pixels_dir = os.path.join(processed_data_root_dir, "SparsePixels")  # USER: modify this to change the name of the SparsePixels folder inside processed_data_root_dir

In [None]:
# NEW DATASETS
# If you have RAW_DATA and PROCESSED_DATA folders, you should run this cell and delete the # OLD DATASETS cell above

### USER: specify your experimental directory

base_dir = "/home/esrf/james1997a/Data/ihma439/id11/20231211"

rawdata_path = os.path.join(base_dir, 'RAW_DATA')

!ls -lrt {rawdata_path}

processed_data_root_dir = os.path.join(base_dir, 'PROCESSED_DATA/James/20240123')  # USER: modify this to change the destination folder if desired
sparse_pixels_dir = os.path.join(processed_data_root_dir, "SparsePixels")  # USER: modify this to change the name of the SparsePixels folder inside processed_data_root_dir

In [None]:
# USER: pick a sample and a dataset you want to segment

sample = "FeAu_0p5_tR_nscope"
dataset = "top_100um"

In [None]:
# desination of H5 files

dset_path = os.path.join(sparse_pixels_dir, f"ds_{sample}_{dataset}.h5" )
sparse_path = os.path.join(sparse_pixels_dir, f'{sample}_{dataset}_sparse.h5')
pks_path = os.path.join(sparse_pixels_dir, f'pks_{sample}_{dataset}.h5')
cf_path = os.path.join(sparse_pixels_dir, f'cf_{sample}_{dataset}.h5')
grains_path = os.path.join(sparse_pixels_dir, f'grains_{sample}_{dataset}.map')
par_path = 'Fe_refined.par'

e2dx_path = os.path.join(processed_data_root_dir, '../../CeO2/e2dx_E-08-0173_20231127.edf')
e2dy_path = os.path.join(processed_data_root_dir, '../../CeO2/e2dy_E-08-0173_20231127.edf')

In [None]:
# Import spatially-corrected 4D peaks from file (no intensity filtering)

cf_4d = ImageD11.columnfile.columnfile(cf_path)

cf_4d.parameters.loadparameters(par_path)
cf_4d.updateGeometry()

print(f"Read {cf_4d.nrows} peaks")

# Import grains from file
    
grains = ImageD11.grain.read_grain_file(grains_path)

for grain in grains:
    grain.a = np.cbrt(np.linalg.det(grain.ubi))
    
print(f"{len(grains)} grains imported")

In [None]:
# filter the peaks to select only the brightest ones for sinogram use

def strongest_peaks(colf, uself=True, frac=0.995, B=0.2, doplot=None):
    # correct intensities for structure factor (decreases with 2theta)
    cor_intensity = colf.sum_intensity * (np.exp(colf.ds*colf.ds*B))
    if uself:
        lf = ImageD11.refinegrains.lf(colf.tth, colf.eta)
        cor_intensity *= lf
    order = np.argsort( cor_intensity )[::-1] # sort the peaks by intensity
    sortedpks = cor_intensity[order]
    cums =  np.cumsum(sortedpks)
    cums /= cums[-1]
    enough = np.searchsorted(cums, frac)
    # Aim is to select the strongest peaks for indexing.
    cutoff = sortedpks[enough]
    mask = cor_intensity > cutoff
    if doplot is not None:
        fig, axs = plt.subplots(1,2,figsize=(10,5))
        axs[0].plot(cums/cums[-1], ',')
        axs[0].set(xlabel='npks',ylabel='fractional intensity')
        axs[0].plot([mask.sum(),], [frac,], "o" )
        axs[1].plot(cums/cums[-1], ',')
        axs[1].set(xlabel='npks logscale',ylabel='fractional intensity', xscale='log', ylim=(doplot,1.), 
                 xlim=(np.searchsorted(cums, doplot), len(cums)))
        axs[1].plot( [mask.sum(),], [frac,], "o" )
        plt.show()
    return mask

def selectpeaks( cf, dstol=0.005, dsmax = 100, frac=0.99):
    cell = ImageD11.unitcell.unitcell_from_parameters( cf.parameters )
    cell.makerings( dsmax )
    m = np.zeros( cf.nrows, bool )
    for v in cell.ringds:
        if v < dsmax:
            m |= (abs(cf.ds - v) < dstol)
    cfc = cf.copy()
    cfc.filter( m )
    ms = strongest_peaks( cfc, frac = frac, doplot = frac*0.5 )
    cfc.filter( ms )
    return cfc

In [None]:
# here we are filtering our peaks (cf_4d) to select only the strongest ones
# this time as opposed to indexing, our frac is slightly weaker but we are NOT filtering in dstar!!!!!
# this means many more peaks per grain = stronger sinograms

# USER: modify the "frac" parameter below and re-run the cell until the orange dot sits nicely on the "elbow" of the blue line
# this indicates the fractional intensity cutoff we will select
# if the blue line does not look elbow-shaped in the logscale plot, try changing the "doplot" parameter (the y scale of the logscale plot) until it does

cf_strong = selectpeaks(cf_4d, frac=0.99, dsmax=cf_4d.ds.max())

In [None]:
def assign_peaks_to_grains( cf, grains, tol = 0.05 ):
    labels = np.zeros( cf.nrows, 'i' )
    gv = np.transpose( (cf.gx,cf.gy,cf.gz)).astype( float )
    drlv2 =  np.ones( cf.nrows, 'd' )
    for i in range(len(grains)):
        n = ImageD11.cImageD11.score_and_assign( grains[i].ubi, gv, tol, drlv2, labels, i )

    indx = np.arange(cf.nrows,dtype='i')
    li = labels.copy()
    cf.addcolumn(li, 'grain_id')
    for i, g in enumerate(grains):
        g.gid = i
        g.mask = g.m = (li == i)       # which peaks were assigned to this grain by ubi
        g.pks = indx[li==i]
        g.npks = len(g.pks)
        g.hkl = np.round (np.dot(g.ubi, gv[g.mask].T)).astype( int )
        g.etasigns = np.sign(cf.eta[ g.mask ]).astype(int)

        # how many peaks does this grain get out of the ones it would have initially taken?
        labels.fill(-1)
        drlv2.fill(1)
        j = ImageD11.cImageD11.score_and_assign( g.ubi, gv, tol, drlv2, labels, i)
        g.allnpks = j
        g.allmask = (labels == i)
        print(i,g.npks,"%.3f"%(g.npks/g.allnpks))
    
    return cf, grains

In [None]:
cf_strong, grains = assign_peaks_to_grains(cf_strong, grains)

In [None]:
fig, ax = plt.subplots()
m = cf_strong.grain_id >= 0
ax.scatter(cf_strong.omega[m], cf_strong.dty[m], c=cf_strong.grain_id[m])
plt.show()

In [None]:
mean_unit_cell_lengths = [grain.a for grain in grains]

fig, ax = plt.subplots()
ax.plot(mean_unit_cell_lengths)
ax.set_xlabel("Grain ID")
ax.set_ylabel("Unit cell length")
plt.show()

a0 = np.median(mean_unit_cell_lengths)
    
print(a0)

In [None]:
grid_size = np.ceil(np.sqrt(len(grains))).astype(int)
nrows = (len(grains)+grid_size-1)//grid_size

fig, axs = plt.subplots(grid_size, nrows, figsize=(10,10), layout="constrained", sharex=True, sharey=True)
for i, ax in enumerate(axs.ravel()):
    if i < len(grains):
    # get corresponding grain for this axis
        g = grains[i]
        m = cf_strong.grain_id == g.gid
        ax.plot(cf_strong.omega[m], cf_strong.dty[m], ',')
        
fig.supxlabel("Omega")
fig.supylabel("Y translation (um)")
    
plt.show()

In [None]:
def grain_to_rgb(g, ax=(0,0,1)):
    return hkl_to_color_cubic(crystal_direction_cubic(g.ubi, ax))

def crystal_direction_cubic(ubi, axis):
    hkl = np.dot(ubi, axis)
    # cubic symmetry implies:
    #      24 permutations of h,k,l
    #      one has abs(h) <= abs(k) <= abs(l)
    hkl= abs(hkl)
    hkl.sort()
    return hkl

def hkl_to_color_cubic(hkl):
    """
    https://mathematica.stackexchange.com/questions/47492/how-to-create-an-inverse-pole-figure-color-map
        [x,y,z]=u⋅[0,0,1]+v⋅[0,1,1]+w⋅[1,1,1].
            These are:
                u=z−y, v=y−x, w=x
                This triple is used to assign each direction inside the standard triangle
                
    makeColor[{x_, y_, z_}] := 
         RGBColor @@ ({z - y, y - x, x}/Max@{z - y, y - x, x})                
    """
    x,y,z = hkl
    assert x<=y<=z
    assert z>=0
    u,v,w = z-y, y-x, x
    m = max( u, v, w )
    r,g,b = u/m, v/m, w/m
    return (r,g,b)

def hkl_to_pf_cubic(hkl):
    x,y,z = hkl
    assert x<=y<=z
    assert z>=0
    m = np.sqrt((hkl**2).sum())
    return x/(z+m), y/(z+m)

def triangle():
    """ compute a series of point on the edge of the triangle """
    xy = [ np.array(v) for v in ( (0,1,1), (0,0,1), (1,1,1)) ]
    xy += [ xy[2]*(1-t) + xy[0]*t for t in np.linspace(0.1,1,5)]
    return np.array( [hkl_to_pf_cubic( np.array(p) ) for p in xy] )


def calcy(cos_omega, sin_omega, sol):
    return sol[0] + cos_omega*sol[1] + sin_omega*sol[2]

def fity(y, cos_omega, sin_omega, wt=1):
    """
    Fit a sinogram to get a grain centroid
    # calc = d0 + x*co + y*so
    # dc/dpar : d0 = 1
    #         :  x = co
    #         :  y = so
    # gradients
    # What method is being used here???????????
    """
    g = [wt*np.ones(y.shape, float),  wt*cos_omega, wt*sin_omega]
    nv = len(g)
    m = np.zeros((nv,nv),float)
    r = np.zeros( nv, float )
    for i in range(nv):
        r[i] = np.dot( g[i], wt * y )
        for j in range(i,nv):
            m[i,j] = np.dot( g[i], g[j] )
            m[j,i] = m[i,j]
    sol = np.dot(np.linalg.inv(m), r)
    return sol


def fity_robust(dty, co, so, nsigma=5, doplot=False):
    # NEEDS COMMENTING
    cen, dx, dy = fity(dty, co, so)
    calc2 = calc1 = calcy(co, so, (cen, dx, dy))
    selected = np.ones(co.shape, bool)
    for i in range(3):
        err = dty - calc2
        estd = max( err[selected].std(), 1.0 ) # 1 micron
        #print(i,estd)
        es = estd*nsigma
        selected = abs(err) < es
        cen, dx, dy = fity( dty, co, so, selected.astype(float) )
        calc2 = calcy(co, so, (cen, dx, dy))
    # bad peaks are > 5 sigma
    if doplot:
        f, a = plt.subplots(1,2)
        theta = np.arctan2( so, co )
        a[0].plot(theta, calc1, ',')
        a[0].plot(theta, calc2, ',')
        a[0].plot(theta[selected], dty[selected], "o")
        a[0].plot(theta[~selected], dty[~selected], 'x')
        a[1].plot(theta[selected], (calc2 - dty)[selected], 'o')
        a[1].plot(theta[~selected], (calc2 - dty)[~selected], 'x')
        a[1].set(ylim = (-es, es))
        pl.show()
    return selected, cen, dx, dy

def graincen(gid, colf, doplot=True):
    # Get peaks beloging to this grain ID
    m = colf.grain_id == gid
    # Get omega values of peaks in radians
    romega = np.radians(colf.omega[m])
    # Calculate cos and sin of omega
    co = np.cos(romega)
    so = np.sin(romega)
    # Get dty values of peaks
    dty = colf.dty[m]
    selected, cen, dx, dy = fity_robust(dty, co, so, doplot=doplot)
    return selected, cen, dx, dy


@numba.njit(parallel=True)
def pmax(ary):
    """ Find the min/max of an array in parallel """
    mx = ary.flat[0]
    mn = ary.flat[0]
    for i in numba.prange(1,ary.size):
        mx = max( ary.flat[i], mx )
        mn = min( ary.flat[i], mn )
    return mn, mx

@numba.njit(parallel=True)
def palloc(shape, dtype):
    """ Allocate and fill an array with zeros in parallel """
    ary = np.empty(shape, dtype=dtype)
    for i in numba.prange( ary.size ):
        ary.flat[i] = 0
    return ary

# counting sort by grain_id
@numba.njit
def counting_sort(ary, maxval=None, minval=None):
    """ Radix sort for integer array. Single threaded. O(n)
    Numpy should be doing this...
    """
    if maxval is None:
        assert minval is None
        minval, maxval = pmax( ary ) # find with a first pass
    maxval = int(maxval)
    minval = int(minval)
    histogram = palloc( (maxval - minval + 1,), np.int64 )
    indices = palloc( (maxval - minval + 2,), np.int64 )
    result = palloc( ary.shape, np.int64 )
    for gid in ary:
        histogram[gid - minval] += 1
    indices[0] = 0
    for i in range(len(histogram)):
        indices[ i + 1 ] = indices[i] + histogram[i]
    i = 0
    for gid in ary:
        j = gid - minval
        result[indices[j]] = i
        indices[j] += 1
        i += 1
    return result, histogram


@numba.njit(parallel=True)
def find_grain_id(spot3d_id, grain_id, spot2d_label, grain_label, order, nthreads=20):
    """
    Assignment grain labels into the peaks 2d array
    spot3d_id = the 3d spot labels that are merged and indexed
    grain_id = the grains assigned to the 3D merged peaks
    spot2d_label = the 3d label for each 2d peak
    grain_label => output, which grain is this peak
    order = the order to traverse spot2d_label sorted
    """
    assert spot3d_id.shape == grain_id.shape
    assert spot2d_label.shape == grain_label.shape
    assert spot2d_label.shape == order.shape
    T = nthreads
    print("Using",T,"threads")
    for tid in numba.prange( T ):
        pcf = 0 # thread local I hope?
        for i in order[tid::T]:
            grain_label[i] = -1
            pkid = spot2d_label[i]
            while spot3d_id[pcf] < pkid:
                pcf += 1
            if spot3d_id[pcf] == pkid:
                grain_label[i] = grain_id[pcf]
                

def tocolf(pkd, parfile, dxfile=e2dx_path, dyfile=e2dy_path):
    """ Converts a dictionary of peaks into and ImageD11 columnfile
    adds on the geometric computations (tth, eta, gvector, etc) """
    spat = eiger_spatial(dxfile=dxfile, dyfile=dyfile)
    cf = ImageD11.columnfile.colfile_from_dict(spat(pkd))
    cf.parameters.loadparameters(parfile)
    cf.updateGeometry()
    return cf


def map_grain_from_peaks(g, flt, ds):
    """
    Computes sinogram
    flt is already the peaks for this grain
    Runs iradon
    Returns angles, sino, recon
    """   
    NY = len(ds.ybincens)
    iy = np.round( (flt.dty - ds.ybincens[0]) / (ds.ybincens[1]-ds.ybincens[0]) ).astype(int)

    # The problem is to assign each spot to a place in the sinogram
    hklmin = g.hkl.min(axis=1)
    dh = g.hkl - hklmin[:,np.newaxis]
    de = (g.etasigns.astype(int) + 1)//2
    #   4D array of h,k,l,+/-
    pkmsk = np.zeros( list(dh.max(axis=1) + 1 )+[2,], int )
    pkmsk[ dh[0], dh[1], dh[2], de ] = 1
    #   sinogram row to hit
    pkrow = np.cumsum( pkmsk.ravel() ).reshape( pkmsk.shape ) - 1
    pkhkle = np.arange( np.prod( pkmsk.shape ), dtype=int )[ pkmsk.flat == 1 ]
    npks = pkmsk.sum( )
    destRow = pkrow[ dh[0], dh[1], dh[2], de ] 
    sino = np.zeros( ( npks, NY ), 'f' )
    hits = np.zeros( ( npks, NY ), 'f' )
    angs = np.zeros( ( npks, NY ), 'f' )
    adr = destRow * NY + iy 
    # Just accumulate 
    sig = flt.sum_intensity
    ImageD11.cImageD11.put_incr64( sino, adr, sig )
    ImageD11.cImageD11.put_incr64( hits, adr, np.ones(len(de),dtype='f'))
    ImageD11.cImageD11.put_incr64( angs, adr, flt.omega)
    
    sinoangles = angs.sum( axis = 1) / hits.sum( axis = 1 )
    # Normalise:
    sino = (sino.T/sino.max( axis=1 )).T
    # Sort (cosmetic):
    order = np.lexsort( (np.arange(npks), sinoangles) )
    sinoangles = sinoangles[order]
    ssino = sino[order].T
    return sinoangles, ssino, hits[order].T     

In [None]:
for grain in grains:
    grain.pks3d, grain.cen, grain.dx, grain.dy = graincen(grain.gid, cf_strong, doplot=False)
    grain.rgb_z = grain_to_rgb(grain, ax=(0,0,1),)# symmetry = Symmetry.cubic)
    grain.rgb_y = grain_to_rgb(grain, ax=(0,1,0),)# symmetry = Symmetry.cubic)
    grain.rgb_x = grain_to_rgb(grain, ax=(1,0,0),)# symmetry = Symmetry.cubic)

In [None]:
# make sure we get cen right (centre of rotation should be the middle of dty)
fig, ax = plt.subplots()
ax.plot([g.cen for g in grains])

plt.show()

In [None]:
c0 = np.median([g.cen for g in grains])

print('Center of rotation in dty', c0)

# c0 is being correctly determined
# we know this because of the earlier single-grain dty vs omega plot
# if g.cen was off, the fit would be shifted

In [None]:
plt.style.use('dark_background')
fig, ax = plt.subplots(2,2, figsize=(12,12))
a = ax.ravel()
x = [g.dx for g in grains]
y = [g.dy for g in grains]
s = [g.pks3d.sum()/10 for g in grains]
a[0].scatter(x, y, s=s, c=[g.rgb_z for g in grains])
a[0].set(title='IPF color Z',  aspect='equal')
a[1].scatter(x, y, s=s, c=[g.rgb_y for g in grains])
a[1].set(title='IPF color Y', aspect='equal')
a[2].scatter(x, y, s=s, c=[g.rgb_x for g in grains])
a[2].set(title='IPF color X',  aspect='equal')
a[3].scatter(x, y, c=s)
a[3].set(title='Number of 3d peaks', aspect='equal')

fig.supxlabel("Lab x")
fig.supylabel("Lab y")

plt.show()

In [None]:
# Big scary block
# Must understand what this does!

# Ensure cf is sorted by spot3d_id
# NOTE: spot3d_id should be spot4d_id, because we have merged into 4D?
assert (np.argsort(cf_strong.spot3d_id) == np.arange(cf_strong.nrows)).all()

# load the 2d peak labelling output
pks = ImageD11.sinograms.properties.pks_table.load(pks_path)

# Load the dataset (for motor positions, not sure why these are not in peaks)
ds = ImageD11.sinograms.dataset.load(dset_path)

# Grab the 2d peak centroids
p2d = pks.pk2d(ds.omega, ds.dty)

# NOTE: These are not spatially corrected?!

numba_order, numba_histo = counting_sort(p2d['spot3d_id'])

grain_2d_id = palloc(p2d['spot3d_id'].shape, np.dtype(int))

cleanid = cf_strong.grain_id.copy()

find_grain_id(cf_strong.spot3d_id, cleanid, p2d['spot3d_id'], grain_2d_id, numba_order)

gord, counts = counting_sort(grain_2d_id)

inds = np.concatenate(((0,), np.cumsum(counts)))

# I think what we end up with is:
# inds
# this is an array which tells you which 2D spots each grain owns
# the 2D spots are sorted by spot ID
# inds tells you for each grain were you can find its associated 2D spots

In [None]:
def do_sinos(g, hkltol=0.250):
    i = g.gid
    # the inds[0] refers to not indexed peaks
    g.pks = gord[inds[i+1] : inds[i+2]]  
    assert grain_2d_id[g.pks[0]] == i
    flt = tocolf( {p:p2d[p][g.pks] for p in p2d} , par_path)
    hkl_real = np.dot(g.ubi, (flt.gx, flt.gy, flt.gz))
    hkl_int = np.round(hkl_real).astype(int)
    dh = ((hkl_real - hkl_int)**2).sum(axis = 0)
    assert len(dh) == flt.nrows
    g.dherrall = dh.mean()
    g.npksall = flt.nrows
    flt.filter( dh < hkltol*hkltol )
    hkl_real = np.dot(g.ubi, (flt.gx, flt.gy, flt.gz))
    hkl_int = np.round(hkl_real).astype(int)
    dh = ((hkl_real - hkl_int)**2).sum(axis = 0)
    g.dherr = dh.mean()
    g.npks = flt.nrows
    g.etasigns = np.sign(flt.eta)
    g.hkl = hkl_int
    g.sinoangles, g.ssino, g.hits = map_grain_from_peaks(g, flt, ds)
    return i,g

In [None]:
# Determine sinograms of all grains

nthreads = len(os.sched_getaffinity(os.getpid()))

with concurrent.futures.ThreadPoolExecutor(max_workers= max(1,nthreads-1)) as pool:
    for i in tqdm.tqdm(pool.map(do_sinos, grains), total=len(grains)):
        pass

In [None]:
# Show sinogram of single grain

g = grains[0]

fig, ax = plt.subplots()

ax.imshow((g.ssino/g.ssino.mean(axis=0)), norm=matplotlib.colors.LogNorm(), interpolation='nearest', origin="lower")

plt.show()

In [None]:
def run_iradon_id11(grain, pad=20, y0=c0/2):
    outsize = grain.ssino.shape[0] + pad
    
    # Perform iradon transform of grain sinogram, store result (reconstructed grain shape) in g.recon
    grain.recon = ImageD11.sinograms.roi_iradon.iradon(grain.ssino, 
                                                       theta=grain.sinoangles, 
                                                       output_size=outsize,
                                                       projection_shifts=np.full(grain.ssino.shape, -y0),
                                                       filter_name='hamming'
                                                       )
    return grain

In [None]:
g = grains[0]

run_iradon_id11(g)

In [None]:
fig, axs = plt.subplots(1,2)
axs[0].imshow(g.ssino)
axs[0].set_title("ID11 Sinogram")
axs[1].imshow(g.recon, origin="lower", vmin=0)
axs[1].set_title("ID11 iradon")
plt.show()

In [None]:
with concurrent.futures.ThreadPoolExecutor( max_workers= max(1,nthreads-1) ) as pool:
    for i in tqdm.tqdm(pool.map(run_iradon_id11, grains), total=len(grains)):
        pass

In [None]:
f,a = plt.subplots( 1,3, figsize=(15,5) )
ty, tx = triangle().T
for i,title in enumerate( 'xyz' ):
    ax = np.zeros(3)
    ax[i] = 1.
    hkl = [crystal_direction_cubic( g.ubi, ax ) for g in grains]
    xy = np.array([hkl_to_pf_cubic(h) for h in hkl ])
    rgb = np.array([hkl_to_color_cubic(h) for h in hkl ])
    for j in range(len(grains)):
        grains[j].rgb = rgb[j]
    a[i].scatter( xy[:,1], xy[:,0], c = rgb )   # Note the "x" axis of the plot is the 'k' direction and 'y' is h (smaller)
    a[i].set(title=title, aspect='equal', facecolor='silver', xticks=[], yticks=[])
    a[i].plot( tx, ty, 'k-', lw = 1 )

In [None]:
l = np.zeros_like( grains[0].recon ) - 1
red = np.zeros_like( grains[0].recon )
grn = np.zeros_like( grains[0].recon )
blu = np.zeros_like( grains[0].recon )

redl = np.zeros_like( grains[0].recon )
grnl = np.zeros_like( grains[0].recon )
blul = np.zeros_like( grains[0].recon )

s = np.zeros_like( grains[0].recon )
s.fill(0.4)

def norm(r):
    m = r > r.max()*0.2
    return (r/r[m].mean()).clip(0,1)

def smoothclamp(x, mi, mx):
    return mi + (mx-mi)*(lambda t: np.where(t < 0 , 0, np.where( t <= 1 , 3*t**2-2*t**3, 1 ) ) )( (x-mi)/(mx-mi) )

for g in grains:
    i = g.gid
    # print(i, scale)
    # r = norm(g.recon)
    
    r = smoothclamp(norm(g.recon), 0, 1)
    
    m = r > s
    px = r[m]
    s[m] = px
    red[m] = px*g.rgb_z[0]
    grn[m] = px*g.rgb_z[1]
    blu[m] = px*g.rgb_z[2]
    
    redl[m] = g.rgb_z[0]
    grnl[m] = g.rgb_z[1]
    blul[m] = g.rgb_z[2]
    
    l[m] = i

In [None]:
image_to_show = np.transpose((red, grn, blu), axes=(1, 2, 0))
fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(image_to_show)  # originally 1,2,0
plt.show()