# Jupyter notebook based on ImageD11 to process scanning 3DXRD data
# Written by Haixing Fang, Jon Wright and James Ball
## Date: 26/02/2024

In [None]:
# USER: Change the path below to point to your local copy of ImageD11:

import os

home_dir = !echo $HOME
home_dir = str(home_dir[0])

# USER: You can change this location if you want

id11_code_path = os.path.join(home_dir, "Code/ImageD11")

import sys

sys.path.insert(0, id11_code_path)

In [None]:
# import functions we need

import glob, pprint
import numpy as np
import h5py
from tqdm.notebook import tqdm

import matplotlib
%matplotlib ipympl
from matplotlib import pyplot as plt

import ImageD11.nbGui.nb_utils as utils

import ImageD11.grain
import ImageD11.indexing
import ImageD11.columnfile
from ImageD11.sinograms import properties, dataset

from ImageD11.blobcorrector import eiger_spatial

In [None]:
# NOTE: For old datasets before the new directory layout structure, we don't distinguish between RAW_DATA and PROCESSED_DATA

### USER: specify your experimental directory

rawdata_path = "/data/visitor/ihma439/id11/20231211/RAW_DATA"

!ls -lrt {rawdata_path}

### USER: specify where you want your processed data to go

processed_data_root_dir = "/data/visitor/ihma439/id11/20231211/PROCESSED_DATA/James/nb_testing"

In [None]:
# USER: pick a sample and a dataset you want to segment

sample = "FeAu_0p5_tR_nscope"
dataset = "top_250um"

In [None]:
# desination of H5 files

dset_file = os.path.join(processed_data_root_dir, sample, f"{sample}_{dataset}", f"{sample}_{dataset}_dataset.h5")

In [None]:
# load the dataset from file

ds = ImageD11.sinograms.dataset.load(dset_file)

print(ds)
print(ds.shape)

In [None]:
# USER: specify the path to the parameter file and spatial distortion files

par_file = os.path.join(processed_data_root_dir, '../../../SCRIPTS/James/S3DXRD/Fe_refined.par')
e2dx_file = os.path.join(processed_data_root_dir, '../../CeO2/e2dx_E-08-0173_20231127.edf')
e2dy_file = os.path.join(processed_data_root_dir, '../../CeO2/e2dy_E-08-0173_20231127.edf')

# add them to the dataset

ds.parfile = par_file
ds.e2dxfile = e2dx_file
ds.e2dyfile = e2dy_file

In [None]:
# if we already have the 2D and 4D peaks, we could just load them from file:

In [None]:
# cf_2d = ImageD11.columnfile.colfile_from_hdf(ds.col2dfile)
# cf_2d.parameters.loadparameters(ds.par_file)
# cf_2d.updateGeometry()

In [None]:
# cf_4d = ImageD11.columnfile.colfile_from_hdf(ds.col4dfile)
# cf_4d.parameters.loadparameters(ds.par_file)
# cf_4d.updateGeometry()

In [None]:
# otherwise load them from the peaks table:

In [None]:
# Import 2D peaks, make a spatially corrected columnfile, save it

peaks_table = ImageD11.sinograms.properties.pks_table.load(ds.pksfile)

# Grab the 2d peak centroids
peaks_2d = peaks_table.pk2d(ds.omega, ds.dty)
cf_2d = utils.tocolf(peaks_2d, ds)

if os.path.exists(ds.col2dfile):
    os.remove(ds.col2dfile)

# save the 2D peaks to file so we don't have to spatially correct them again
ImageD11.columnfile.colfile_to_hdf(cf_2d, ds.col2dfile)

In [None]:
# We will now generate a cf (columnfile) object for the 4D peaks.
# Will be corrected for detector spatial distortion

peaks_4d = peaks_table.pk2dmerge(ds.omega, ds.dty)
cf_4d = utils.tocolf(peaks_4d, ds)  # spatial correction

# uncomment below if you don't want spatial correction for some reason
# cf_4d = ImageD11.columnfile.colfile_from_dict(peaks_4d)
# cf_4d.addcolumn(cf_4d.s_raw, "sc")
# cf_4d.addcolumn(cf_4d.f_raw, "fc")
# cf_4d.parameters.loadparameters(ds.par_file)
# cf_4d.updateGeometry()

# the first thing we should do is create an index column for our 4D peaks
index_column = np.arange(cf_4d.nrows)
cf_4d.addcolumn(index_column, 'index')

# Delete the columnfile output file if it exists

if os.path.exists(ds.col4dfile):
    os.remove(ds.col4dfile)
    
# save the 4D peaks to file so we don't have to spatially correct them again
ImageD11.columnfile.colfile_to_hdf(cf_4d, ds.col4dfile)

In [None]:
# Generate a mask that selects only 4D peaks greater than 25 pixels in size

m = cf_4d['Number_of_pixels'] > 25

# then plot omega vs dty for all peaks - should look sinusoidal

fig, ax = plt.subplots()
counts, xedges, yedges, im = ax.hist2d(cf_4d['omega'][m], cf_4d['dty'][m], weights=np.sqrt(cf_4d['sum_intensity'][m]), bins=(ds.obinedges, ds.ybinedges), norm=matplotlib.colors.LogNorm())
ax.set_xlabel("Omega angle")
ax.set_ylabel("dty")

fig.colorbar(im, ax=ax)

plt.show()

In [None]:
# plot the 4D peaks (fewer of them) as a cake (two-theta vs eta)
# if the parameters in the par file are good, these should look like straight lines

fig, ax = plt.subplots()

ax.scatter(cf_4d.ds, cf_4d.eta, s=1)

ax.set_xlabel("dstar")
ax.set_ylabel("eta")

plt.show()

In [None]:
# OPTIONAL: export CF to an flt so we can play with it with ImageD11_gui
# uncomment the below line

# cf_4d.writefile(f'{sample}_{dataset}_4d_peaks.flt')

In [None]:
# here we are filtering our peaks (cf_4d) to select only the strongest ones for indexing purposes only!
# dsmax is being set to limit rings given to the indexer - 6-8 rings is normally good

# USER: modify the "frac" parameter below and re-run the cell until the orange dot sits nicely on the "elbow" of the blue line
# this indicates the fractional intensity cutoff we will select
# if the blue line does not look elbow-shaped in the logscale plot, try changing the "doplot" parameter (the y scale of the logscale plot) until it does

cf_strong_frac = 0.994
cf_strong_dsmax = 1.155
cf_strong_dstol = 0.005

cf_strong = utils.selectpeaks(cf_4d, frac=cf_strong_frac, dsmax=cf_strong_dsmax, dstol=cf_strong_dstol, doplot=0.95)
print(cf_4d.nrows)
print(cf_strong.nrows)

In [None]:
# OPTIONAL: export CF to an flt so we can play with it with ImageD11_gui
# uncomment the below line

# cf_strong.writefile(f'{sample}_{dataset}_strong_4d_peaks.flt')

In [None]:
# now we can take a look at the intensities of the remaining peaks

fig, ax = plt.subplots()

ax.plot(cf_strong.ds, cf_strong.sum_intensity,',')
ax.semilogy()

ax.set_xlabel("Dstar")
ax.set_ylabel("Intensity")

plt.show()

In [None]:
# now we can define a unit cell from our parameters

ucell = ImageD11.unitcell.unitcell_from_parameters(cf_strong.parameters)
ucell.makerings(cf_strong.ds.max())

In [None]:
# now let's plot our peaks again, with the rings from the unitcell included, to check our lattice parameters are good

fig, ax = plt.subplots()

skip=1
ax.scatter( cf_strong.ds[::skip], cf_strong.eta[::skip], s=0.5)
ax.plot( ucell.ringds, [0,]*len(ucell.ringds), '|', ms=90, c="red")
ax.set_xlabel('1 / d ($\AA$)')
ax.set_ylabel('$\\eta$ (deg)')

plt.show()

In [None]:
# specify our ImageD11 indexer with these peaks
# we're aiming to index around 3_000 to 10_000 peaks

indexer = ImageD11.indexing.indexer_from_colfile(cf_strong)

print(f"Indexing {cf_strong.nrows} peaks")

In [None]:
# USER: set a tolerance in d-space (for assigning peaks to powder rings)

indexer_ds_tol = 0.01
indexer.ds_tol = indexer_ds_tol

# change the log level so we can see what the ring assigments look like

ImageD11.indexing.loglevel = 1

# assign peaks to powder rings

indexer.assigntorings()

# change log level back again

ImageD11.indexing.loglevel = 3

In [None]:
# let's plot the assigned peaks

fig, ax = plt.subplots()

# indexer.ra is the ring assignments

ax.scatter(cf_strong.ds, cf_strong.eta, c=indexer.ra, cmap='tab20', s=1)
ax.set_xlabel("d-star")
ax.set_ylabel("eta")

plt.show()

In [None]:
# now we are indexing!
# we have to choose which rings we want to generate orientations on
# generally we want two or three low-multiplicity rings that are isolated from other phases
# take a look at the ring assignment output from a few cells above, and choose two or three
rings_for_gen = [0, 1, 3]

# now we want to decide which rings to score our found orientations against
# generally we can just exclude dodgy rings (close to other phases, only a few peaks in etc)
rings_for_scoring = [0, 1, 2, 3, 4]

# the sequence of hkl tolerances the indexer will iterate through
hkl_tols_seq = [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.075]
# the sequence of minpks fractions the indexer will iterate through
fracs = [0.9, 0.7]
# the tolerance in g-vector angle
cosine_tol = np.cos(np.radians(90 - 0.25))
# the max number of UBIs we can find per pair of rings
max_grains = 1000

grains, indexer = utils.do_index(cf=cf_strong,
                                dstol=indexer_ds_tol,
                                forgen=rings_for_gen,
                                foridx=rings_for_scoring,
                                hkl_tols=hkl_tols_seq,
                                fracs=fracs,
                                cosine_tol=cosine_tol,
                                max_grains=max_grains
)

In [None]:
# set grain GIDs (useful if we ever delete a grain)
for i, g in enumerate(grains):
    g.gid = i
    
    g.a = np.cbrt(np.linalg.det(g.ubi))

In [None]:
mean_unit_cell_lengths = [grain.a for grain in grains]

fig, ax = plt.subplots()
ax.plot(mean_unit_cell_lengths)
ax.set_xlabel("Grain ID")
ax.set_ylabel("Unit cell length")
plt.show()

a0 = np.median(mean_unit_cell_lengths)
    
print(a0)

In [None]:
# assign peaks to grains

peak_assign_tol = 0.05

utils.assign_peaks_to_grains(grains, cf_strong, tol=peak_assign_tol)

print("Storing peak data in grains")
# iterate through all the grains
for g in tqdm(grains):
    # store this grain's peak indices so we know which 4D peaks we used for indexing
    g.mask_4d = cf_strong.grain_id == g.gid
    g.peaks_4d = cf_strong.index[cf_strong.grain_id == g.gid]

In [None]:
utils.plot_index_results(indexer, cf_strong, 'First attempt')

In [None]:
utils.plot_grain_sinograms(grains, cf_strong)

In [None]:
# save grain data

utils.save_s3dxrd_grains_after_indexing(grains, ds)

In [None]:
# save new things to the dataset

ds.save()

In [None]:
if 1:
    raise ValueError("Change the 1 above to 0 to allow 'Run all cells' in the notebook")

In [None]:
# Now that we're happy with our indexing parameters, we can run the below cell to do this in bulk for many samples/datasets
# by default this will do all samples in sample_list, all datasets with a prefix of dset_prefix
# you can add samples and datasets to skip in skips_dict

skips_dict = {
    "FeAu_0p5_tR_nscope": ["top_-50um", "top_-100um"]
}

dset_prefix = "top"

sample_list = ["FeAu_0p5_tR_nscope"]
    
samples_dict = utils.find_datasets_to_process(rawdata_path, skips_dict, dset_prefix, sample_list)
    
# manual override:
# samples_dict = {"FeAu_0p5_tR_nscope": ["top_100um", "top_200um"]}
    
# now we have our samples_dict, we can process our data:


for sample, datasets in samples_dict.items():
    for dataset in datasets:
        print(f"Processing dataset {dataset} in sample {sample}")
        dset_path = os.path.join(processed_data_root_dir, sample, f"{sample}_{dataset}", f"{sample}_{dataset}_dataset.h5")
        if not os.path.exists(dset_path):
            print(f"Missing DataSet file for {dataset} in sample {sample}, skipping")
            continue
        
        print("Importing DataSet object")
        
        ds = ImageD11.sinograms.dataset.load(dset_path)
        print(f"I have a DataSet {ds.dset} in sample {ds.sample}")
        if os.path.exists(ds.grainsfile):
            print(f"Already have grains for {dataset} in sample {sample}, skipping")
            continue
        
        ds.parfile = par_file
        ds.e2dxfile = e2dx_file
        ds.e2dyfile = e2dy_file
        
        peaks_table = ImageD11.sinograms.properties.pks_table.load(ds.pksfile)
        peaks_2d = peaks_table.pk2d(ds.omega, ds.dty)
        cf_2d = utils.tocolf(peaks_2d, ds)
        if os.path.exists(ds.col2dfile):
            os.remove(ds.col2dfile)
        ImageD11.columnfile.colfile_to_hdf(cf_2d, ds.col2dfile)

        peaks_4d = peaks_table.pk2dmerge(ds.omega, ds.dty)
        cf_4d = utils.tocolf(peaks_4d, ds)  # spatial correction
        index_column = np.arange(cf_4d.nrows)
        cf_4d.addcolumn(index_column, 'index')
        if os.path.exists(ds.col4dfile):
            os.remove(ds.col4dfile)
        ImageD11.columnfile.colfile_to_hdf(cf_4d, ds.col4dfile)
        
        cf_strong = utils.selectpeaks(cf_4d, frac=cf_strong_frac, dsmax=cf_strong_dsmax, dstol=cf_strong_dstol)

        grains, indexer = utils.do_index(cf=cf_strong,
                                        dstol=indexer_ds_tol,
                                        forgen=rings_for_gen,
                                        foridx=rings_for_scoring,
                                        hkl_tols=hkl_tols_seq,
                                        fracs=fracs,
                                        cosine_tol=cosine_tol,
                                        max_grains=max_grains
        )
        
        for i, g in enumerate(grains):
            g.gid = i
            
        utils.assign_peaks_to_grains(grains, cf_strong, tol=peak_assign_tol)

        print("Storing peak data in grains")
        for g in tqdm(grains):
            g.mask_4d = cf_strong.grain_id == g.gid
            g.peaks_4d = cf_strong.index[cf_strong.grain_id == g.gid]
            
        utils.save_s3dxrd_grains_after_indexing(grains, ds)
        
        ds.save()

print("Done!")