# Jupyter notebook based on ImageD11 to process scanning 3DXRD data
# Written by Haixing Fang, Jon Wright and James Ball
## Date: 12/10/2024

This notebook will try to perform a point-by-point strain refinement from your tomographic-derived grain shapes.  

### NOTE: It is highly recommended to run this notebook on a Jupyter server with many cores and a lot of RAM.  
The compute_origins() function in particular runs locally and can be compute-intensive for large datasets.  
If this is a big scan (e.g 100 million + 2D peaks), you should definitely refine on the cluster rather than locally.

In [None]:
import os

os.environ['OMP_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'

exec(open('/data/id11/nanoscope/install_ImageD11_from_git.py').read())
PYTHONPATH = setup_ImageD11_from_git( ) # ( os.path.join( os.environ['HOME'],'Code'), 'ImageD11_git' )

In [None]:
# import functions we need

import os
import concurrent.futures
import timeit

import matplotlib
%matplotlib ipympl

import h5py
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib.colors import Normalize

from xfab.symmetry import Umis

import ImageD11.columnfile
from ImageD11.sinograms.tensor_map import TensorMap
from ImageD11.sinograms.point_by_point import PBPRefine
from ImageD11.peakselect import select_ring_peaks_by_intensity
from ImageD11.sinograms import properties, roi_iradon
from ImageD11.sinograms import geometry
from ImageD11.sinograms.sinogram import GrainSinogram, build_slice_arrays, write_slice_recon, read_slice_recon, write_h5, read_h5, write_pbp_strain
from ImageD11.grain import grain
from ImageD11 import cImageD11

import ImageD11.nbGui.nb_utils as utils

In [None]:
# USER: Pass path to dataset file

dset_file = 'si_cube_test/processed/Si_cube/Si_cube_S3DXRD_nt_moves_dty/Si_cube_S3DXRD_nt_moves_dty_dataset.h5'

ds = ImageD11.sinograms.dataset.load(dset_file)
   
sample = ds.sample
dataset = ds.dsname
rawdata_path = ds.dataroot
processed_data_root_dir = ds.analysisroot

print(ds)
print(ds.shape)

In [None]:
# load phases from parameter file

ds.phases = ds.get_phases_from_disk()
ds.phases.unitcells

In [None]:
# now let's select a phase to index from our parameters json
major_phase_str = 'Fe'
minor_phase_str = 'Au'

major_phase_unitcell = ds.phases.unitcells[major_phase_str]
minor_phase_unitcell = ds.phases.unitcells[minor_phase_str]

print(major_phase_str, major_phase_unitcell.lattice_parameters, major_phase_unitcell.spacegroup)
print(minor_phase_str, minor_phase_unitcell.lattice_parameters, minor_phase_unitcell.spacegroup)

In [None]:
# load 4d peaks

cf_4d = ds.get_cf_4d()

In [None]:
# for now - set parameters with major phase

ds.update_colfile_pars(cf_4d, phase_name=major_phase_str)

In [None]:
cf_major_phase = select_ring_peaks_by_intensity(cf_4d, frac=1, dsmax=cf_4d.ds.max(), dstol=0.005, doplot=None)
cf_minor_phase = select_ring_peaks_by_intensity(cf_4d, frac=1, dsmax=cf_4d.ds.max(), dstol=0.005, doplot=None)

major_phase_unitcell.makerings(cf_major_phase.ds.max())
minor_phase_unitcell.makerings(cf_minor_phase.ds.max())

In [None]:
# now we can take a look at the intensities of the remaining peaks

fig, ax = plt.subplots(figsize=(16, 9), constrained_layout=True)

ax.plot(cf_4d.ds, cf_4d.sum_intensity,',', label='cf_4d',c='blue')
ax.plot(cf_major_phase.ds, cf_major_phase.sum_intensity,',', label='major phase',c='orange')
ax.plot(cf_minor_phase.ds, cf_minor_phase.sum_intensity,',', label='minor phase',c='green')
ax.plot(major_phase_unitcell.ringds, [5e4,]*len(major_phase_unitcell.ringds), '|', ms=90, c="red")
ax.plot(minor_phase_unitcell.ringds, [1e4,]*len(minor_phase_unitcell.ringds), '|', ms=90, c="brown")
ax.semilogy()

ax.set_xlabel("Dstar")
ax.set_ylabel("Intensity")
ax.legend()

plt.show()

In [None]:
# you should choose the rings that you want to refine off from the plot above

rings_to_refine = [0, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13]

In [None]:
# Import 2D peaks

cf_2d = ds.get_cf_2d()
ds.update_colfile_pars(cf_2d, phase_name=minor_phase_str)

print(f"Read {cf_2d.nrows} 2D peaks")

In [None]:
# import grainsinos

grainsinos = read_h5(ds.grainsfile, ds, minor_phase_str)
grains = [gs.grain for gs in grainsinos]

In [None]:
# import slice reconstructions

tensor_map = TensorMap.from_h5(ds.grainsfile, h5group='TensorMap_' + minor_phase_str)

In [None]:
tensor_map.plot('phase_ids')

In [None]:
# make a PBPMap from our TensorMap

pmap = tensor_map.to_pbpmap(z_layer=0, default_npks=20, default_nuniq=20)
# fills voxels that have grains with npks = 20 and nuniq = 20

In [None]:
pmap.choose_best(1)
pmap.plot_best(1)

In [None]:
# set up a refinement manager object

y0 = grainsinos[0].recon_y0
fpks = 0.9
hkl_tol_origins = 0.05
hkl_tol_refine = 0.1
hkl_tol_refine_merged = 0.05
ds_tol = 0.004
ifrac = 1e-3
forref = [0, 2, 3, 4, 5, 6, 7, 8, 10, 12, 13]
phase_str = minor_phase_str

refine = PBPRefine(dset=ds, y0=y0, fpks=fpks, hkl_tol_origins=hkl_tol_origins, hkl_tol_refine=hkl_tol_refine, hkl_tol_refine_merged=hkl_tol_refine_merged, ds_tol=ds_tol, ifrac=ifrac, phase_name=phase_str, forref=forref)

In [None]:
# change the default paths of the refinement manager to append the phase name
# so we don't conflict

refine.own_filename = os.path.splitext(refine.own_filename)[0] + f'_{phase_str}.h5'
refine.icolf_filename = os.path.splitext(refine.icolf_filename)[0] + f'_{phase_str}.h5'
refine.pbpmap_filename = os.path.splitext(refine.pbpmap_filename)[0] + f'_{phase_str}.h5'
refine.refinedmap_filename = os.path.splitext(refine.refinedmap_filename)[0] + f'_{phase_str}.h5'

In [None]:
# choose 2D peaks to refine with

refine.setpeaks(cf_2d)

# or load from disk:
# refine.loadpeaks()

In [None]:
# plot the peaks you selected

refine.iplot()

In [None]:
# tell it which point-by-point map we are refining

refine.setmap(pmap)

# or load from disk:
# refine.loadmap()

In [None]:
# set the mask from minimum peak values
# anything greater than 0 should be accepted

refine.mask = pmap.best_npks > 0

In [None]:
fig, ax = plt.subplots()
ax.imshow(refine.mask, origin='lower')
plt.show()

In [None]:
# generate a single-valued map to refine on

refine.setsingle(refine.pbpmap, minpeaks=1)

In [None]:
# compute diffraction origins - these will be added as a column to refine.icolf
# will then save the new column to disk to avoid re-computation

refine.get_origins()

In [None]:
# run the refinement
# if compute_origins took more than a couple of minutes to run, I suggest setting use_cluster=True below
# otherwise if you asked for lots of cores and RAM on this Jupyter instance, you can run it locally (use_cluster=False)

use_cluster = False

refine.run_refine(use_cluster=use_cluster, pythonpath=PYTHONPATH)

In [None]:
# save refinement results to disk

if not use_cluster:
    refine.to_h5()

ds.save()

In [None]:
if 1:
    raise ValueError("Change the 1 above to 0 to allow 'Run all cells' in the notebook")

In [None]:
# Now that we're happy with our refinement parameters, we can run the below cell to do this in bulk for many samples/datasets
# by default this will do all samples in sample_list, all datasets with a prefix of dset_prefix
# you can add samples and datasets to skip in skips_dict

skips_dict = {
    "FeAu_0p5_tR_nscope": ["top_-50um", "top_-100um"]
}

dset_prefix = "top"

sample_list = ["FeAu_0p5_tR_nscope"]
    
samples_dict = utils.find_datasets_to_process(ds.dataroot, skips_dict, dset_prefix, sample_list)
    
# manual override:
# samples_dict = {"FeAu_0p5_tR_nscope": ["top_100um", "top_150um"]}
    
# now we have our samples_dict, we can process our data:

for sample, datasets in samples_dict.items():
    for dataset in datasets:
        print(f"Processing dataset {dataset} in sample {sample}")
        dset_path = os.path.join(ds.analysisroot, sample, f"{sample}_{dataset}", f"{sample}_{dataset}_dataset.h5")
        if not os.path.exists(dset_path):
            print(f"Missing DataSet file for {dataset} in sample {sample}, skipping")
            continue
        
        print("Importing DataSet object")
        
        ds = ImageD11.sinograms.dataset.load(dset_path)
        print(f"I have a DataSet {ds.dset} in sample {ds.sample}")
        if os.path.exists(ds.refoutfile):
            print(f"Already have PBP refinement output file for {dataset} in sample {sample}, skipping")
            continue
        
        if not os.path.exists(ds.pbpfile):
            print(f"Can't find PBP indexing file for {dataset} in sample {sample}, skipping")
            continue
        
        cf_2d = ds.get_cf_2d()
        ds.update_colfile_pars(cf_2d, phase_name=phase_str)

        if not os.path.exists(ds.col2dfile):
            ImageD11.columnfile.colfile_to_hdf(cf_2d, ds.col2dfile)
            
        grainsinos = read_h5(ds.grainsfile, ds, phase_str)
        y0 = grainsinos[0].recon_y0
        
        tensor_map = TensorMap.from_h5(ds.grainsfile, h5group='TensorMap_' + phase_str)
        pmap = tensor_map.to_pbpmap(z_layer=0, default_npks=20, default_nuniq=20)
        pmap.choose_best(1)

        refine = PBPRefine(dset=ds, y0=y0, fpks=fpks, hkl_tol_origins=hkl_tol_origins, hkl_tol_refine=hkl_tol_refine, hkl_tol_refine_merged=hkl_tol_refine_merged, ds_tol=ds_tol, ifrac=ifrac, phase_name=minor_phase_str, forref=forref)
        
        refine.setmap(pmap)
        refine.setpeaks(cf_2d)
        refine.mask = pmap.best_npks > 0
        refine.setsingle(refine.pbpmap, minpeaks=1)
        refine.get_origins()
        refine.run_refine(use_cluster=use_cluster, pythonpath=PYTHONPATH)
        if not use_cluster:
            # wait to complete locally, then save
            refine.to_h5()
        ds.save()

print("Done!")