# Jupyter notebook based on ImageD11 to process scanning 3DXRD data
# Written by Haixing Fang, Jon Wright and James Ball
## Date: 21/02/2024

In [None]:
# There is a bug with the current version of ImageD11 in the site-wide Jupyter env.
# This has been fixed here: https://github.com/FABLE-3DXRD/ImageD11/commit/4af88b886b1775585e868f2339a0eb975401468f
# Until a new release has been made and added to the env, we need to get the latest version of ImageD11 from GitHub
# Put it in your home directory
# USER: Change the path below to point to your local copy of ImageD11:

import os

username = os.environ.get("USER")

id11_code_path = f"/home/esrf/{username}/Code/ImageD11"

import sys

sys.path.insert(0, id11_code_path)

In [None]:
# import functions we need

import concurrent.futures
import timeit
import glob
import pprint
from shutil import rmtree
import time

import matplotlib
%matplotlib ipympl

import h5py
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt

import ImageD11.columnfile
from ImageD11.sinograms import properties, roi_iradon
from ImageD11.blobcorrector import eiger_spatial
from ImageD11.grain import grain

from skimage.filters import threshold_otsu
from skimage.morphology import convex_hull_image

import ImageD11.nbGui.nb_utils as utils

import ipywidgets as widgets
from ipywidgets import interact

In [None]:
# NOTE: For old datasets before the new directory layout structure, we don't distinguish between RAW_DATA and PROCESSED_DATA
# In this case, use this cell to specify where your experimental folder is, and do not run the cell below
# e.g /data/visitor/ma4752/id11/20210513

### USER: specify your experimental directory

rawdata_path = "/home/esrf/james1997a/Data/ihma439/id11/20231211/RAW_DATA"

!ls -lrt {rawdata_path}

### USER: specify where you want your processed data to go

processed_data_root_dir = "/home/esrf/james1997a/Data/ihma439/id11/20231211/PROCESSED_DATA/James/20240221"

In [None]:
# USER: pick a sample and a dataset you want to segment

sample = "FeAu_0p5_tR_nscope"
dataset = "top_100um"

In [None]:
# desination of H5 files

dset_path = os.path.join(processed_data_root_dir, sample, f"{sample}_{dataset}", f"{sample}_{dataset}_dataset.h5")

par_path = os.path.join(processed_data_root_dir, 'Fe_refined.par')

e2dx_path = os.path.join(processed_data_root_dir, '../../CeO2/e2dx_E-08-0173_20231127.edf')
e2dy_path = os.path.join(processed_data_root_dir, '../../CeO2/e2dy_E-08-0173_20231127.edf')

In [None]:
# Load the dataset (for motor positions, not sure why these are not in peaks)
ds = ImageD11.sinograms.dataset.load(dset_path)

In [None]:
# Import 4D peaks

cf_4d = ImageD11.columnfile.columnfile(ds.col4dfile)

cf_4d.parameters.loadparameters(par_path)
cf_4d.updateGeometry()

print(f"Read {cf_4d.nrows} 4D peaks")

In [None]:
def read_grains(ds):
    with h5py.File(ds.grainsfile, 'r') as hin:      
        grains_group = 'grains'
        
        grains = []
        for gid_string in tqdm(sorted(hin[grains_group].keys(), key=lambda x: int(x))):
            gg = hin[grains_group][gid_string]
            ubi = gg.attrs['ubi'][:]
            g = ImageD11.grain.grain(ubi)
            g.gid = int(gid_string)
            grains.append(g)
    
    return grains

In [None]:
grains = read_grains(ds)

for grain in grains:
    # print(grain.gid)
    grain.a = np.cbrt(np.linalg.det(grain.ubi))
    
print(f"{len(grains)} grains imported")

In [None]:
# here we are filtering our peaks (cf_4d) to select only the strongest ones
# this time as opposed to indexing, our frac is slightly weaker but we are NOT filtering in dstar!!!!!
# this means many more peaks per grain = stronger sinograms

# USER: modify the "frac" parameter below and re-run the cell until the orange dot sits nicely on the "elbow" of the blue line
# this indicates the fractional intensity cutoff we will select
# if the blue line does not look elbow-shaped in the logscale plot, try changing the "doplot" parameter (the y scale of the logscale plot) until it does

cf_strong = utils.selectpeaks(cf_4d, frac=0.995, dsmax=cf_4d.ds.max(), doplot=0.9)
print(cf_4d.nrows)
cf_strong.nrows

In [None]:
# # now let's do a whole-sample tomographic reconstruction

In [None]:
# If the sinograms are only half-sinograms (we scanned dty across half the sample rather than the full sample), set the below to true:
is_half_scan = False

In [None]:
if is_half_scan:
    utils.correct_half_scan(ds)

In [None]:
utils.assign_peaks_to_grains(grains, cf_strong, tol=0.25)

print("Storing peak data in grains")
# iterate through all the grains
for g in tqdm(grains):
    # store this grain's peak indices so we know which 4D peaks we used for indexing
    g.mask_4d = cf_strong.grain_id == g.gid
    g.peaks_4d = cf_strong.index[cf_strong.grain_id == g.gid]

In [None]:
for grain in tqdm(grains):
    grain.peaks_4d_selected, grain.cen, grain.dx, grain.dy = utils.graincen(grain.gid, cf_strong, doplot=False)
    grain.rgb_z = utils.grain_to_rgb(grain, ax=(0,0,1),)# symmetry = Symmetry.cubic)
    grain.rgb_y = utils.grain_to_rgb(grain, ax=(0,1,0),)# symmetry = Symmetry.cubic)
    grain.rgb_x = utils.grain_to_rgb(grain, ax=(1,0,0),)# symmetry = Symmetry.cubic)

In [None]:
# make sure we get cen right (centre of rotation should be the middle of dty)
fig, ax = plt.subplots()
ax.plot([g.cen for g in grains])

plt.show()

In [None]:
c0 = np.median([g.cen for g in grains])

print('Center of rotation in dty', c0)

In [None]:
# generate sinogram for whole sample

whole_sample_sino, xedges, yedges = np.histogram2d(cf_4d.dty, cf_4d.omega, bins=[ds.ybinedges, ds.obinedges[::2]])

fig, ax = plt.subplots()
ax.imshow(whole_sample_sino, interpolation="nearest", vmin=0)
ax.set_aspect(1)
plt.show()

In [None]:
# "quick" MLEM reconstruction

pad = 50

outsize = whole_sample_sino.shape[0] + pad

nthreads = len(os.sched_getaffinity(os.getpid()))

if is_half_scan:
    halfmask = np.zeros_like(whole_sample_sino)

    halfmask[:len(halfmask)//2-1, :] = 1
    halfmask[len(halfmask)//2-1, :] = 0.5

    ssino_to_recon = whole_sample_sino * halfmask
else:
    ssino_to_recon = whole_sample_sino

recon = ImageD11.sinograms.roi_iradon.mlem(ssino_to_recon, 
                                           theta=ds.obincens[::2],
                                           workers=nthreads - 1,
                                           output_size=outsize,
                                           projection_shifts=np.full(ssino_to_recon.shape, -c0/2),
                                           niter=30)

In [None]:
# we should be able to easily segment this using scikit-image
recon_man_mask = recon.copy()

# we can incoporate our own mask too
# by uncommenting and modifying the below lines
# without a mask, MLEM can introduce artifacts in the corners
# so we can manually mask those out

recon_man_mask[280:, 280:] = 0

thresh = threshold_otsu(recon_man_mask)

# we can also override the threshold if we don't like it:

# thresh = 0.025

binary = recon_man_mask > thresh

chull = convex_hull_image(binary)

fig, axs = plt.subplots(1, 3, sharex=True, sharey=True, constrained_layout=True)
axs[0].imshow(recon_man_mask, vmin=0)
axs[1].imshow(binary)
axs[2].imshow(chull)

axs[0].set_title("MLEM reconstruction")
axs[1].set_title("Binarised threshold")
axs[2].set_title("Convex hull")

plt.show()

In [None]:
whole_sample_mask = chull

In [None]:
fig, ax = plt.subplots()
m = cf_strong.grain_id >= 0
ax.scatter(cf_strong.omega[m], cf_strong.dty[m], c=cf_strong.grain_id[m])
plt.show()

In [None]:
mean_unit_cell_lengths = [grain.a for grain in grains]

fig, ax = plt.subplots()
ax.plot(mean_unit_cell_lengths)
ax.set_xlabel("Grain ID")
ax.set_ylabel("Unit cell length")
plt.show()

a0 = np.median(mean_unit_cell_lengths)
    
print(a0)

In [None]:
utils.plot_grain_sinograms(grains, cf_strong)

In [None]:
# plt.style.use('dark_background')
fig, ax = plt.subplots(2,2, figsize=(12,12))
a = ax.ravel()
x = [g.dx for g in grains]
y = [g.dy for g in grains]
s = [g.peaks_4d_selected.sum()/10 for g in grains]
a[0].scatter(x, y, s=s, c=[g.rgb_z for g in grains])
a[0].set(title='IPF color Z',  aspect='equal')
a[1].scatter(x, y, s=s, c=[g.rgb_y for g in grains])
a[1].set(title='IPF color Y', aspect='equal')
a[2].scatter(x, y, s=s, c=[g.rgb_x for g in grains])
a[2].set(title='IPF color X',  aspect='equal')
a[3].scatter(x, y, c=s)
a[3].set(title='Number of 4D peaks', aspect='equal')

fig.supxlabel("Lab x")
fig.supylabel("Lab y")

plt.show()

In [None]:
# populate translations of grains
for g in grains:
    g.translation = np.array([g.dx, g.dy, 0])

In [None]:
# Big scary block
# Must understand what this does!

# Ensure cf is sorted by spot3d_id
# NOTE: spot3d_id should be spot4d_id, because we have merged into 4D?
assert (np.argsort(cf_strong.spot3d_id) == np.arange(cf_strong.nrows)).all()

# load the 2d peak labelling output
pks = ImageD11.sinograms.properties.pks_table.load(ds.pksfile)

# Grab the 2d peak centroids
p2d = pks.pk2d(ds.omega, ds.dty)

# NOTE: These are not spatially corrected?!

numba_order, numba_histo = utils.counting_sort(p2d['spot3d_id'])

grain_2d_id = utils.palloc(p2d['spot3d_id'].shape, np.dtype(int))

cleanid = cf_strong.grain_id.copy()

utils.find_grain_id(cf_strong.spot3d_id, cleanid, p2d['spot3d_id'], grain_2d_id, numba_order)

gord, counts = utils.counting_sort(grain_2d_id)

inds = np.concatenate(((0,), np.cumsum(counts)))

# I think what we end up with is:
# inds
# this is an array which tells you which 2D spots each grain owns
# the 2D spots are sorted by spot ID
# inds tells you for each grain were you can find its associated 2D spots

In [None]:
# now our 2D peak assignments are known, let's populate our grain objects with our 2D peaks

for grain in tqdm(grains):
    i = grain.gid
    grain.peaks_2d = gord[inds[i+1] : inds[i+2]]
    # grain.mask_2d = np.isin(cf_2d.index, grain.peaks_2d)

In [None]:
def map_grain_from_peaks(g, flt, ds):
    """
    Computes sinogram
    flt is already the peaks for this grain
    Returns angles, sino
    """   
    NY = len(ds.ybincens)  # number of y translations
    iy = np.round((flt.dty - ds.ybincens[0]) / (ds.ybincens[1]-ds.ybincens[0])).astype(int)  # flt column for y translation index

    # The problem is to assign each spot to a place in the sinogram
    hklmin = g.hkl_2d_strong.min(axis=1)  # Get minimum integer hkl (e.g -10, -9, -10)
    dh = g.hkl_2d_strong - hklmin[:,np.newaxis]  # subtract minimum hkl from all integer hkls
    de = (g.etasigns_2d_strong.astype(int) + 1)//2  # something signs related
    #   4D array of h,k,l,+/-
    # pkmsk is whether a peak has been observed with this HKL or not
    pkmsk = np.zeros(list(dh.max(axis=1) + 1 )+[2,], int)  # make zeros-array the size of (max dh +1) and add another axis of length 2
    pkmsk[ dh[0], dh[1], dh[2], de ] = 1  # we found these HKLs for this grain
    #   sinogram row to hit
    pkrow = np.cumsum(pkmsk.ravel()).reshape(pkmsk.shape) - 1  #
    # counting where we hit an HKL position with a found peak
    # e.g (-10, -9, -10) didn't get hit, but the next one did, so increment

    npks = pkmsk.sum( )
    destRow = pkrow[ dh[0], dh[1], dh[2], de ] 
    sino = np.zeros( ( npks, NY ), 'f' )
    hits = np.zeros( ( npks, NY ), 'f' )
    angs = np.zeros( ( npks, NY ), 'f' )
    adr = destRow * NY + iy 
    # Just accumulate 
    sig = flt.sum_intensity
    ImageD11.cImageD11.put_incr64( sino, adr, sig )
    ImageD11.cImageD11.put_incr64( hits, adr, np.ones(len(de),dtype='f'))
    ImageD11.cImageD11.put_incr64( angs, adr, flt.omega)
    
    sinoangles = angs.sum( axis = 1) / hits.sum( axis = 1 )
    # Normalise:
    sino = (sino.T/sino.max( axis=1 )).T
    # Sort (cosmetic):
    order = np.lexsort((np.arange(npks), sinoangles))
    sinoangles = sinoangles[order]
    ssino = sino[order].T
    return sinoangles, ssino, hits[order].T

def do_sinos(g, hkltol=0.25):
    flt = utils.tocolf({p:p2d[p][g.peaks_2d] for p in p2d}, par_path, dxfile=e2dx_path, dyfile=e2dy_path)  # convert it to a columnfile and spatially correct
    
    hkl_real = np.dot(g.ubi, (flt.gx, flt.gy, flt.gz))  # calculate hkl of all assigned peaks
    hkl_int = np.round(hkl_real).astype(int) # round to nearest integer
    dh = ((hkl_real - hkl_int)**2).sum(axis = 0)  # calculate square of difference

    # g.dherrall = dh.mean()  # mean hkl error across all assigned peaks
    # g.npksall = flt.nrows  # total number of assigned peaks
    flt.filter(dh < hkltol*hkltol)  # filter all assigned peaks to be less than hkltol squared
    hkl_real = np.dot(g.ubi, (flt.gx, flt.gy, flt.gz))  # recalculate error after filtration
    hkl_int = np.round(hkl_real).astype(int)
    dh = ((hkl_real - hkl_int)**2).sum(axis = 0)
    # g.dherr = dh.mean()  # dherr is mean hkl error across assigned peaks after hkltol filtering
    # g.npks = flt.nrows  # total number of assigned peaks after hkltol filtering
    g.etasigns_2d_strong = np.sign(flt.eta)
    g.hkl_2d_strong = hkl_int  # integer hkl of assigned peaks after hkltol filtering
    g.sinoangles, g.ssino, g.hits = map_grain_from_peaks(g, flt, ds)
    return i,g

In [None]:
# Determine sinograms of all grains

nthreads = len(os.sched_getaffinity(os.getpid()))

with concurrent.futures.ThreadPoolExecutor(max_workers= max(1,nthreads-1)) as pool:
    for i in tqdm(pool.map(do_sinos, grains), total=len(grains)):
        pass

In [None]:
# Show sinogram of single grain

g = grains[0]

fig, ax = plt.subplots()

ax.imshow((g.ssino/g.ssino.mean(axis=0)), norm=matplotlib.colors.LogNorm(), interpolation='nearest', origin="lower")

plt.show()

In [None]:
def run_iradon_id11(grain, pad=20, y0=c0/2, workers=1, sample_mask=whole_sample_mask, apply_halfmask=is_half_scan, mask_central_zingers=is_half_scan):
    outsize = grain.ssino.shape[0] + pad
    
    if apply_halfmask:
        halfmask = np.zeros_like(grain.ssino)

        halfmask[:len(halfmask)//2-1, :] = 1
        halfmask[len(halfmask)//2-1, :] = 0.5
        
        ssino_to_recon = grain.ssino * halfmask
    else:
        ssino_to_recon = grain.ssino
        
    # # pad the sample mask
    # sample_mask_padded = np.pad(sample_mask, pad//2)

    
    # Perform iradon transform of grain sinogram, store result (reconstructed grain shape) in g.recon
    grain.recon = ImageD11.sinograms.roi_iradon.iradon(ssino_to_recon, 
                                                       theta=grain.sinoangles, 
                                                       mask=sample_mask,
                                                       output_size=outsize,
                                                       projection_shifts=np.full(grain.ssino.shape, -y0),
                                                       filter_name='hamming',
                                                       interpolation='linear',
                                                       workers=workers)
    
    if mask_central_zingers:
        grs = grain.recon.shape[0]
        xpr, ypr = -grs//2 + np.mgrid[:grs, :grs]
        inner_mask_radius = 25
        outer_mask_radius = inner_mask_radius + 2

        inner_circle_mask = (xpr ** 2 + ypr ** 2) < inner_mask_radius ** 2
        outer_circle_mask = (xpr ** 2 + ypr ** 2) < outer_mask_radius ** 2

        mask_ring = inner_circle_mask & outer_circle_mask
        # we now have a mask to apply
        fill_value = np.median(grain.recon[mask_ring])
        grain.recon[inner_circle_mask] = fill_value
    
    return grain

In [None]:
# if you want, you can override the y0 value here

# y0 = 1.5  # for example!

y0 = c0/2

In [None]:
g = grains[0]

run_iradon_id11(g, pad=pad, y0=y0, workers=20)

In [None]:
g = grains[0]

fig, axs = plt.subplots(1,2, figsize=(10,5))
axs[0].imshow(g.recon, vmin=0)
axs[0].set_title("ID11 iradon")
axs[1].imshow(g.ssino, aspect='auto')
axs[1].set_title("ssino")

plt.show()

In [None]:
nthreads = len(os.sched_getaffinity(os.getpid()))

with concurrent.futures.ThreadPoolExecutor( max_workers= max(1,nthreads-1) ) as pool:
    for i in tqdm(pool.map(run_iradon_id11, grains, [pad]*len(grains), [y0]*len(grains)), total=len(grains)):
        pass

In [None]:
for grain in grains:
    grain.og_recon = grain.recon

In [None]:
fig, a = plt.subplots(1,2,figsize=(10,5))
rec = a[0].imshow(grains[8].og_recon, vmin=0)
sin = a[1].imshow(grains[8].ssino, aspect='auto')

# Function to update the displayed image based on the selected frame
def update_frame(i):
    rec.set_array(grains[i].og_recon)
    sin.set_array(grains[i].ssino)
    a[0].set(title=str(i))
    fig.canvas.draw()

# Create a slider widget to select the frame number
frame_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=len(grains) - 1,
    step=1,
    description='Grain:'
)

interact(update_frame, i=frame_slider)

plt.show()

In [None]:
f,a = plt.subplots( 1,3, figsize=(15,5) )
ty, tx = utils.triangle().T
for i,title in enumerate( 'xyz' ):
    ax = np.zeros(3)
    ax[i] = 1.
    hkl = [utils.crystal_direction_cubic( g.ubi, ax ) for g in grains]
    xy = np.array([utils.hkl_to_pf_cubic(h) for h in hkl ])
    rgb = np.array([utils.hkl_to_color_cubic(h) for h in hkl ])
    for j in range(len(grains)):
        grains[j].rgb = rgb[j]
    a[i].scatter( xy[:,1], xy[:,0], c = rgb )   # Note the "x" axis of the plot is the 'k' direction and 'y' is h (smaller)
    a[i].set(title=title, aspect='equal', facecolor='silver', xticks=[], yticks=[])
    a[i].plot( tx, ty, 'k-', lw = 1 )

In [None]:
rgb_array, grain_labels_array, raw_intensity_array = utils.build_slice_arrays(grains, cutoff_level=0)

In [None]:
# plot initial output

fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(rgb_array)
plt.show()

In [None]:
fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(grain_labels_array)  # originally 1,2,0
ax.set_title("Grain label map")
plt.show()

In [None]:
fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(raw_intensity_array)
ax.set_title("Raw intensity array")
plt.show()

In [None]:
# we can clean up these reconstructions using an MLEM iterative recon
# we can use the whole sample shape mask for this

In [None]:
# write og_recon and ssino and circle_mask to disk

cmp = {'compression':'gzip',
       'compression_opts': 2,
       'shuffle' : True }

def save_array(grp, name, ary):
    hds = grp.require_dataset(name, 
                              shape=ary.shape,
                              dtype=ary.dtype,
                              **cmp)
    hds[:] = ary
    return hds

def save_grains(grains, ds):
    with h5py.File(ds.grainsfile, 'r+') as hout:
        try:
            grp = hout.create_group('peak_assignments')
        except ValueError:
            grp = hout['peak_assignments']

        ds_gord = save_array( grp, 'gord', gord )
        ds_gord.attrs['description'] = 'Grain ordering: g[i].pks = gord[ inds[i] : inds[i+1] ]'
        ds_inds = save_array( grp, 'inds', inds )
        ds_inds.attrs['description'] = 'Grain indices: g[i].pks = gord[ inds[i] : inds[i+1] ]'
        
        grains_group = 'grains'
        for g in tqdm(grains):
            gg = hout[grains_group][str(g.gid)]
            # save stuff for sinograms
            
            save_array(gg, 'ssino', g.ssino).attrs['description'] = 'Sinogram of peak intensities sorted by omega'
            save_array(gg, 'sinoangles', g.sinoangles).attrs['description'] = 'Projection angles for sinogram'
            save_array(gg, 'og_recon', g.og_recon).attrs['description'] = 'Original ID11 iRadon reconstruction'
            save_array(gg, 'circle_mask', whole_sample_mask).attrs['description'] = 'Reconstruction mask to use for MLEM'
            
            # might as well save peaks stuff while we're here
            save_array(gg, 'translation', g.translation).attrs['description'] = 'Grain translation in lab frame'
            save_array(gg, 'peaks_2d_sinograms', g.peaks_2d).attrs['description'] = "2D peaks from strong 4D peaks that were assigned to this grain for sinograms"
            save_array(gg, 'peaks_4d_sinograms', g.peaks_4d).attrs['description'] = "Strong 4D peaks that were assigned to this grain for sinograms"

            gg.attrs['cen'] = g.cen
            gg.attrs['y0'] = y0

In [None]:
save_grains(grains, ds)

In [None]:
if is_half_scan:
    dohm = "Yes"
    mask_cen = "Yes"
else:
    dohm = "No"
    mask_cen = "No"

In [None]:
slurm_mlem_path = os.path.join(ds.analysispath, "slurm_mlem")

if os.path.exists(slurm_mlem_path):
    print(f"Removing {slurm_mlem_path}")
    rmtree(slurm_mlem_path)

os.mkdir(slurm_mlem_path)

In [None]:
recons_path = os.path.join(ds.analysispath, "mlem_recons")

if os.path.exists(recons_path):
    print(f"Removing {recons_path}")
    rmtree(recons_path)

os.mkdir(recons_path)

In [None]:
bash_script_path = os.path.join(slurm_mlem_path, ds.dsname + '_mlem_recon_slurm.sh')
# python_script_path = os.path.join(ds.analysisroot, "run_mlem_recon.py")
python_script_path = os.path.join(id11_code_path, "ImageD11/nbGui/S3DXRD/run_mlem_recon.py") 
outfile_path =  os.path.join(slurm_mlem_path, ds.dsname + '_mlem_recon_slurm_%A_%a.out')
errfile_path =  os.path.join(slurm_mlem_path, ds.dsname + '_mlem_recon_slurm_%A_%a.err')
log_path = os.path.join(slurm_mlem_path, ds.dsname + '_mlem_recon_slurm_$SLURM_ARRAY_JOB_ID_$SLURM_ARRAY_TASK_ID.log')

reconfile = os.path.join(recons_path, ds.dsname + "_mlem_recon_$SLURM_ARRAY_TASK_ID.txt")

n_simultaneous_jobs = 50
cores_per_task = 8
niter = 50

bash_script_string = f"""#!/bin/bash
#SBATCH --job-name=mlem-recon
#SBATCH --output={outfile_path}
#SBATCH --error={errfile_path}
#SBATCH --array=0-{len(grains)-1}%{n_simultaneous_jobs}
#SBATCH --time=02:00:00
# define memory needs and number of tasks for each array job
#SBATCH --ntasks=1
#SBATCH --cpus-per-task={cores_per_task}
#
date
python3 {python_script_path} {ds.grainsfile} $SLURM_ARRAY_TASK_ID {reconfile} {pad} {niter} {dohm} {mask_cen} > {log_path} 2>&1
date
"""

with open(bash_script_path, "w") as bashscriptfile:
    bashscriptfile.writelines(bash_script_string)

In [None]:
utils.slurm_submit_and_wait(bash_script_path, 60)

In [None]:
# collect results into grain attributes
# the filenames are element position not gid

for i, grain in enumerate(tqdm(grains)):
    grain.recon = np.loadtxt(os.path.join(recons_path, ds.dsname + f"_mlem_recon_{i}.txt"))

In [None]:
# look at all our grains

n_grains_to_plot = 25

grains_step = len(grains)//n_grains_to_plot

grid_size = np.ceil(np.sqrt(len(grains[::grains_step]))).astype(int)
nrows = (len(grains[::grains_step])+grid_size-1)//grid_size

fig, axs = plt.subplots(grid_size, nrows, figsize=(10,10), layout="constrained", sharex=True, sharey=True)
for i, ax in enumerate(axs.ravel()):
    if i < len(grains[::grains_step]):
    # get corresponding grain for this axis
        g = grains[::grains_step][i]
        ax.imshow(g.recon, vmin=0)
        # ax.invert_yaxis()
        ax.set_title(i)
    
plt.show()

In [None]:
rgb_array, grain_labels_array, raw_intensity_array = utils.build_slice_arrays(grains, cutoff_level=0.2)

In [None]:
fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(rgb_array)
plt.show()

In [None]:
fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(raw_intensity_array)
ax.set_title("Sinogram raw intensity map")
plt.show()

In [None]:
fig, ax = plt.subplots(constrained_layout=True)
ax.imshow(grain_labels_array)
ax.set_title("Grain label map")
plt.show()

In [None]:
# save recons and 2d properties to existing grain file

def save_grains(grains, ds):
    with h5py.File(ds.grainsfile, 'r+') as hout:
        try:
            grp = hout.create_group('slice_recon')
        except ValueError:
            grp = hout['slice_recon']
        save_array(grp, 'intensity', raw_intensity_array).attrs['description'] = 'Raw intensity array for all grains'
        save_array(grp, 'labels', grain_labels_array).attrs['description'] = 'Grain labels array for all grains'
        
        grains_group = 'grains'

        for g in tqdm(grains):
            gg = hout[grains_group][str(g.gid)]

            save_array(gg, 'recon', g.recon).attrs['description'] = 'Final reconstruction'

In [None]:
save_grains(grains, ds)