# Jupyter notebook based on ImageD11 to process 3DXRD data
# Written by Haixing Fang, Jon Wright and James Ball
## Date: 27/02/2024

Now we have good experimental parameters, we can index more grains!

In [None]:
exec(open('/data/id11/nanoscope/install_ImageD11_from_git.py').read())
PYTHONPATH = setup_ImageD11_from_git( ) # ( os.path.join( os.environ['HOME'],'Code'), 'ImageD11_git' )

In [None]:
# import functions we need

import os, glob, pprint
import numpy as np
import h5py
from tqdm.notebook import tqdm

import matplotlib
%matplotlib widget
from matplotlib import pyplot as plt
import scipy.spatial

# import utils
from ImageD11.nbGui import nb_utils as utils

import ImageD11.grain
import ImageD11.indexing
import ImageD11.columnfile
import ImageD11.refinegrains
import ImageD11.grid_index_parallel
from ImageD11.sinograms import properties, dataset

from ImageD11.blobcorrector import eiger_spatial
from ImageD11.peakselect import select_ring_peaks_by_intensity

In [None]:
# desination of H5 files

dset_path = '/data/visitor/ihma439/id11/20231211/PROCESSED_DATA/James/20240909/FeAu_0p5_tR/FeAu_0p5_tR_ff1/FeAu_0p5_tR_ff1_dataset.h5'

In [None]:
# load the dataset from file

ds = ImageD11.sinograms.dataset.load(dset_path)

sample = ds.sample
dataset = ds.dset
print(ds)
print(ds.shape)

In [None]:
ds.phases = ds.get_phases_from_disk()
ds.phases.unitcells

# now let's select a phase to index from our parameters json
phase_str = 'Fe'

ucell = ds.phases.unitcells[phase_str]

print(ucell.lattice_parameters, ucell.spacegroup)

In [None]:
# load 3d columnfile from disk

cf_3d = ds.get_cf_3d_from_disk()
ds.update_colfile_pars(cf_3d, phase_name=phase_str) 

cf_3d_path = f'{sample}_{dataset}_3d_peaks.flt'
cf_3d.writefile(cf_3d_path)

In [None]:
# plot the 3D peaks (fewer of them) as a cake (two-theta vs eta)
# if the parameters in the par file are good, these should look like straight lines

ucell.makerings(cf_3d.ds.max())

fig, ax = plt.subplots(figsize=(16,9), layout='constrained')

ax.scatter(cf_3d.ds, cf_3d.eta, s=1)
ax.plot( ucell.ringds, [0,]*len(ucell.ringds), '|', ms=90, c="red")

ax.set_xlabel("D-star")
ax.set_ylabel("eta")

plt.show()

# First step: Visually inspect if we can easily see Friedel pairs
# Not worth doing if we can't see them!

In [None]:
# here we are filtering our peaks (cf_3d) to select only the strong peaks from the first ring

cf_strong_frac = 0.9837
cf_strong_dsmax = 0.6
cf_strong_dstol = 0.01

cf_strong = select_ring_peaks_by_intensity(cf_3d, frac=cf_strong_frac, dsmax=cf_strong_dsmax, doplot=0.8, dstol=cf_strong_dstol)
print(f"Got {cf_strong.nrows} strong peaks for indexing")

In [None]:
fig, ax = plt.subplots(figsize=(16, 9), constrained_layout=True)

ax.plot( ucell.ringds, [1e4,]*len(ucell.ringds), '|', ms=90, c="red")

ax.plot(cf_3d.ds, cf_3d.sum_intensity,',', label='cf_3d')
ax.plot(cf_strong.ds, cf_strong.sum_intensity,',', label='first ring')
ax.semilogy()

ax.set_xlabel("Dstar")
ax.set_ylabel("Intensity")
ax.legend()

plt.show()

In [None]:
lf = ImageD11.refinegrains.lf(cf_strong.tth, cf_strong.eta)

f = plt.figure(figsize=(15,5))
ax = f.add_subplot()

# select peaks between 3 and 5 degrees in omega
om1 = (cf_strong.omega < 5) & (cf_strong.omega > 3)

# plot omega against intensity for those peaks, coloured by eta (azimuthal position on the ring)
ax.scatter(cf_strong.omega[om1], np.log10(cf_strong.sum_intensity)[om1], c=cf_strong.eta[om1], marker='o')

# the friedel pair of these peaks should be 180 degrees away
etapair = 180 - cf_strong.eta

# modulate
etapair = np.where(etapair > 180, etapair - 360, etapair)

# select peaks for the friedel pairs between 183 and 185 degrees
om2 = (cf_strong.omega < 185) & (cf_strong.omega > 183)

# plot omega against intensity for the friedel pairs as crosses
ax.scatter(cf_strong.omega[om2] - 180, np.log10(cf_strong.sum_intensity)[om2], c=etapair[om2], marker='+')

# for valid friedel pairs, we should see 'o' and '+' markers close together in omega and intensity, with similar colours (eta)
plt.show()

In [None]:
def calc_tth_eta( c, pi, pj ):
    dX = c.xl[pi] + c.xl[pj]
    dY = c.yl[pi] + c.yl[pj]
    dZ = c.zl[pi] - c.zl[pj]
    r = np.sqrt(dY*dY + dZ*dZ)
    tth = np.degrees( np.arctan2( r, dX )  )
    eta = np.degrees(np.arctan2( -dY, dZ ))
    return tth, eta

def find_friedel_pairs(cf_in, doplot=False):
    womega = 1.5
    weta = 0.2
    wtth = 1.5
    wI = 0.5
    t1 = scipy.spatial.cKDTree( np.transpose( [ 
                                womega*(cf_in.omega%360),
                                weta*(cf_in.eta%360),
                                wtth*cf_in.tth,
                                wI*np.log10(cf_in.sum_intensity) ] ))

    t2 = scipy.spatial.cKDTree( np.transpose([ 
                                 womega*((cf_in.omega+180)%360),
                                 weta*((180-cf_in.eta)%360),
                                 wtth* cf_in.tth,
                                 wI*np.log10(cf_in.sum_intensity) ] ))
    
    coo = t1.sparse_distance_matrix( t2, max_distance=1, output_type='coo_matrix' ) # 1 degree eta might be tight?
    
    inds = np.arange(cf_in.nrows)
    p1 = inds[coo.row]
    p2 = inds[coo.col]
    
    tth, eta = calc_tth_eta( cf_in, p1, p2 )
    s1 = cf_3d.sum_intensity[p1]
    s2 = cf_3d.sum_intensity[p2]
    
    dstar = 2*np.sin(np.radians(tth)/2)/cf_in.parameters.get('wavelength')
    
    if doplot:
        f,a = plt.subplots(2,1,figsize=(20,6))
        a[0].hist2d(dstar,eta,bins=(2000,360), norm='log', weights=s1+s2)
        a[0].plot(ucell.ringds, np.zeros_like(ucell.ringds),"|r",lw=1,ms=90)
        a[0].set(ylabel='eta/deg')
        a[1].hist2d(dstar,coo.data,
        #            np.log(s1+s2),
                    bins=(1000,128), norm='log');
        a[1].plot( ucell.ringds, np.full_like(ucell.ringds,4),"|r",lw=1,ms=20)
        a[1].set(xlabel='dstar', ylabel='distance for search')
        plt.show()
    
    if doplot:
        f,a = plt.subplots(t1.data.shape[1],1,figsize=(20,6))
        for i in range(t1.data.shape[1]):
            a[i].hist2d(dstar, t1.data[coo.row,i] - t2.data[coo.col,i], bins=(1000,128), norm='log')
        
        plt.show()
        
    m = np.zeros_like(p1, dtype=bool)
    for d in ucell.ringds:
        m |= abs(dstar - d)<0.002
        
    c1 = cf_in.copyrows( p1[m] )
    c2 = cf_in.copyrows( p2[m] )
    
    c1.tth[:] = tth[m]
    c2.tth[:] = tth[m]
    c1.ds[:] = dstar[m]
    c2.ds[:] = dstar[m]
    
    if doplot:
        fig, ax = plt.subplots()
        ax.plot(c1.eta%360, eta[m]%360,',')
        plt.show()
        
    c1.eta[:] = eta[m]
    e2 = 180 - eta[m]
    c2.eta[:] = np.where( e2 > 180, e2-360, e2)
    
    cpair = ImageD11.columnfile.colfile_from_dict({
        t: np.concatenate( (c1[t], c2[t]) ) for t in c1.titles } )
    cpair.parameters = cf_in.parameters
    
    if doplot:
        plt.figure()
        plt.plot(c1.ds, c1.eta, ',')
        plt.plot(c2.ds, c2.eta, ',')
        plt.plot(cpair.ds, cpair.eta, ',')
        plt.show()
    
    cpair.gx[:],cpair.gy[:],cpair.gz[:] = ImageD11.transform.compute_g_vectors( cpair.tth, cpair.eta, cpair.omega, cpair.parameters.get('wavelength') )
    
    if doplot:
        plt.figure()
        plt.plot(cpair.ds, cpair.sum_intensity*np.exp(5*cpair.ds**2),',')
        plt.semilogy()
        plt.show()
        
    return cpair

In [None]:
cf_friedel_pairs = find_friedel_pairs(cf_3d, doplot=False)

In [None]:
# here we are filtering our peaks (cf_3d) to select only the strong peaks from the first ring

cf_friedel_pairs_strong_frac = 0.9837
cf_friedel_pairs_strong_dsmax = cf_friedel_pairs.ds.max()
cf_friedel_pairs_strong_dstol = 0.01

cf_friedel_pairs_strong = select_ring_peaks_by_intensity(cf_friedel_pairs, frac=cf_friedel_pairs_strong_frac, dsmax=cf_friedel_pairs_strong_dsmax, doplot=0.8, dstol=cf_friedel_pairs_strong_dstol)
print(f"Got {cf_friedel_pairs_strong.nrows} strong peaks for indexing")
# cf_strong_path = f'{sample}_{dataset}_3d_peaks_strong.flt'
# cf_strong.writefile(cf_strong_path)

In [None]:
# specify our ImageD11 indexer with these peaks

indexer = ImageD11.indexing.indexer_from_colfile(cf_friedel_pairs_strong)

print(f"Indexing {cf_friedel_pairs_strong.nrows} peaks")

# USER: set a tolerance in d-space (for assigning peaks to powder rings)

indexer_ds_tol = 0.05
indexer.ds_tol = indexer_ds_tol

# change the log level so we can see what the ring assigments look like

ImageD11.indexing.loglevel = 1

# assign peaks to powder rings

indexer.assigntorings()

# change log level back again

ImageD11.indexing.loglevel = 3

In [None]:
# let's plot the assigned peaks

fig, ax = plt.subplots()

# indexer.ra is the ring assignments

ax.scatter(cf_friedel_pairs_strong.ds, cf_friedel_pairs_strong.eta, c=indexer.ra, cmap='tab20', s=1)
ax.plot( ucell.ringds, [0,]*len(ucell.ringds), '|', ms=90, c="red")
ax.set_xlabel("d-star")
ax.set_ylabel("eta")
ax.set_xlim(cf_friedel_pairs_strong.ds.min()-0.05, cf_friedel_pairs_strong.ds.max()+0.05)

plt.show()

In [None]:
# now we are indexing!
# we have to choose which rings we want to generate orientations on
# generally we want two or three low-multiplicity rings that are isolated from other phases
# take a look at the ring assignment output from a few cells above, and choose two or three
rings_for_gen = [1, 1]

# now we want to decide which rings to score our found orientations against
# generally we can just exclude dodgy rings (close to other phases, only a few peaks in etc)
rings_for_scoring = [0, 1, 2, 3, 4, 5, 6, 7, 8]

# the sequence of hkl tolerances the indexer will iterate through
hkl_tols_seq = [0.01, 0.02, 0.03]
# the sequence of minpks fractions the indexer will iterate through
fracs = [0.5]
# the tolerance in g-vector angle
cosine_tol = np.cos(np.radians(90 - 0.25))
# the max number of UBIs we can find per pair of rings
max_grains = 1000

_, indexer = utils.do_index(cf=cf_friedel_pairs_strong,
                                dstol=indexer.ds_tol,
                                forgen=rings_for_gen,
                                foridx=rings_for_scoring,
                                hkl_tols=hkl_tols_seq,
                                fracs=fracs,
                                cosine_tol=cosine_tol,
                                max_grains=max_grains,
                                unitcell=ucell
)

In [None]:
# inspect the results of the index

indexer.histogram_drlv_fit()

plt.figure()
for row in indexer.histogram:
    plt.plot(indexer.bins[1:-1], row[:-1],'-')

In [None]:
# now we switch to grid indexing

In [None]:
omegas_sorted = np.sort(ds.omega)[0]
omega_step = np.round(np.diff(omegas_sorted).mean(), 3)
omega_slop = omega_step/2

gridpars = {
    'DSTOL' : 0.004,
    'OMEGAFLOAT' : omega_slop,
    'COSTOL' : 0.002,
    'NPKS' : 10,
    'TOLSEQ' : [ 0.05, ],
    'SYMMETRY' : "cubic",
    'RING1'  : [1,5],
    'RING2' : [1,5],
    'NUL' : True,
    'FITPOS' : True,
    'tolangle' : 0.25,
    'toldist' : 100.,
    'NPROC' : None, # guess from cpu_count
    'NTHREAD' : 1 ,
    }

cf_friedel_pairs_strong.addcolumn(indexer.ga.copy(), 'labels')
cf_friedel_pairs_strong.addcolumn(np.zeros(cf_friedel_pairs_strong.nrows), 'drlv2')

for v in 'xyz':
    cf_3d.parameters.stepsizes[f't_{v}'] = 0.1

fittedgrains = []
for i in range(len(indexer.ubis)):
    grains = [ImageD11.grain.grain(indexer.ubis[i].copy() ),]
    # only take indexed spots using Friedel pairs
    cfit = ImageD11.columnfile.colfile_from_dict(
        { t:cf_friedel_pairs_strong[t][indexer.ga==i+1] for t in cf_friedel_pairs_strong.titles} )
    if cfit.nrows == 0:
        continue
    fitted = ImageD11.grid_index_parallel.domap( cf_3d.parameters,
                                    cfit,
                                    grains,
                                    gridpars )
    fittedgrains.append( fitted[0] )
    print(fitted[0].ubi)
    print(fitted[0].translation, fitted[0].npks, fitted[0].nuniq )

In [None]:
centre_plot = False

fig = plt.figure(figsize=(12, 12))
ax = fig.add_subplot(projection='3d', proj_type="ortho")
xx = [grain.translation[0] for grain in fittedgrains]
yy = [grain.translation[1] for grain in fittedgrains]
zz = [grain.translation[2] for grain in fittedgrains]
# col = [utils.grain_to_rgb(grain) for grain in grains_filtered]  # IPF-Z colour instead
col = [float(grain.npks) for grain in fittedgrains]
sizes = [0.01*(float(grain.intensity_info.split("mean = ")[1].split(" , ")[0].replace("'", ""))) for grain in fittedgrains]
if centre_plot:
    scatterplot = ax.scatter(xx-np.mean(xx), yy-np.mean(yy), zz, c=col, s=sizes)
else:
    scatterplot = ax.scatter(xx, yy, zz, c=col, s=sizes)
ax.set_aspect("equal")
plt.colorbar(scatterplot)
ax.set_title("Grains coloured by n peaks")
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_zlabel("z")
plt.show()

In [None]:
fig, ax = plt.subplots()
ax.hist([float(grain.npks) for grain in fittedgrains], bins=50)
# ax.semilogy()
plt.show()

In [None]:
# find the spike
absolute_minpks = 250

In [None]:
# filter out grains with fewer than absolute_minpks peaks
grains_filtered = [grain for grain in fittedgrains if float(grain.npks) > absolute_minpks]

In [None]:
centre_plot = False

fig = plt.figure(figsize=(12, 12))
ax = fig.add_subplot(projection='3d', proj_type="ortho")
xx = [grain.translation[0] for grain in grains_filtered]
yy = [grain.translation[1] for grain in grains_filtered]
zz = [grain.translation[2] for grain in grains_filtered]
# col = [utils.grain_to_rgb(grain) for grain in grains_filtered]  # IPF-Z colour instead
col = [float(grain.npks) for grain in grains_filtered]
sizes = [0.01*(float(grain.intensity_info.split("mean = ")[1].split(" , ")[0].replace("'", ""))) for grain in grains_filtered]
if centre_plot:
    scatterplot = ax.scatter(xx-np.mean(xx), yy-np.mean(yy), zz, c=col, s=sizes)
else:
    scatterplot = ax.scatter(xx, yy, zz, c=col, s=sizes)
ax.set_aspect("equal")
plt.colorbar(scatterplot)
ax.set_title("Grains coloured by n peaks")
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.set_zlabel("z")
plt.show()

In [None]:
# write the filtered grains to disk

filtered_map_path = f'{sample}_{dataset}_nice_grains.map'

ImageD11.grain.write_grain_file(filtered_map_path, grains_filtered)

In [None]:
# write cf_3d to disk temporarily

cf_3d_path = f'{sample}_{dataset}_3d_peaks.flt'
cf_3d.writefile(cf_3d_path)

In [None]:
# write a classic parameter file for makemap.py

from ImageD11 import parameters

pars = parameters.parameters()
pars.parameters.update(ds.phases.get_xfab_pars_dict(phase_str))

oldparfile = phase_str + '.par'

pars.saveparameters(oldparfile)

In [None]:
# run makemap again against all peaks

symmetry = "cubic"

new_filtered_map_path = f'{sample}_{dataset}_nice_grains.map.new'
new_cf_3d_path = cf_3d_path + '.new'

final_makemap_tol = 0.01

makemap_output = !makemap.py -p {oldparfile} -u {filtered_map_path} -U {new_filtered_map_path} -f {cf_3d_path} -s {symmetry} -t {final_makemap_tol} --omega_slop={omega_slop} --no_sort

In [None]:
grains_final = ImageD11.grain.read_grain_file(new_filtered_map_path)

# import makemap output columnfile with peak assignments
cf_3d = ImageD11.columnfile.columnfile(new_cf_3d_path)

# write 3D columnfile to disk
ImageD11.columnfile.colfile_to_hdf(cf_3d, ds.col3dfile, name='peaks')

# save grain data
ds.save_grains_to_disk(grains_final, phase_name=phase_str)

ds.save()

In [None]:
# cleaning up

for path in [
    cf_3d_path,
    filtered_map_path,
    new_filtered_map_path,
    new_cf_3d_path
]:
    if os.path.exists(path):
        os.remove(path)

In [None]:
# change to 0 to allow all cells to be run automatically
if 1:
    raise ValueError("Hello!")

In [None]:
# Now that we're happy with our indexing parameters, we can run the below cell to do this in bulk for many samples/datasets
# by default this will do all samples in sample_list, all datasets with a prefix of dset_prefix
# you can add samples and datasets to skip in skips_dict

skips_dict = {
    "FeAu_0p5_tR": []
}

dset_prefix = "ff"

sample_list = ["FeAu_0p5_tR"]
    
samples_dict = utils.find_datasets_to_process(ds.dataroot, skips_dict, dset_prefix, sample_list)


for sample, datasets in samples_dict.items():
    for dataset in datasets:
        print(f"Processing dataset {dataset} in sample {sample}")
        print("Importing DataSet object")
        dset_path = os.path.join(ds.analysisroot, sample, f"{sample}_{dataset}", f"{sample}_{dataset}_dataset.h5")
        
        if not os.path.exists(dset_path):
            print(f"Couldn't find {dataset} in {sample}, skipping")
            continue
            
        ds = ImageD11.sinograms.dataset.load(dset_path)
        print(f"I have a DataSet {ds.dset} in sample {ds.sample}")
        
        if os.path.exists(ds.grainsfile):
            # check grains file for existance of minor phase, skip if it's there
            with h5py.File(ds.grainsfile, "r") as hin:
                if phase_str in hin.keys():
                    print(f"Already have grains for {dataset} in sample {sample}, skipping")
                    continue
        
        ds.phases = ds.get_phases_from_disk()
        ucell = ds.phases.unitcells[phase_str]
        
        print("Loading 3D peaks")
        cf_3d = ds.get_cf_3d_from_disk()
        ds.update_colfile_pars(cf_3d, phase_name=phase_str) 
        cf_3d_path = f'{sample}_{dataset}_3d_peaks.flt'
        cf_3d.writefile(cf_3d_path)
        
        ucell.makerings(cf_3d.ds.max())

        print("Finding Friedel pairs")
        cf_friedel_pairs = find_friedel_pairs(cf_3d, doplot=False)
        cf_friedel_pairs_strong_dsmax = cf_friedel_pairs.ds.max()
        cf_friedel_pairs_strong = select_ring_peaks_by_intensity(cf_friedel_pairs, frac=cf_friedel_pairs_strong_frac, dsmax=cf_friedel_pairs_strong_dsmax, dstol=cf_friedel_pairs_strong_dstol)
        
        print('Finding orientations from collapsed Friedel pairs')
        _, indexer = utils.do_index(cf=cf_friedel_pairs_strong,
                                dstol=indexer.ds_tol,
                                forgen=rings_for_gen,
                                foridx=rings_for_scoring,
                                hkl_tols=hkl_tols_seq,
                                fracs=fracs,
                                cosine_tol=cosine_tol,
                                max_grains=max_grains,
                                unitcell=ucell
                                   )
        
        print('Fitting positions of indexed grains')
        omegas_sorted = np.sort(ds.omega)[0]
        omega_step = np.round(np.diff(omegas_sorted).mean(), 3)
        omega_slop = omega_step/2
        gridpars['OMEGAFLOAT'] = omega_slop
        
        cf_friedel_pairs_strong.addcolumn(indexer.ga.copy(), 'labels')
        cf_friedel_pairs_strong.addcolumn(np.zeros(cf_friedel_pairs_strong.nrows), 'drlv2')

        for v in 'xyz':
            cf_3d.parameters.stepsizes[f't_{v}'] = 0.1
        
        fittedgrains = []
        for i in range(len(indexer.ubis)):
            grains = [ImageD11.grain.grain(indexer.ubis[i].copy() ),]
            cfit = ImageD11.columnfile.colfile_from_dict(
                { t:cf_friedel_pairs_strong[t][indexer.ga==i+1] for t in cf_friedel_pairs_strong.titles} )
            if cfit.nrows == 0:
                continue
            fitted = ImageD11.grid_index_parallel.domap( cf_3d.parameters,
                                            cfit,
                                            grains,
                                            gridpars )
            fittedgrains.append( fitted[0] )
        
        grains_filtered = [grain for grain in fittedgrains if float(grain.npks) > absolute_minpks]
        filtered_map_path = f'{sample}_{dataset}_nice_grains.map'
        ImageD11.grain.write_grain_file(filtered_map_path, grains_filtered)
        new_filtered_map_path = f'{sample}_{dataset}_nice_grains.map.new'
        new_cf_3d_path = cf_3d_path + '.new'
        makemap_output = !makemap.py -p {oldparfile} -u {filtered_map_path} -U {new_filtered_map_path} -f {cf_3d_path} -s {symmetry} -t {final_makemap_tol} --omega_slop={omega_slop} --no_sort
        
        grains_final = ImageD11.grain.read_grain_file(new_filtered_map_path)
        cf_3d = ImageD11.columnfile.columnfile(new_cf_3d_path)
        ImageD11.columnfile.colfile_to_hdf(cf_3d, ds.col3dfile, name='peaks')
        ds.save_grains_to_disk(grains_final, phase_name=phase_str)
        ds.save()
        
        for path in [
            cf_3d_path,
            filtered_map_path,
            new_filtered_map_path,
            new_cf_3d_path
        ]:
            if os.path.exists(path):
                os.remove(path)

print("Done!")