# Test indexing_iterative in grainmaps.py
## Mainly to test indexing_iterative that iteratively indexing grains, matching peaks with already-indexed grains and continue new indexing with the unmatched peaks
## Normally it is used as a supplementary for tomo_1_index.ipynb or or tomo_1_index_minor_phase.ipynb
## Jan 2025

In [None]:
exec(open('/data/id11/nanoscope/install_ImageD11_from_git.py').read())
PYTHONPATH = setup_ImageD11_from_git( ) # ( os.path.join( os.environ['HOME'],'Code'), 'ImageD11_git' )

In [None]:
# import functions we need

import h5py
import numpy as np

import matplotlib
# %matplotlib ipympl
from matplotlib import pyplot as plt

import ImageD11.nbGui.nb_utils as utils

import ImageD11.grain
import ImageD11.indexing
import ImageD11.columnfile
from ImageD11.unitcell import Phases
from ImageD11.peakselect import select_ring_peaks_by_intensity

In [None]:
# USER: Pass path to dataset file

dset_file = '/data/visitor/ihmi1549/id11/20240305/PROCESSED_DATA/FeAu_No1_interrupted/FeAu_No1_interrupted_s3DXRD_z5/FeAu_No1_interrupted_s3DXRD_z5_dataset.h5'

ds = ImageD11.sinograms.dataset.load(dset_file)
   
sample = ds.sample
dataset = ds.dsname
rawdata_path = ds.dataroot
processed_data_root_dir = ds.analysisroot

print(ds)
print(ds.shape)

In [None]:
# USER: specify the path to the parameter file
# you can find an example json in the same folder as this notebook

par_file = '/data/visitor/ihmi1549/id11/20240305/SCRIPTS/HF/S3DXRD/pars.json'

# add them to the dataset

ds.parfile = par_file

ds.save()

In [None]:
# load phases from parameter file

ds.phases = ds.get_phases_from_disk()
ds.phases.unitcells

In [None]:
# now let's select a phase to index from our parameters json
phase_str = 'Fe_bcc'

ucell = ds.phases.unitcells[phase_str]

print(ucell.lattice_parameters, ucell.spacegroup)

In [None]:
# We will now generate a cf (columnfile) object for the 4D peaks.
# Will be corrected for detector spatial distortion

cf_4d = ds.get_cf_4d()
ds.update_colfile_pars(cf_4d, phase_name=phase_str)

if not os.path.exists(ds.col4dfile):
    # save the 4D peaks to file so we don't have to spatially correct them again
    ImageD11.columnfile.colfile_to_hdf(cf_4d, ds.col4dfile)

In [None]:
# Generate a mask that selects only 4D peaks greater than 25 pixels in size

m = cf_4d['Number_of_pixels'] > 25

# then plot omega vs dty for all peaks - should look sinusoidal

fig, ax = plt.subplots()
counts, xedges, yedges, im = ax.hist2d(cf_4d['omega'][m], cf_4d['dty'][m], weights=np.sqrt(cf_4d['sum_intensity'][m]), bins=(ds.obinedges, ds.ybinedges), norm=matplotlib.colors.LogNorm())
ax.set_xlabel("Omega angle")
ax.set_ylabel("dty")

fig.colorbar(im, ax=ax)

plt.show()

In [None]:
# plot the 4D peaks (fewer of them) as a cake (two-theta vs eta)
# if the parameters in the par file are good, these should look like straight lines

ucell.makerings(cf_4d.ds.max())

fig, ax = plt.subplots(figsize=(10,5), layout='constrained')

ax.scatter(cf_4d.ds, cf_4d.eta, s=1)
ax.plot( ucell.ringds, [0,]*len(ucell.ringds), '|', ms=90, c="red")
ax.set_xlabel("dstar")
ax.set_ylabel("eta")

plt.show()

In [None]:
# OPTIONAL: export CF to an flt so we can play with it with ImageD11_gui
# uncomment the below line

# cf_4d.writefile(f'{sample}_{dataset}_4d_peaks.flt')

In [None]:
# here we are filtering our peaks (cf_4d) to select only the strongest ones for indexing purposes only!
# dsmax is being set to limit rings given to the indexer - 6-8 rings is normally good

# USER: modify the "frac" parameter below and re-run the cell until the orange dot sits nicely on the "elbow" of the blue line
# this indicates the fractional intensity cutoff we will select
# if the blue line does not look elbow-shaped in the logscale plot, try changing the "doplot" parameter (the y scale of the logscale plot) until it does

cf_strong_frac = 0.985
cf_strong_dsmax = 1.594
cf_strong_dstol = 0.005

cf_strong = select_ring_peaks_by_intensity(cf_4d, frac=cf_strong_frac, dsmax=cf_strong_dsmax, dstol=cf_strong_dstol, doplot=0.97)
print(cf_4d.nrows)
print(cf_strong.nrows)

In [None]:
# OPTIONAL: export CF to an flt so we can play with it with ImageD11_gui
# uncomment the below line

# cf_strong.writefile(f'{sample}_{dataset}_strong_4d_peaks.flt')

In [None]:
# now we can take a look at the intensities of the remaining peaks

fig, ax = plt.subplots(figsize=(10, 5), constrained_layout=True)

ax.plot(cf_4d.ds, cf_4d.sum_intensity,',', label='cf_4d')
ax.plot(cf_strong.ds, cf_strong.sum_intensity,',', label='cf_strong')
ax.plot( ucell.ringds, [1e4,]*len(ucell.ringds), '|', ms=90, c="red")

ax.semilogy()

ax.set_xlabel("Dstar")
ax.set_ylabel("Intensity")
ax.legend()

plt.show()

In [None]:
# now we can take a look at the intensities of the remaining peaks

fig, ax = plt.subplots(figsize=(10, 5), constrained_layout=True)

# ax.plot(cf_4d.ds, cf_4d.sum_intensity,',', label='cf_4d')
ax.plot(cf_strong.ds, cf_strong.sum_intensity,',', label='cf_strong')
ax.plot( ucell.ringds, [1e4,]*len(ucell.ringds), '|', ms=90, c="red")

ax.semilogy()

ax.set_xlabel("Dstar")
ax.set_ylabel("Intensity")
ax.legend()

plt.show()

In [None]:
# specify our ImageD11 indexer with these peaks

indexer = ImageD11.indexing.indexer_from_colfile(cf_strong)

print(f"Indexing {cf_strong.nrows} peaks")

In [None]:
# USER: set a tolerance in d-space (for assigning peaks to powder rings)

indexer_ds_tol = 0.006
indexer.ds_tol = indexer_ds_tol

# change the log level so we can see what the ring assigments look like

ImageD11.indexing.loglevel = 1

# assign peaks to powder rings

indexer.assigntorings()

# change log level back again

ImageD11.indexing.loglevel = 3

In [None]:
# let's plot the assigned peaks

fig, ax = plt.subplots(layout='constrained', figsize=(10,5))

# indexer.ra is the ring assignments

ax.scatter(cf_strong.ds, cf_strong.eta, c=indexer.ra, cmap='tab20', s=1)
ax.plot( ucell.ringds, [0,]*len(ucell.ringds), '|', ms=90, c="red")
ax.set_xlim(cf_strong.ds.min()-0.05, cf_strong.ds.max()+0.05)
ax.set_xlabel("d-star")
ax.set_ylabel("eta")

plt.show()

In [None]:
# now we are indexing!
# we have to choose which rings we want to generate orientations on
# generally we want two or three low-multiplicity rings that are isolated from other phases
# take a look at the ring assignment output from a few cells above, and choose two or three
rings_for_gen = [0, 1, 3, 5]

# now we want to decide which rings to score our found orientations against
# generally we can just exclude dodgy rings (close to other phases, only a few peaks in etc)
rings_for_scoring = [0, 1, 2, 3, 4, 5]

# the sequence of hkl tolerances the indexer will iterate through
# hkl_tols_seq = [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.075]
hkl_tols_seq = [0.01, 0.02, 0.03, 0.04, 0.05]
# the sequence of minpks fractions the indexer will iterate through
fracs = [0.9, 0.7]
# fracs = [0.9,]
# the tolerance in g-vector angle
cosine_tol = np.cos(np.radians(90 - ds.ostep))
# the max number of UBIs we can find per pair of rings
max_grains = 1000

grains, indexer = utils.do_index(cf=cf_strong,
                                 unitcell=ds.phases.unitcells[phase_str],
                                 dstol=indexer_ds_tol,
                                 forgen=rings_for_gen,
                                 foridx=rings_for_scoring,
                                 hkl_tols=hkl_tols_seq,
                                 fracs=fracs,
                                 cosine_tol=cosine_tol,
                                 max_grains=max_grains,
                                
)
print(f'Found {len(grains)} grains!')

In [None]:
# add temporary grain IDs to the grains

for ginc, g in enumerate(grains):
    g.gid = ginc

In [None]:
mean_unit_cell_lengths = [np.cbrt(np.linalg.det(g.ubi)) for g in grains]

fig, ax = plt.subplots()
ax.plot(mean_unit_cell_lengths)
ax.set_xlabel("Grain ID")
ax.set_ylabel("Unit cell length")
plt.show()

a0 = np.median(mean_unit_cell_lengths)
    
print(a0)

In [None]:
# assign peaks to grains

peak_assign_tol = 0.05

utils.assign_peaks_to_grains(grains, cf_strong, tol=peak_assign_tol)

In [None]:
utils.plot_index_results(indexer, cf_strong, 'First attempt')

In [None]:
utils.plot_grain_sinograms(grains, cf_strong, min(len(grains), 25))

## Note that the operations with deleting grains group and saving newly indexed grains have been commented out
## You may need to uncomment these lines in your own data processing

In [None]:
# # if you would like to overwrite the grains, you have to delete the existing grains group in grainsfile
# from ImageD11.forward_model import io
# io.delete_group_from_h5(ds.grainsfile, phase_str)

In [None]:
# # save grain data

# ds.save_grains_to_disk(grains, phase_name=phase_str)

In [None]:
# # save new things to the dataset

# ds.save()

# Use "indexing_iterative" in grainmaps.py to remove all the indexed peaks and using the left peaks to further index new grains
# Main steps:
## 1) match peaks with already-indexed grains
## 2) find the remaining peaks that are unmatched
## 3) use the remaining peaks for indexing new grains
## 4) new grains must be compared with the previously indexed grains and merge the duplicate ones
## 5) generate a new grains object that has been merged
## These steps can be looped iteratively until the indexing is satisfied, e.g. percentage of peaks used for indexing has reached to a certain level 

In [None]:
from ImageD11.forward_model import grainmaps

In [None]:
pars = ImageD11.parameters.read_par_file(ds.parfile)

In [None]:
# the first iterative indexing
grains_new = grainmaps.indexing_iterative(cf_strong, grains, ds, ucell, pars, ds_max = 1.6, tol_angle = 0.25, tol_pixel =3,
                                          peak_assign_tol = 0.25,
                                          tol_misori = 3,
                                          crystal_system='cubic',
                                          indexer_ds_tol = indexer_ds_tol,
                                          rings_for_gen = rings_for_gen,
                                          rings_for_scoring = rings_for_scoring,
                                          hkl_tols_seq = hkl_tols_seq,
                                          fracs = fracs,
                                          cosine_tol = cosine_tol,
                                          max_grains = max_grains)

In [None]:
# assign peaks to grains

peak_assign_tol = 0.05

utils.assign_peaks_to_grains(grains_new, cf_strong, tol=peak_assign_tol)

In [None]:
utils.plot_grain_sinograms(grains_new, cf_strong, min(len(grains_new), 20))

In [None]:
# a second indexing with a looser criterion for matching peaks, i.e. bigger tol_angle and tol_pixel
grains_new2 = grainmaps.indexing_iterative(cf_strong, grains_new, ds, ucell, pars, ds_max = 1.2, tol_angle = 0.5, tol_pixel =5,
                                          peak_assign_tol = 0.25,
                                          tol_misori = 3,
                                          crystal_system='cubic',
                                          indexer_ds_tol = indexer_ds_tol,
                                          rings_for_gen = rings_for_gen,
                                          rings_for_scoring = rings_for_scoring,
                                          hkl_tols_seq = hkl_tols_seq,
                                          fracs = fracs,
                                          cosine_tol = cosine_tol,
                                          max_grains = max_grains)

In [None]:
# assign peaks to grains

peak_assign_tol = 0.05

utils.assign_peaks_to_grains(grains_new2, cf_strong, tol=peak_assign_tol)

In [None]:
utils.plot_grain_sinograms(grains_new2, cf_strong, min(len(grains_new2), 20))

In [None]:
# Now I can see more and more grains indexed with iterative_indexing
len(grains), len(grains_new), len(grains_new2)

# Compute completeness for each grain and remove those that have relatively low completeness (considered to be non-trustable)

In [None]:
from ImageD11.forward_model import forward_model

In [None]:
cf_matched_all, Comp_all = forward_model.forward_match_peaks(cf_strong, grains_new2, ds, ucell, pars,
                                  ds_max = 1.2,
                                  tol_angle=0.5,
                                  tol_pixel=5,
                                  thres_int=None,
                                  verbose=1)

In [None]:
Comp_all = np.array(Comp_all, dtype = 'float')

In [None]:
# plot the completeness
plt.figure()
plt.plot(np.arange(0, len(grains_new2),1), Comp_all[:,0], 'r.')
plt.show()

In [None]:
# remove grains with a completeness threshold
thres_comp = 0.35
gid_to_remove = np.argwhere(Comp_all[:,0] < thres_comp)
print(f'Originally there are {len(grains_new2)} grains')
print(f'Now I am going to remove {gid_to_remove.shape[0]} grains with completeness < {thres_comp}')

In [None]:
grains_new3 = grainmaps.remove_gid_from_grains(grains_new2, gid_to_remove=gid_to_remove)

In [None]:
# re-assign peaks to grains
utils.assign_peaks_to_grains(grains_new3, cf_strong, tol=peak_assign_tol)
utils.plot_grain_sinograms(grains_new3, cf_strong, min(len(grains_new3), 20))

In [None]:
# There is a potential problem that assign_peaks_to_grains assigned 0 peaks to grains, whereas these grains have acceptable completeness

# Save the new results to ds

### Note that the operations with deleting grains group and saving newly indexed grains have been commented out
### You may need to uncomment these lines in your own data processing

In [None]:
# io.delete_group_from_h5(ds.grainsfile, phase_str)

In [None]:
# # save grain data

# ds.save_grains_to_disk(grains_new3, phase_name=phase_str)

In [None]:
# # save new things to the dataset

# ds.save()