# Jupyter notebook based on ImageD11 to process 3DXRD data
# Written by Haixing Fang, Jon Wright and James Ball
## Date: 27/02/2024

This notebook will help you to extract the locations of diffraction peaks on your detector images.

It will also merge together your 2D spots (on a stack of detector images with different omega angles).

We merge across omega because we often see the same spot twice on multiple detector images.

The results are saved to the PROCESSED_DATA folder of the experiment, inside the sample and dataset folders that you select within this notebook

## NOTE: These notebooks are under active development
They require the latest version of ImageD11 from Git to run.

If you don't have this set up yet, you can run the below cell.

It will automatically download and install ImageD11 to your home directory

In [None]:
import os

home_dir = !echo $HOME
home_dir = str(home_dir[0])

# USER: You can change this location if you want

id11_code_path = os.path.join(home_dir, "Code/ImageD11")

# check whether we already have ImageD11 here

if os.path.exists(id11_code_path):
    raise FileExistsError("ImageD11 already present! Giving up")

!git clone https://github.com/FABLE-3DXRD/ImageD11 {id11_code_path}
output = !cd {id11_code_path} && python setup.py build_ext --inplace

if not os.path.exists(os.path.join(id11_code_path, "build")):
    raise FileNotFoundError(f"Can't find build folder in {id11_code_path}, compilation went wrong somewhere")

import sys

sys.path.insert(0, id11_code_path)

# if this works, we installed ImageD11 properly!
try:
    import ImageD11.cImageD11
except:
    raise FileNotFoundError("Couldn't import cImageD11, there's a problem with your Git install!")

In [None]:
# USER: Change the path below to point to your local copy of ImageD11:

import os

home_dir = !echo $HOME
home_dir = str(home_dir[0])

# USER: You can change this location if you want

id11_code_path = os.path.join(home_dir, "Code/ImageD11")

import sys

sys.path.insert(0, id11_code_path)

In [None]:
# import functions we need

import glob, pprint

import ImageD11.sinograms.dataset
import ImageD11.sinograms.lima_segmenter
import ImageD11.sinograms.assemble_label
import ImageD11.sinograms.properties

import numpy as np
import fabio
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from skimage import filters, measure, morphology
import ipywidgets as widgets
import h5py
from IPython.display import display
%matplotlib widget

from ImageD11.nbGui import nb_utils as utils

from frelon_peaksearch import worker, process

# from utils import apply_spatial

In [None]:
# Check that we're importing ImageD11 from the home directory rather than from the Jupyter kernel

?ImageD11.sinograms.dataset

In [None]:
# NOTE: For old datasets before the new directory layout structure, we don't distinguish between RAW_DATA and PROCESSED_DATA

### USER: specify your experimental directory

rawdata_path = "/data/visitor/ihma439/id11/20231211/RAW_DATA"

!ls -lrt {rawdata_path}

### USER: specify where you want your processed data to go

processed_data_root_dir = "/data/visitor/ihma439/id11/20231211/PROCESSED_DATA/James/nb_testing"# NOTE: For old datasets before the new directory layout structure, we don't distinguish between RAW_DATA and PROCESSED_DATA

### USER: specify your experimental directory

rawdata_path = "/data/visitor/ihma439/id11/20231211/RAW_DATA"

!ls -lrt {rawdata_path}

### USER: specify where you want your processed data to go

processed_data_root_dir = "/data/visitor/ihma439/id11/20231211/PROCESSED_DATA/James/nb_testing"

In [None]:
# USER: pick a sample and a dataset you want to segment

sample = "FeAu_0p5_tR"
dataset = "ff1"

# USER: specify path to detector spline file

spline_file = '/data/id11/inhouse1/ewoks/detectors/files/Frelon2k_C36/frelon36.spline'

In [None]:
# create ImageD11 dataset object

ds = ImageD11.sinograms.dataset.DataSet(dataroot=rawdata_path,
                                        analysisroot=processed_data_root_dir,
                                        sample=sample,
                                        dset=dataset,
                                        detector="frelon3",
                                        omegamotor="diffrz",
                                        dtymotor="diffty")
ds.import_all(scans=["1.1"])
ds.save()

In [None]:
# USER: specify path to background and mask file

bg_file = "/home/esrf/james1997a/Data/ihma439/id11/20231211/PROCESSED_DATA/FeAu_0p5_tR/tdxrd_all/ff_bkg.edf"
maskfile = '/data/id11/inhouse1/ewoks/detectors/files/Frelon2k_C36/mask.edf'

In [None]:
ds.splinefile = spline_file
ds.maskfile = maskfile
ds.bgfile = bg_file

In [None]:
#Define the initial parameters
start_worker_args = {
    "bgfile":ds.bgfile,
    "maskfile":ds.maskfile,
    "threshold":50,
    "smoothsigma":1.0,
    "bgc":0.9,
    "minpx":3,
    "m_offset_thresh":80,
    "m_ratio_thresh":135,
}

In [None]:
with h5py.File(ds.masterfile, 'r') as h5In:
    test_image = h5In['1.1/measurement/frelon3'][0].astype('uint16')

# Display the image initially
fig, axs = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(16, 5))
test_image_worker = worker(**start_worker_args)
goodpeaks = test_image_worker.peaksearch(img=test_image, omega=0)
fc, sc = goodpeaks[:, 23:25].T  # 23 and 24 are the columns for fc and sc from blob properties

im1 = axs[0].imshow(test_image, norm=LogNorm(vmax=1000))
axs[0].set_title("Original image")
im2 = axs[1].imshow(test_image_worker.smoothed, cmap="viridis", norm=LogNorm(vmax=1000), interpolation="nearest")
axs[1].set_title("Background corrected")
im3 = axs[2].imshow(test_image_worker.smoothed, cmap="viridis", norm=LogNorm(vmax=1000), interpolation="nearest")
axs[2].set_title(f"{len(fc)} peaks")
sc1, = axs[2].plot(fc, sc, marker='+', c="r", ls="")
axs[2].set_aspect(1)
plt.show()

thresh_slider = widgets.IntSlider(value=start_worker_args["threshold"], min=1, max=100, step=1, description='Threshold:')
smsig_slider = widgets.FloatSlider(value=start_worker_args["smoothsigma"], min=0.0, max=1.0, step=0.05, description='Smoothsigma:')
bgc_slider = widgets.FloatSlider(value=start_worker_args["bgc"], min=0.0, max=1.0, step=0.05, description='bgc:')
minpx_slider = widgets.IntSlider(value=start_worker_args["minpx"], min=1, max=5, step=1, description='minpx:')
mofft_slider = widgets.IntSlider(value=start_worker_args["m_offset_thresh"], min=1, max=200, step=1, description='m_offset_thresh:')
mratt_slider = widgets.IntSlider(value=start_worker_args["m_ratio_thresh"], min=1, max=200, step=1, description='m_ratio_thresh:')


def update(threshold, smoothsigma, bgc, minpx, m_offset_thresh, m_ratio_thresh):
    image_worker = worker(ds.bgfile,
                          ds.maskfile,
                          threshold,
                          smoothsigma,
                          bgc,
                          minpx,
                          m_offset_thresh,
                          m_ratio_thresh)
    goodpeaks = image_worker.peaksearch(img=test_image, omega=0)
    fc, sc = goodpeaks[:, 23:25].T
    im2.set_data(image_worker.smoothed)
    im3.set_data(image_worker.smoothed)
    sc1.set_data(fc, sc)
    axs[2].set_title(f"{len(fc)} peaks")
    plt.draw()

interactive_plot = widgets.interactive(update,
                                       threshold=thresh_slider,
                                       smoothsigma=smsig_slider,
                                       bgc=bgc_slider,
                                       minpx=minpx_slider,
                                       m_offset_thresh=mofft_slider,
                                       m_ratio_thresh=mratt_slider)

display(interactive_plot)

In [None]:
end_worker_args = {
    "bgfile":ds.bgfile,
    "maskfile":ds.maskfile,
    "threshold":thresh_slider.value,
    "smoothsigma":smsig_slider.value,
    "bgc":bgc_slider.value,
    "minpx":minpx_slider.value,
    "m_offset_thresh":mofft_slider.value,
    "m_ratio_thresh":mratt_slider.value,
}

In [None]:
print(end_worker_args)

In [None]:
# now we run the segmenter on all our data

nthreads = len(os.sched_getaffinity(os.getpid()))

cf_2d, cf_3d = process(ds, nthreads-1, end_worker_args)

In [None]:
# we can use this to verify that the 3D merging is behaving as expected
# don't worry about this too much!

# take a few 3d peaks with the most 2d peaks, plot them

unique, counts = np.unique(cf_2d.spot3d_id, return_counts=True)
hits_dict = dict(zip(unique, counts))
hits_dict_max = sorted(hits_dict.items(), key=lambda x: x[1], reverse=True)

m = np.isin(cf_3d.index, [spot3d_id for spot3d_id, count in hits_dict_max[500:501]])
cf_3d_single_peak = cf_3d.copy()
cf_3d_single_peak.filter(m)

peak_2d_mask = np.isin(cf_2d.spot3d_id, cf_3d_single_peak.index)
cf_2d_peaks = cf_2d.copy()
cf_2d_peaks.filter(peak_2d_mask)

fig, ax = plt.subplots()
ax.scatter(cf_3d_single_peak.f_raw, cf_3d_single_peak.s_raw, marker="X", c=cf_3d_single_peak.omega, s=50, label='Merged 3D peak')
cols = ax.scatter(cf_2d_peaks.f_raw, cf_2d_peaks.s_raw, c=cf_2d_peaks.o_raw, s=cf_2d_peaks.s_I / 1000, label='Contibutory 2D peaks')
fig.colorbar(cols)
ax.set_xlim(0, 2048)
ax.set_ylim(0, 2048)
ax.invert_yaxis()
ax.legend()
ax.set_title("Color is omega of peak. Scaled by sum intensity")
ax.set_xlabel("f_raw")
ax.set_ylabel("s_raw")
plt.show()

In [None]:
cf_2d = utils.apply_spatial_lut(cf_2d, spline_file)

In [None]:
cf_3d = utils.apply_spatial_lut(cf_3d, spline_file)

In [None]:
parfile = '/home/esrf/james1997a/Data/ihma439/id11/20231211/SCRIPTS/James/3DXRD/Fe_tdxrd_refined.par'

In [None]:
cf_2d.parameters.loadparameters(parfile)

cf_2d.updateGeometry()
ImageD11.columnfile.colfile_to_hdf(cf_2d, ds.col2dfile)

In [None]:
cf_3d.parameters.loadparameters(parfile)
cf_3d.updateGeometry()
ImageD11.columnfile.colfile_to_hdf(cf_3d, ds.col3dfile)

In [None]:
ds.parfile = parfile
ds.save()

In [None]:
# change to 0 to allow all cells to be run automatically
if 1:
    raise ValueError("Hello!")

In [None]:
# Now that weparfile happy with our indexing parameters, we can run the below cell to do this in bulk for many samples/datasets
# by default this will do all samples in sample_list, all datasets with a prefix of dset_prefix
# you can add samples and datasets to skip in skips_dict

skips_dict = {
    "FeAu_0p5_tR": []
}

dset_prefix = "ff"

sample_list = ["FeAu_0p5_tR"]
    
samples_dict = utils.find_datasets_to_process(rawdata_path, skips_dict, dset_prefix, sample_list)
    
# manual override:
# samples_dict = {"FeAu_0p5_tR_nscope": ["top_100um", "top_200um"]}

worker_args = end_worker_args

nthreads = len(os.sched_getaffinity(os.getpid()))

for sample, datasets in samples_dict.items():
    for dataset in datasets:
        print(f"Processing dataset {dataset} in sample {sample}")
        print("Importing DataSet object")
        ds = ImageD11.sinograms.dataset.DataSet(dataroot=rawdata_path,
                                            analysisroot=processed_data_root_dir,
                                            sample=sample,
                                            dset=dataset,
                                            detector="frelon3",
                                            omegamotor="diffrz",
                                            dtymotor="diffty")
        
        if os.path.exists(ds.col2dfile):
            print(f"Found existing cf_2d for {dataset} in {sample}, skipping")
            continue
        
        ds.import_all(scans=["1.1"])
        print(f"I have a DataSet {ds.dset} in sample {ds.sample}")
        ds.save()
        
        ds.splinefile = spline_file
        ds.maskfile = maskfile
        ds.bgfile = bg_file

        print("Peaksearching")
        cf_2d, cf_3d = process(ds, nthreads-1, worker_args)
        
        print("Spatially correcting peaks")
        cf_2d = utils.apply_spatial_lut(cf_2d, spline_file)
        cf_3d = utils.apply_spatial_lut(cf_3d, spline_file)
        
        print("Saving peaks to file")
        cf_2d.parameters.loadparameters(parfile)

        cf_2d.updateGeometry()
        ImageD11.columnfile.colfile_to_hdf(cf_2d, ds.col2dfile)
        
        cf_3d.parameters.loadparameters(parfile)
        cf_3d.updateGeometry()
        ImageD11.columnfile.colfile_to_hdf(cf_3d, ds.col3dfile)
        
        ds.parfile = parfile
        ds.save()