# Jupyter notebook based on ImageD11 to process scanning 3DXRD data
# Written by Haixing Fang, Jon Wright and James Ball
## Date: 28/03/2024

In [None]:
exec(open('/data/id11/nanoscope/install_ImageD11_from_git.py').read())

In [None]:
# this cell is tagged with 'parameters'
# to view the tag, select the cell, then find the settings gear icon (right or left sidebar) and look for Cell Tags

# python environment stuff
PYTHONPATH = setup_ImageD11_from_git( ) # ( os.path.join( os.environ['HOME'],'Code'), 'ImageD11_git' )

# dataset file to import
dset_file = 'si_cube_test/processed/Si_cube/Si_cube_S3DXRD_nt_moves_dty/Si_cube_S3DXRD_nt_moves_dty_dataset.h5'

phase_strs = ['Fe', 'Au']

# whether or not we are combining refined tensormaps (changes where we look for them)
combine_refined = True

dset_prefix = "top_"  # some common string in the names of the datasets if processing multiple scans

In [None]:
# import functions we need

import os
import concurrent.futures
import timeit

import matplotlib
%matplotlib widget

import h5py
from tqdm.notebook import tqdm
import numpy as np
import matplotlib.pyplot as plt

from ImageD11.sinograms.sinogram import GrainSinogram, build_slice_arrays, write_slice_recon, read_slice_recon, write_h5, read_h5
import ImageD11.columnfile
from ImageD11.sinograms import properties, roi_iradon
from ImageD11.sinograms.tensor_map import TensorMap
from ImageD11.blobcorrector import eiger_spatial
from ImageD11.grain import grain
from ImageD11 import cImageD11
from xfab.parameters import read_par_file

from ImageD11.nbGui import nb_utils as utils

In [None]:
# USER: Pass path to dataset file

ds = ImageD11.sinograms.dataset.load(dset_file)
   
sample = ds.sample
dataset = ds.dsname
rawdata_path = ds.dataroot
processed_data_root_dir = ds.analysisroot

print(ds)
print(ds.shape)

In [None]:
# load phases from parameter file

ds.phases = ds.get_phases_from_disk()
ds.phases.unitcells

In [None]:
# what phases are we merging?

print(*[ds.phases.unitcells[phase_str].lattice_parameters for phase_str in phase_strs], sep='\n')

In [None]:
# choose where to import your TensorMaps from
# if you refined them, you'll need to change the below paths to point to the separate refined tensormap h5 files.

In [None]:
if combine_refined:
    tensor_maps = [TensorMap.from_h5(os.path.join(ds.analysispath, f'{ds.sample}_{ds.dset}_refined_tmap_{phase_str}.h5')) for phase_str in phase_strs]
else:
    tensor_maps = [TensorMap.from_h5(ds.grainsfile, h5group='TensorMap_' + phase_str) for phase_str in phase_strs]

In [None]:
try:
    for tmap in tensor_maps:
        tmap.plot('labels')
except KeyError:
    # no labels field
    pass

In [None]:
tensor_map_combined = TensorMap.from_combine_phases(tensor_maps)

In [None]:
tensor_map_combined.plot('phase_ids')
try:
    tensor_map_combined.plot('labels')
except KeyError:
    # no labels field
    pass
tensor_map_combined.plot('ipf_z')

In [None]:
if combine_refined:
    tensor_map_combined.to_h5(os.path.join(ds.analysispath, f'{ds.sample}_{ds.dset}_refined_tmap_combined.h5'))
    tensor_map_combined.to_paraview(os.path.join(ds.analysispath, f'{ds.sample}_{ds.dset}_refined_tmap_combined.h5'))
else:
    tensor_map_combined.to_h5(os.path.join(ds.analysispath, f'{ds.sample}_{ds.dset}_tmap_combined.h5'))
    tensor_map_combined.to_paraview(os.path.join(ds.analysispath, f'{ds.sample}_{ds.dset}_tmap_combined.h5'))

In [None]:
# We can run the below cell to do this in bulk for many samples/datasets
# by default this will do all samples in sample_list, all datasets with a prefix of dset_prefix
# you can add samples and datasets to skip in skips_dict

# you can optionally skip samples
# skips_dict = {
#     "FeAu_0p5_tR_nscope": ["top_-50um", "top_-100um"]
# }
# otherwise by default skip nothing:
skips_dict = {
    ds.sample: []
}

sample_list = [ds.sample, ]

samples_dict = utils.find_datasets_to_process(rawdata_path, skips_dict, dset_prefix, sample_list)
    
# manual override:
# samples_dict = {"FeAu_0p5_tR_nscope": ["top_100um", "top_150um"]}
    
# now we have our samples_dict, we can process our data:

for sample, datasets in samples_dict.items():
    for dataset in datasets:
        print(f"Processing dataset {dataset} in sample {sample}")
        dset_path = os.path.join(ds.analysisroot, sample, f"{sample}_{dataset}", f"{sample}_{dataset}_dataset.h5")
        if not os.path.exists(dset_path):
            print(f"Missing DataSet file for {dataset} in sample {sample}, skipping")
            continue
        
        print("Importing DataSet object")
        
        ds = ImageD11.sinograms.dataset.load(dset_path)
        print(f"I have a DataSet {ds.dset} in sample {ds.sample}")
        
        if combine_refined:
            combined_tmap_path = os.path.join(ds.analysispath, f'{ds.sample}_{ds.dset}_refined_tmap_combined.h5')
        else:
            combined_tmap_path = os.path.join(ds.analysispath, f'{ds.sample}_{ds.dset}_tmap_combined.h5')
        
        if os.path.exists(combined_tmap_path):
            print(f"Already have combined TensorMap output file for {dataset} in sample {sample}, skipping")
            continue
        
        if combine_refined:
            tensor_maps = [TensorMap.from_h5(os.path.join(ds.analysispath, f'{ds.sample}_{ds.dset}_refined_tmap_{phase_str}.h5')) for phase_str in phase_strs]
        else:
            tensor_maps = [TensorMap.from_h5(ds.grainsfile, h5group='TensorMap_' + phase_str) for phase_str in phase_strs]
        tensor_map_combined = TensorMap.from_combine_phases(tensor_maps)
        
        tensor_map_combined.to_h5(combined_tmap_path)
        tensor_map_combined.to_paraview(combined_tmap_path)
        
        ds.save()

print("Done!")