# Jupyter notebook based on ImageD11 to process scanning 3DXRD data
# Written by Haixing Fang, Jon Wright and James Ball
## Date: 26/02/2024

## NOTE: These notebooks are under active development
They require the latest version of ImageD11 from Git to run.

If you don't have this set up yet, you can run the below cell.

It will automatically download and install ImageD11 to your home directory

In [None]:
import os

home_dir = !echo $HOME
home_dir = str(home_dir[0])

# USER: You can change this location if you want

id11_code_path = os.path.join(home_dir, "Code/ImageD11")

# check whether we already have ImageD11 here

if os.path.exists(id11_code_path):
    raise FileExistsError("ImageD11 already present! Giving up")

!git clone https://github.com/FABLE-3DXRD/ImageD11 {id11_code_path}
output = !cd {id11_code_path} && python setup.py build_ext --inplace

if not os.path.exists(os.path.join(id11_code_path, "build")):
    raise FileNotFoundError(f"Can't find build folder in {id11_code_path}, compilation went wrong somewhere")

import sys

sys.path.insert(0, id11_code_path)

# if this works, we installed ImageD11 properly!
try:
    import ImageD11.cImageD11
except:
    raise FileNotFoundError("Couldn't import cImageD11, there's a problem with your Git install!")

In [None]:
# USER: Change the path below to point to your local copy of ImageD11:

import os

home_dir = !echo $HOME
home_dir = str(home_dir[0])

# USER: You can change this location if you want

id11_code_path = os.path.join(home_dir, "Code/ImageD11")

import sys

sys.path.insert(0, id11_code_path)

In [None]:
# import functions we need

import glob, pprint
import fabio
import time
import shutil

import ImageD11.sinograms.dataset
import ImageD11.sinograms.lima_segmenter
import ImageD11.sinograms.assemble_label
import ImageD11.sinograms.properties
import ImageD11.nbGui.nb_utils as utils

import numpy as np
import fabio
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from skimage import filters, measure, morphology
from ipywidgets import interact, interactive, widgets, fixed, Layout
import h5py

%matplotlib ipympl

In [None]:
# Check that we're importing ImageD11 from the home directory rather than from the Jupyter kernel

?ImageD11.sinograms.dataset

In [None]:
# NOTE: For old datasets before the new directory layout structure, we don't distinguish between RAW_DATA and PROCESSED_DATA

### USER: specify your experimental directory

rawdata_path = "/home/esrf/james1997a/Data/ihma439/id11/20231211/RAW_DATA"

!ls -lrt {rawdata_path}

### USER: specify where you want your processed data to go

processed_data_root_dir = "/home/esrf/james1997a/Data/ihma439/id11/20231211/PROCESSED_DATA/James/20240304"

In [None]:
# USER: pick a sample and a dataset you want to segment

sample = "FeAu_0p5_tR_nscope"
dataset = "top_100um"

# USER: specify path to detector mask

# mask_path = '/data/id11/inhouse1/ewoks/detectors/files/eiger_E-08-0173/mask_with_gaps_E-08-0173.edf'  # temporary eiger mask (Nov 2023)
# mask_path = '/data/id11/inhouse1/ewoks/detectors/files/eiger_E-08-0144/mask.edf'  # normal eiger mask

mask_path = '/data/id11/inhouse1/ewoks/detectors/files/eiger_E-08-0173/mask_with_gaps_E-08-0173.edf'  # temporary eiger mask (Nov 2023)

In [None]:
# create ImageD11 dataset object

ds = ImageD11.sinograms.dataset.DataSet(dataroot=rawdata_path,
                                        analysisroot=processed_data_root_dir,
                                        sample=sample,
                                        dset=dataset)
ds.import_all()
ds.save()

In [None]:
#Define the initial parameters
start_pars = {#"bgfile": bg_path,
              "maskfile": mask_path,
              "cut": 1,
              "pixels_in_spot": 3,
               "howmany": 100000}

def segment_frame_from_options( ds, options, image_file_num=None):
    opts = ImageD11.sinograms.lima_segmenter.OPTIONS = ImageD11.sinograms.lima_segmenter.SegmenterOptions(**options)
    opts.setup()
    if image_file_num is None:
        image_file_num = len(ds.imagefiles)//2
    hfile = ImageD11.sinograms.lima_segmenter.os.path.join( ds.datapath, ds.imagefiles[ image_file_num ] )
    with h5py.File( hfile, 'r' ) as hin:
        frms = hin[ds.limapath]
        for i, spf in enumerate( ImageD11.sinograms.lima_segmenter.reader( frms, opts.mask, opts.cut ) ): 
            ref = frms[i]
            break
    
    spi = spf.to_dense('intensity')
    
    if opts.mask is not None:
        rshow = ref * opts.mask
        sshow = spi * opts.mask
    else:
        rshow = ref
        sshow = spi
    
    return rshow, sshow

cut_slider = widgets.IntSlider(value=start_pars["cut"], min=1, max=20, step=1, description='Cut:')
pixels_in_spot_slider = widgets.IntSlider(value=start_pars["pixels_in_spot"], min=1, max=20, step=1, description='Pixels in Spot:')
howmany_slider = widgets.IntSlider(value=np.log10(start_pars["howmany"]), min=1, max=15, step=1, description='log(howmany):')

# Display the image initially
fig, axs = plt.subplots(1, 2, sharex=True, sharey=True, figsize=(16, 9))
raw_image, segmented_image = segment_frame_from_options(ds, start_pars)
im1 = axs[0].imshow(raw_image, cmap="viridis", norm=LogNorm(vmin=1, vmax=1000), interpolation="nearest")
im2 = axs[1].imshow(segmented_image, cmap="viridis", norm=LogNorm(vmin=1, vmax=1000), interpolation="nearest")
axs[0].set_title("Raw image")
axs[1].set_title("Segmented image")
labels, nblobs = measure.label(segmented_image, connectivity=2, return_num=True)
plt.suptitle(f"{nblobs} peaks found\n cut={cut_slider.value}, pixels_in_spot={pixels_in_spot_slider.value}, howmany={10**howmany_slider.value}")
plt.show()

def update_image(cut, pixels_in_spot, howmany):
    howmany_exp = 10**howmany
    these_opts = start_pars.copy()
    these_opts["cut"] = cut
    these_opts["pixels_in_spot"] = pixels_in_spot
    these_opts["howmany"] = howmany_exp
    raw_image, segmented_image = segment_frame_from_options(ds, these_opts)
    labels, nblobs = measure.label(segmented_image, connectivity=2, return_num=True)
    im1.set_data(raw_image)
    im2.set_data(segmented_image)
    plt.suptitle(f"{nblobs} peaks found\n cut={cut}, pixels_in_spot={pixels_in_spot}, howmany={howmany_exp}")
    plt.draw()

interactive_plot = widgets.interactive(update_image, cut=cut_slider, pixels_in_spot=pixels_in_spot_slider, howmany=howmany_slider)

display(interactive_plot)

In [None]:
end_pars = {# "bgfile": bg_path,
              "maskfile": mask_path,
              "cut": cut_slider.value,
              "pixels_in_spot": pixels_in_spot_slider.value,
              "howmany": 10**howmany_slider.value}

In [None]:
# create batch file to send to SLURM cluster

sbat = ImageD11.sinograms.lima_segmenter.setup(ds.dsfile, **end_pars)
if sbat is None:
    raise ValueError("This scan has already been segmented!")
print(sbat)

In [None]:
utils.slurm_submit_and_wait(sbat, 60)

In [None]:
# label sparse peaks

ImageD11.sinograms.assemble_label.main(ds.dsfile)

In [None]:
# generate peaks table

ImageD11.sinograms.properties.main(ds.dsfile, options={'algorithm': 'lmlabel', 'wtmax': 70000, 'save_overlaps': False})

In [None]:
# make a new subfolder called "sparse" that holds all the individual "scan______sparse.h5" files

sparse_folder_path = os.path.join(ds.analysispath, "sparse")

if not os.path.exists(sparse_folder_path):
    os.mkdir(sparse_folder_path)
    
scan_sparse_files = glob.glob(os.path.join(ds.analysispath, "scan*_sparse.h5"))

for scan_sparse_file in scan_sparse_files:
    shutil.move(scan_sparse_file, sparse_folder_path)

In [None]:
# TODO: incorporate DATA/visitor/ma5839/id11/20240118/SCRIPTS/0_S3DXRD_segment_and_label_single_dset.ipynb

In [None]:
if 1:
    raise ValueError("Change the 1 above to 0 to allow 'Run all cells' in the notebook")

In [None]:
# Now that we're happy with our indexing parameters, we can run the below cell to do this in bulk for many samples/datasets
# by default this will do all samples in sample_list, all datasets with a prefix of dset_prefix
# you can add samples and datasets to skip in skips_dict

skips_dict = {
    "FeAu_0p5_tR_nscope": ["top_-50um", "top_-100um"]
}

dset_prefix = "top"

sample_list = ["FeAu_0p5_tR_nscope"]
    
samples_dict = utils.find_datasets_to_process(rawdata_path, skips_dict, dset_prefix, sample_list)

# manual override:
# samples_dict = {"FeAu_0p5_tR_nscope": ["top_100um", "top_200um"]}

print(samples_dict)

# now we have our samples_dict, we can process our data:
mask_path = '/data/id11/inhouse1/ewoks/detectors/files/eiger_E-08-0173/mask_with_gaps_E-08-0173.edf'

# you can change these if needed, but they will default to those you selected with the widget

try:
    seg_pars = {"maskfile": mask_path,
                "cut": cut_slider.value,
                "pixels_in_spot": pixels_in_spot_slider.value,
                "howmany": 10**howmany_slider.value}
except NameError:
    seg_pars = {"maskfile": mask_path}


sbats = []
dataset_objects = []

for sample, datasets in samples_dict.items():
    for dataset in datasets:
        print(f"Processing dataset {dataset} in sample {sample}")
        dset_path = os.path.join(processed_data_root_dir, sample, f"{sample}_{dataset}", f"{sample}_{dataset}_dataset.h5")
        
        ds = ImageD11.sinograms.dataset.DataSet(dataroot=rawdata_path,
                                                analysisroot=processed_data_root_dir,
                                                sample=sample,
                                                dset=dataset)
        
        if os.path.exists(ds.sparsefile):
            print(f"Found existing Sparse file for {dataset} in sample {sample}, skipping")
            continue
        
        print("Importing DataSet object")
        try:
            ds.import_all()
        except ValueError:
            print(f"Very dodgy scan! Skipping")
            continue
        except KeyError:
            print(f"Very dodgy scan! Skipping")
            continue
        print(f"I have a DataSet {ds.dset} in sample {ds.sample}")
        ds.save()
        
        print("Segmenting")
        sbat = ImageD11.sinograms.lima_segmenter.setup(ds.dsfile, **seg_pars)
        
        if sbat is None:
            print(f"{dataset} in sample {sample} already segmented, skipping")
            continue
            
        sbats.append(sbat)
        dataset_objects.append(ds)

        
utils.slurm_submit_many_and_wait(sbats, wait_time_sec=60)

for ds in dataset_objects:
    print("Labelling sparse peaks")
    ImageD11.sinograms.assemble_label.main(ds.dsfile)

    print("Generating peaks table")
    ImageD11.sinograms.properties.main(ds.dsfile, options={'algorithm': 'lmlabel', 'wtmax': 70000, 'save_overlaps': False})

    print("Cleaning up sparse files")
    sparse_folder_path = os.path.join(ds.analysispath, "sparse")

    if not os.path.exists(sparse_folder_path):
        os.mkdir(sparse_folder_path)

    scan_sparse_files = glob.glob(os.path.join(ds.analysispath, "scan*_sparse.h5"))

    for scan_sparse_file in scan_sparse_files:
        shutil.move(scan_sparse_file, sparse_folder_path)
print("Done!")