# Jupyter notebook based on ImageD11 to process scanning 3DXRD data
# Written by Haixing Fang, Jon Wright and James Ball
## Date: 12/10/2024

This notebook will try to perform a point-by-point strain refinement of your pbp index results.  
As with the pbp index, the results of this process are multi-valued.  
You can run 4_visualise to convert the refinement results to an accurate single-valued map.  

### NOTE: It is highly recommended to run this notebook on a Jupyter server with many cores and a lot of RAM.  
The compute_origins() function in particular runs locally and can be compute-intensive for large datasets.  
If this is a big scan (e.g 100 million + 2D peaks), you should definitely refine on the cluster rather than locally.

In [None]:
import os

os.environ['OMP_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'

exec(open('/data/id11/nanoscope/install_ImageD11_from_git.py').read())
PYTHONPATH = setup_ImageD11_from_git( ) # ( os.path.join( os.environ['HOME'],'Code'), 'ImageD11_git' )

In [None]:
import numba
import numpy as np
import scipy.ndimage as ndi
from matplotlib import pyplot as plt
import matplotlib.cm as cm
from matplotlib.colors import Normalize

from ImageD11.grain import grain
from ImageD11 import unitcell
import ImageD11.sinograms.dataset
from ImageD11.sinograms.point_by_point import PBPMap, nb_inv, PBPRefine
from ImageD11.sinograms.tensor_map import TensorMap
from ImageD11.nbGui import nb_utils as utils

%matplotlib ipympl

In [None]:
# USER: Pass path to dataset file

dset_file = 'si_cube_test/processed/Si_cube/Si_cube_S3DXRD_nt_moves_dty/Si_cube_S3DXRD_nt_moves_dty_dataset.h5'

ds = ImageD11.sinograms.dataset.load(dset_file)
   
sample = ds.sample
dataset = ds.dsname
rawdata_path = ds.dataroot
processed_data_root_dir = ds.analysisroot

print(ds)
print(ds.shape)

In [None]:
# load phases from parameter file

ds.phases = ds.get_phases_from_disk()

In [None]:
# now let's select a phase to index from our parameters json
phase_str = 'Fe'

ref_ucell = ds.phases.unitcells[phase_str]

print(ref_ucell.lattice_parameters, ref_ucell.spacegroup)

In [None]:
# let's load our point-by-point map from disk

pmap = PBPMap(ds.pbpfile)

In [None]:
# plot a histogram of unique peaks per ubi

pmap.plot_nuniq_hist()

In [None]:
# choose the minimum number of peaks you want a pixel to have to be counted

min_unique = 20

pmap.choose_best(min_unique)

In [None]:
# let's plot the result of your choice

pmap.plot_best(min_unique)

In [None]:
# Get the 2D peaks from disk

cf_2d = ds.get_cf_2d()
ds.update_colfile_pars(cf_2d, phase_name=phase_str)

print(cf_2d.nrows)

In [None]:
# set up a refinement manager object

y0 = 0.0
fpks = 0.95
hkl_tol_origins = 0.05
hkl_tol_refine = 0.1
hkl_tol_refine_merged = 0.05
ds_tol = 0.006
ifrac = 6e-3

refine = PBPRefine(dset=ds, y0=y0, fpks=fpks, hkl_tol_origins=hkl_tol_origins, hkl_tol_refine=hkl_tol_refine, hkl_tol_refine_merged=hkl_tol_refine_merged, ds_tol=ds_tol, ifrac=ifrac, phase_name=phase_str)

In [None]:
# tell it which point-by-point map we are refining

refine.setmap(pmap)

# or load from disk:
# refine.loadmap()

In [None]:
# choose 2D peaks to refine with

refine.setpeaks(cf_2d)

# or load from disk:
# refine.loadpeaks()

In [None]:
# plot the peaks you selected

refine.iplot()

In [None]:
# set whole-sample mask to choose where to refine

manual_threshold = None

refine.setmask(manual_threshold=manual_threshold, doplot=True)

In [None]:
# generate a single-valued map to refine on

refine.setsingle(refine.pbpmap, minpeaks=min_unique)

In [None]:
# compute diffraction origins - these will be added as a column to refine.icolf
# will then save the new column to disk to avoid re-computation

refine.get_origins()

In [None]:
# run the refinement
# if compute_origins took more than a couple of minutes to run, I suggest setting use_cluster=True below
# otherwise if you asked for lots of cores and RAM on this Jupyter instance, you can run it locally (use_cluster=False)

use_cluster = False

refine.run_refine(use_cluster=use_cluster, pythonpath=PYTHONPATH)

In [None]:
# make sure refinement is saved to disk

refine.to_h5()

ds.save()

In [None]:
if 1:
    raise ValueError("Change the 1 above to 0 to allow 'Run all cells' in the notebook")

In [None]:
# We can run the below cell to do this in bulk for many samples/datasets
# by default this will do all samples in sample_list, all datasets with a prefix of dset_prefix
# you can add samples and datasets to skip in skips_dict

skips_dict = {
    "FeAu_0p5_tR_nscope": ["top_-50um", "top_-100um"]
}

dset_prefix = "top"

sample_list = ["FeAu_0p5_tR_nscope"]
    
samples_dict = utils.find_datasets_to_process(ds.dataroot, skips_dict, dset_prefix, sample_list)
    
# manual override:
# samples_dict = {"FeAu_0p5_tR_nscope": ["top_100um", "top_150um"]}
    
# now we have our samples_dict, we can process our data:

for sample, datasets in samples_dict.items():
    for dataset in datasets:
        print(f"Processing dataset {dataset} in sample {sample}")
        dset_path = os.path.join(ds.analysisroot, sample, f"{sample}_{dataset}", f"{sample}_{dataset}_dataset.h5")
        if not os.path.exists(dset_path):
            print(f"Missing DataSet file for {dataset} in sample {sample}, skipping")
            continue
        
        print("Importing DataSet object")
        
        ds = ImageD11.sinograms.dataset.load(dset_path)
        print(f"I have a DataSet {ds.dset} in sample {ds.sample}")
        if os.path.exists(ds.refoutfile):
            print(f"Already have PBP refinement output file for {dataset} in sample {sample}, skipping")
            continue
        
        if not os.path.exists(ds.pbpfile):
            print(f"Can't find PBP indexing file for {dataset} in sample {sample}, skipping")
            continue
        
        cf_2d = ds.get_cf_2d()
        ds.update_colfile_pars(cf_2d, phase_name=phase_str)

        if not os.path.exists(ds.col2dfile):
            ImageD11.columnfile.colfile_to_hdf(cf_2d, ds.col2dfile)
            
        pmap = PBPMap(ds.pbpfile)
        pmap.choose_best(min_unique)
        
        refine = PBPRefine(dset=ds, y0=y0, fpks=fpks, hkl_tol_origins=hkl_tol_origins, hkl_tol_refine=hkl_tol_refine, hkl_tol_refine_merged=hkl_tol_refine_merged, ds_tol=ds_tol, ifrac=ifrac, phase_name=phase_str)
        
        refine.setmap(pmap)
        refine.setpeaks(cf_2d)
        refine.setmask(manual_threshold=manual_threshold, doplot=False)
        refine.setsingle(refine.pbpmap, minpeaks=min_unique)
        refine.get_origins()
        refine.run_refine(use_cluster=use_cluster, pythonpath=PYTHONPATH)
        if not use_cluster:
            # wait to complete locally, then save
            refine.to_h5()
        ds.save()

print("Done!")