# 3DXRD indexing notebook - Grid indexing method  
__Written by Haixing Fang, Jon Wright and James Ball__  
__Date: 21/02/2025__

In [None]:
import os

os.environ['OMP_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'

exec(open('/data/id11/nanoscope/install_ImageD11_from_git.py').read())

In [None]:
# this cell is tagged with 'parameters'
# to view the tag, select the cell, then find the settings gear icon (right or left sidebar) and look for Cell Tags

PYTHONPATH = setup_ImageD11_from_git( ) # ( os.path.join( os.environ['HOME'],'Code'), 'ImageD11_git' )

# desination of H5 files
# replace below with e.g.:
# dset_path = '/data/visitor/expt1234/20240101/PROCESSED_DATA/sample/dataset/sample_dataset.h5'

dset_path = ''

phase_str = 'Fe'

# path to parameters .json/.par
parfile = ''

# peak filtration options
cf_strong_frac = 0.999
cf_strong_dsmax = 1.017
cf_strong_dstol = 0.025

# indexing options
rings_to_use = [0, 1, 3]

# makemap options
symmetry = "cubic"
makemap_tol_seq = [0.02, 0.015, 0.01]

gridpars = {
        'DSTOL' : 0.004,
        'RING1'  : [1,0,],
        'RING2' : [0,],
        'NUL' : True,
        'FITPOS' : True,
        'tolangle' : 0.50,
        'toldist' : 100.,
        'NTHREAD' : 1 ,
}

grid_xlim = 600  # um - extent away from rotation axis to search for grains
grid_ylim = 600
grid_zlim = 200
grid_step = 100  # step size of search grid, um

# fraction of expected number of peaks to accept in Makemap output
frac = 0.85

# find the spike
absolute_minpks = 56

dset_prefix = 'ff'

In [None]:
import os
import numpy as np
from matplotlib import pyplot as plt

import random
import ImageD11.cImageD11
import ImageD11.grain
import ImageD11.indexing
import ImageD11.columnfile
from ImageD11.sinograms import dataset
from ImageD11.peakselect import select_ring_peaks_by_intensity
from ImageD11.nbGui import nb_utils as utils
from ImageD11.grid_index_parallel import grid_index_parallel

%matplotlib widget

# Load data
## Dataset

In [None]:
ds = ImageD11.sinograms.dataset.load(dset_path)
print(ds)

## Parameters
Specify the path to your parameter file.

You can optionally set up some default parameters for either an Eiger or Frelon detector like so:
```python
from ImageD11.parameters import AnalysisSchema
asc = AnalysisSchema.from_default(detector='eiger')  # or detector='frelon'
asc.save('./pars.json')
```
Please note in this case that you will still have to update the `geometry.par` values accordingly for your experiment.  
If you haven't already, you should run one of the calibration notebooks to determine these.

In [None]:
ds.parfile = parfile
ds.save()

## Phases
If the parameter file was a json, we can access the unit cells via `ds.phases.unitcells`

In [None]:
ds.phases = ds.get_phases_from_disk()
ds.phases.unitcells

In [None]:
ucell = ds.phases.unitcells[phase_str]
print(ucell)

## Peaks

In [None]:
cf_3d = ds.get_cf_3d_from_disk()
ds.update_colfile_pars(cf_3d, phase_name=phase_str) 
cf_3d_path = 'cf_3d.flt'
cf_3d.writefile(cf_3d_path)

# Visualise data
### $d^{*}$ vs $\eta$

In [None]:
skip = 1  # we can skip peaks to speed up plotting if needed
ucell.makerings(cf_3d.ds.max())
fig, ax = plt.subplots(figsize=(10,5), layout='constrained')
ax.plot(cf_3d.ds[::skip], cf_3d.eta[::skip], ',')
ax.vlines(ucell.ringds, -50, 50, color='red')
ax.set(xlabel=r'$d^{*}~(\AA^{-1})$', ylabel=r'$\eta~(\degree)$', title='2D azimuthal transform')
plt.show()

# Filtration
Here we are filtering our peaks (`cf_3d`) to select only the strongest ones for indexing purposes only!  
We first filter the peaks in $d^{*}$ to keep only those close to the predicted peaks from the unit cell.  
We then sort our peaks by intensity, and take a certain intensity-weighted fraction of them.  
`dstol`: The tolerance in $d^{*}$ between a peak and a predicted reflection.  
`dsmax`: The maximum allowed peak $d^{*}$ value. Used to limit the number of rings given to the indexer - 6-8 rings max are normally sufficient.   
`frac`: The intensity fraction: `frac=0.9` keeps 90% of the peak intensity. We recommend that you choose a value close to the 'elbow' of the plot.

In [None]:
cf_strong = select_ring_peaks_by_intensity(cf_3d, frac=cf_strong_frac, dsmax=cf_strong_dsmax, dstol=cf_strong_dstol, ucell=ucell, doplot=0.5)

In [None]:
# we will also export some additional strong peaks across all rings
# this will be useful for grain refinement later (using makemap)
cf_strong_allrings = select_ring_peaks_by_intensity(cf_3d, frac=cf_strong_frac, dsmax=cf_3d.ds.max(), dstol=cf_strong_dstol, ucell=ucell, doplot=0.5)
cf_strong_allrings_path = 'cf_strong_allrings.flt'
cf_strong_allrings.writefile(cf_strong_allrings_path)

In [None]:
skip = 1  # we can skip peaks to speed up plotting if needed
fig, ax = plt.subplots(figsize=(10, 5), constrained_layout=True)
ax.plot(cf_3d.ds[::skip], cf_3d.sum_intensity[::skip],',', label='cf_3d')
ax.plot(cf_strong.ds[::skip], cf_strong.sum_intensity[::skip],',', label='cf_strong')
ax.vlines(ucell.ringds, 1e3, 1e4, color='red')
ax.set(xlabel=r'$d^{*}~(\AA^{-1})$', ylabel='Intensity', yscale='log', title='Peak filtration')
ax.legend()
plt.show()

# Indexing
## Ring assignment

In [None]:
indexer = ImageD11.indexing.indexer_from_colfile_and_ucell(cf_strong, ucell)
indexer.ds_tol = cf_strong_dstol
ImageD11.indexing.loglevel = 1
indexer.assigntorings()
ImageD11.indexing.loglevel = 3
print(f"Indexing {cf_strong.nrows} peaks")

In [None]:
skip = 1  # we can skip peaks to speed up plotting if needed
fig, ax = plt.subplots(layout='constrained', figsize=(10,5))
ax.scatter(indexer.colfile.ds[::skip], indexer.colfile.eta[::skip], c=indexer.ra[::skip]%20, cmap='tab20', s=1)
ax.vlines(ucell.ringds, -50, 50, color='red')
ax.set(xlabel=r'$d^{*}~(\AA^{-1})$', ylabel=r'$\eta~(\degree)$', xlim=(min(ucell.ringds[0], cf_strong.ds.min()) - 0.02, cf_strong.ds.max() + 0.02), title='Ring assignments')
plt.show()

Now we need to compute the number of expected peaks.  
To do this, you add up the multiplicites of the rings you chose.  
If you recorded a 360 degree scan, multiply the result by 2.  
e.g given this output:
```
# info: Ring     (  h,  k,  l) Mult  total indexed to_index  ubis  peaks_per_ubi   tth
# info: Ring 3   ( -2, -2,  0)   12   2251       0     2251    93     24  16.11
# info: Ring 2   ( -1, -1, -2)   24   4899       0     4899   101     48  13.94
# info: Ring 1   ( -2,  0,  0)    6   1233       0     1233   102     12  11.37
# info: Ring 0   ( -1, -1,  0)   12   2861       0     2861   118     24  8.03
```
Selecting rings `[0,1,3]` we would get `(12+6+12)*2 = 84` peaks

In [None]:
peaks_expected = (12+6+12)*2

# choose the fraction of the number of peaks expected - this should be around 0.9 if you had a good clean segementation
# if you suspect you are missing peaks in your data, decrease to around 0.6

minpeaks = int(np.round(peaks_expected * frac, 2))
minpeaks

## Choose rings to export for grid index

In [None]:
mask = np.zeros(cf_strong.nrows, dtype=bool)
for ring in rings_to_use:
    mask |= indexer.ra == ring
peaks_to_export = cf_strong.copy()
peaks_to_export.filter(mask)

In [None]:
skip = 1  # we can skip peaks to speed up plotting if needed
fig, ax = plt.subplots(figsize=(10, 5), constrained_layout=True)
ax.plot(cf_3d.ds[::skip], cf_3d.sum_intensity[::skip],',', label='cf_3d')
ax.plot(peaks_to_export.ds[::skip], peaks_to_export.sum_intensity[::skip],',', label='peaks_to_export')
ax.vlines(ucell.ringds, 1e3, 1e4, color='red')
ax.set(xlabel=r'$d^{*}~(\AA^{-1})$', ylabel='Intensity', yscale='log', title='Peak filtration')
ax.legend()
plt.show()

## Grid index

In [None]:
grid_peaks_path = 'grid_peaks.flt'
new_grid_peaks_path = 'grid_peaks.flt.new'
peaks_to_export.writefile(grid_peaks_path)  # export peaks
oldparfile = phase_str + '.par'
ds.phases.to_old_pars_file(oldparfile, phase_str)  # export parameter file
nproc = max( ImageD11.cImageD11.cores_available() - 1 , 1)


omega_slop = ds.ostep/2
gridpars['COSTOL'] = np.cos(np.radians(90 - ds.ostep))
gridpars['NPROC'] = nproc
gridpars['NTHREAD'] = 1
gridpars['NPKS'] = minpeaks
gridpars['OMEGAFLOAT'] = omega_slop
gridpars['TOLSEQ'] = makemap_tol_seq
gridpars['SYMMETRY'] = symmetry

translations = [(t_x, t_y, t_z) # grid to search
    for t_x in range(-grid_xlim, grid_xlim+1, grid_step)
    for t_y in range(-grid_ylim, grid_ylim+1, grid_step) 
    for t_z in range(-grid_zlim, grid_zlim+1, grid_step) ]

random.seed(42) # reproducible
random.shuffle(translations)
tmp_output_path = 'tmp'
map_path = 'alltmp.map'
new_map_path = f'alltmp.map.new'
grid_index_parallel(grid_peaks_path, oldparfile, tmp_output_path, gridpars, translations)

# View outputs

In [None]:
grains2 = ImageD11.grain.read_grain_file(map_path)
for g in grains2:
    g.ref_unitcell = ucell
utils.get_rgbs_for_grains(grains2)

In [None]:
utils.plot_all_ipfs(grains2)

In [None]:
utils.plot_grain_positions(grains2, colour='npks', centre_plot=False, size_scaling=0.5)

In [None]:
utils.plot_grain_positions(grains2, colour='z', centre_plot=False, size_scaling=0.5)

# Grain refinement

In [None]:
makemap_output = !makemap.py -p {oldparfile} -u {map_path} -U {new_map_path} -f {grid_peaks_path} -s {symmetry} -t {makemap_tol_seq[-1]} --omega_slop={omega_slop} --no_sort

In [None]:
utils.plot_grain_histograms(new_grid_peaks_path, new_map_path, oldparfile, omega_slop, tol=makemap_tol_seq[-1])

In [None]:
grains3 = ImageD11.grain.read_grain_file(new_map_path)
grains3 = [grain for grain in grains3 if "no peaks" not in grain.intensity_info]
for g in grains3:
    g.intensity = float(g.intensity_info.split("mean = ")[1].split(" , ")[0].replace("'", ""))

In [None]:
utils.plot_grain_positions(grains3, colour='npks', centre_plot=False, size_scaling=0.5)

In [None]:
fig, ax = plt.subplots(figsize=(10, 7), layout='constrained')
ax.hist([float(grain.npks) for grain in grains3], bins=30)
ax.set(xlabel='Number of peaks per grain', ylabel='Count', title='Histogram of peaks per grain')
plt.show()

In [None]:
# filter out grains with fewer than absolute_minpks peaks
# most grains should have a high number of peaks
# choose absolute_minpks such that the low-peak grains are removed
grains_filtered = [grain for grain in grains3 if float(grain.npks) > absolute_minpks]

In [None]:
utils.plot_grain_positions(grains_filtered, colour='npks', centre_plot=False, size_scaling=0.5)

In [None]:
filtered_map_path = 'nice_grains.map'
new_filtered_map_path = 'nice_grains.map.new'
new_cf_3d_path = cf_3d_path + '.new'

# run makemap again against all peaks
ImageD11.grain.write_grain_file(filtered_map_path, grains_filtered)

# run makemap on filtered grains with all 3D peaks
makemap_output = !makemap.py -p {oldparfile} -u {filtered_map_path} -U {new_filtered_map_path} -f {cf_3d_path} -s {symmetry} -t {makemap_tol_seq[-1]} --omega_slop={omega_slop} --no_sort

# import makemap output columnfile with peak assignments
cf_3d = ImageD11.columnfile.columnfile(new_cf_3d_path)

# write 3D columnfile to disk
ImageD11.columnfile.colfile_to_hdf(cf_3d, ds.col3dfile, name='peaks')

# re-import filtered grains with new peak statistics
grains_final = ImageD11.grain.read_grain_file(new_filtered_map_path)
for g in grains_final:
    g.ref_unitcell = ucell

utils.get_rgbs_for_grains(grains_final)
print(f"{len(grains_final)} final grains imported")

In [None]:
utils.plot_grain_positions(grains_final, colour='npks', centre_plot=False, size_scaling=0.5)

In [None]:
utils.plot_grain_positions(grains_final, colour='z', centre_plot=False, size_scaling=0.5)

In [None]:
fig, ax = plt.subplots(figsize=(10, 7), layout='constrained')
ax.hist([float(grain.npks) for grain in grains_final], bins=30)
ax.set(xlabel='Number of peaks per grain', ylabel='Count', title='Histogram of peaks per grain')
plt.show()

In [None]:
utils.plot_grain_histograms(new_cf_3d_path, new_filtered_map_path, oldparfile, omega_slop, tol=makemap_tol_seq[-1])

In [None]:
unit_cell_lengths = np.array([np.sort(g.unitcell.copy()) for g in grains_final])
median_unit_cell = np.median(unit_cell_lengths, axis=0)
print("Median unit cell:", median_unit_cell)
fig, ax = plt.subplots(constrained_layout=True)
ax.plot(unit_cell_lengths[:,:3])
ax.hlines(median_unit_cell[:3], 0, len(unit_cell_lengths))
ax.set(xlabel="Grain ID", ylabel="Unit cell length")
plt.show()

# Export data

In [None]:
ds.save_grains_to_disk(grains_final, phase_name=phase_str)

In [None]:
ds.save()

# Deleting temporary files

In [None]:
for path in [
    cf_3d_path,
    cf_strong_allrings_path,
    grid_peaks_path,
    tmp_output_path + '.flt',
    map_path,
    new_map_path,
    new_grid_peaks_path,
    filtered_map_path,
    new_filtered_map_path,
    new_cf_3d_path,
]:
    if os.path.exists(path):
        os.remove(path)