In [None]:
# This cell is tagged with "parameters" and used by papermill
# the outpath will be overridden to keep the name backward compatible if it is from auto-processing

# default settings
inpath = "" # type: str
outpath = "" # type: str
dask_use_local_cluster = True # type: bool
dask_num_workers = 12 # type: int
process_configfile = "/dls_sw/i14/ops/processing/auto/processing.yaml" # type: str
auto_processing = True # type: bool

## Usage
This notebook takes a scan file `inpath`, reconstructs the differential phase contrast (DPC) images, and save the results in `outpath`. 

### Parameters
`inpath` : str    
    the full path of the scan file. e.g. "/dls/i14/data/2024/cm37259-1/scan/i14-274317.nxs"  
`outpath` : str    
    the full path of the output file. e.g. "/dls/i14/data/2024/cm37259-1/processed/i14-274317_dpc-phase.nxs". `outpath` will be overridden if it is triggered by auto-processing, as defined by the flag `auto_processing` (default to `True`); for post-processing, `outpath` will not be modified.   
`dask_use_local_cluster` : bool  
    the flag of using Dask local cluster. If it is true, the reconstruction is often faster at the expense of memory.  
`dask_num_workers` : int    
    the number of Dask workers if local cluster is used.   
`process_configfile` : str   
    the path of the process configuration file. Default to the one widely used by different analysis.   
`auto_processing` : bool    
    the flag to state this is an auto- or post-processing.

**The parameters should be provided by explicitly modifying the top cell content or using tools such as [papermill](https://papermill.readthedocs.io/en/latest/index.html). If the notebook is run as is, please define the parameters accordingly.**

### Dependencies
Majority of the work is carried out by [i14-utility-dpc](https://gitlab.diamond.ac.uk/i14/i14_utility).
- numpy
- dask[distributed]
- i14_utility (https://gitlab.diamond.ac.uk/i14/i14_utility)

In [None]:
import re
import sys
import time

import dask
from dask.distributed import Client, LocalCluster
import numpy as np
import yaml

### Import some utilities and tools

In [None]:
from i14_utility.dpc.io import get_diffraction_frames, save
from i14_utility.dpc.masking import sample_detector_image, get_mask_detector, get_mask_beam, combine_mask, visualise_mask
from i14_utility.dpc.compute import calculate_maps, calculate_phase, visualise_map

### Config Dask
This needs to be done before any usage of Dask in other packages. Applying the setting below can often make the reconstruction quicker but at the expense of memory.

In [None]:
if dask_use_local_cluster:
    # https://github.com/dask/distributed/discussions/7128
    # disable this to make it quicker (memory is not a huge issue here)
    dask.config.set({'distributed.scheduler.worker-saturation': "inf"})
    # disable this as in previous version of Dask
    dask.config.set({'distributed.scheduler.active-memory-manager.start': False})
    # not exist in previous version of Dask
    dask.config.set({'distributed.worker.memory.transfer': 0.9})

    cluster = LocalCluster(n_workers=dask_num_workers, processes=True, threads_per_worker=1)
    client = Client(cluster)

### Check if inpath and outpath is provided
**This does not validate them.**

In [None]:
if not inpath:
    msg = "No 'inpath' is specified."
    raise TypeError(msg)

In [None]:
if not outpath:
    msg = "No 'outpath' is specified."
    raise TypeError(msg)

### Override the path of output file
This keeps the name backward compatible. No overriding is needed when it is not a auto-processing as the `outpath` is provided directly.

In [None]:
if auto_processing:
    regex_scan = re.compile(r"(^.*/i14-\d+)(.*$)")
    outpath = re.sub(regex_scan, r"\1-dpc_phase.nxs", outpath)
    print(f"The overridden outpath is {outpath}")

### Get configuration from configuration file

In [None]:
# get processing configuration
try:
    process_configfile
except NameError:
    # keep the default behaviour if it is not defined
    process_configfile = "/dls_sw/i14/ops/processing/auto/processing.yaml"

In [None]:
mapping_default = {"entry.instrument.dcm_enrg.value.value": ["Acquisition_instrument.XRF.beam_energy", None], 
                   "entry.instrument.sample.sample_rot.value": ["Acquisition_instrument.XRF.stage.rotation", None],
                   "entry.instrument.detectors.excalibur_z.value": ["Acquisition_instrument.DPC.distance", None],
                  }
dpc_config_default = {"crop_size": 256, 
                      "dpc_exclude_crop_size": 256, 
                      "dpc_include_crop_size": 128, 
                      "fractional_dark": 1.05, 
                      "pixel_size": 55e-6, 
                      "absorption_map": True, 
                      "scatter_map": True, 
                      "radial_scatter_map": True,
                      "second_moment_map": False,
                      "quick_mask": False, 
                      "dpc_include_beam": True, 
                      "dpc_exclude_beam": True, 
                      "offset_x": 0, 
                      "offset_y": 0, 
                      "mirror": True, 
                      "anti": True, 
                      "zeropad": False, 
                      "high_pass_filter": False, 
                      "regularisation": 0.001,
                     }
io_config_default = {"dpc": ['/entry/merlin_addetector', '/entry/eiger_addetector'], 
                     "metadata_keys": ["instrument/dcm_enrg", 
                                       "instrument/sample/sample_rot", 
                                       "instrument/detectors/excalibur_z"],
                    }

In [None]:
try:
    f = open(process_configfile)
except FileNotFoundError:
    print(f"Configuration file {process_configfile} is not found. Default configurations will be used.")

    process_config = {"mapping": mapping_default,
                      "dpc": dpc_config_default,
                      "io": io_config_default,
                     }
else:
    process_config = yaml.load(f.read(), Loader=yaml.FullLoader)
    f.close()

In [None]:
mapping = process_config.get("mapping", mapping_default)

In [None]:
dpc_config = process_config.get("dpc", dpc_config_default)

In [None]:
crop_size = dpc_config.get("crop_size", dpc_config_default["crop_size"])
dpc_include_crop_size = dpc_config.get("dpc_include_crop_size", dpc_config_default["dpc_include_crop_size"])
dpc_exclude_crop_size = dpc_config.get("dpc_exclude_crop_size", dpc_config_default["dpc_exclude_crop_size"])
fractional_dark = dpc_config.get("fractional_dark", dpc_config_default["fractional_dark"])
pixel_size = dpc_config.get("pixel_size", dpc_config_default["pixel_size"])
absorption_map = dpc_config.get("absorption_map", dpc_config_default["absorption_map"])
scatter_map = dpc_config.get("scatter_map", dpc_config_default["scatter_map"])
radial_scatter_map = dpc_config.get("radial_scatter_map", dpc_config_default["radial_scatter_map"])
second_moment_map = dpc_config.get("second_moment_map", dpc_config_default["second_moment_map"])
quick_mask = dpc_config.get("quick_mask", dpc_config_default["quick_mask"])
dpc_include_beam = dpc_config.get("dpc_include_beam", dpc_config_default["dpc_include_beam"])
dpc_exclude_beam = dpc_config.get("dpc_exclude_beam", dpc_config_default["dpc_exclude_beam"])
offset_x = dpc_config.get("offset_x", dpc_config_default["offset_x"])
offset_y = dpc_config.get("offset_y", dpc_config_default["offset_y"])
dpc_mirror = dpc_config.get("mirror", dpc_config_default["mirror"])
dpc_anti = dpc_config.get("anti", dpc_config_default["anti"])
dpc_zeropad = dpc_config.get("zeropad", dpc_config_default["zeropad"])
high_pass_filter = dpc_config.get("high_pass_filter", dpc_config_default["high_pass_filter"])
regularisation = dpc_config.get("regularisation", dpc_config_default["regularisation"])

In [None]:
io_config = process_config.get("io", io_config_default)

In [None]:
dataset_path = io_config.get("dpc", io_config_default["dpc"])
metadata_keys = io_config.get("metadata_keys", io_config_default["metadata_keys"])

In [None]:
print(f"{mapping = }")
print(f"{crop_size = }")
print(f"{dpc_include_crop_size = }")
print(f"{dpc_exclude_crop_size = }")
print(f"{fractional_dark = }")
print(f"{scatter_map = }")
print(f"{absorption_map = }")
print(f"{radial_scatter_map = }")
print(f"{dpc_include_beam = }")
print(f"{dpc_exclude_beam = }")
print(f"{offset_x = }")
print(f"{offset_y = }")
print(f"{dpc_mirror = }")
print(f"{dpc_anti = }")
print(f"{dpc_zeropad = }")
print(f"{high_pass_filter = }")
print(f"{regularisation = }")
print(f"{quick_mask = }")
print(f"{dataset_path = }")
print(f"{metadata_keys = }")

### Load the signal lazily using HyperSpy
It also loads associated metadata such as the scanning x and y positions, energy etc.

In [None]:
start = time.perf_counter()

s = get_diffraction_frames(inpath, dataset_path, metadata_keys)

print(f"HyperSpy loading: {time.perf_counter() - start} s")

In [None]:
s

### Sample a beam

In [None]:
start = time.perf_counter()

sample_detector = sample_detector_image(s)

print(f"Get sample detector image: {time.perf_counter() - start} s")

### Get different masks

In [None]:
start = time.perf_counter()

mask_detector, beam_accumulated = get_mask_detector(s, sample_beam=sample_detector)
mask_beam = get_mask_beam(s, beam_accumulated=beam_accumulated)
mask_combined = combine_mask(mask_detector, mask_beam)

print(f"Get masks: {time.perf_counter() - start} s")

In [None]:
visualise_mask(sample_detector, mask_beam, mask_combined)

### Calculate different maps
Absorption map, scatter map, scatter map with radial mask, centre-of-mass (excluding and including the beam)

In [None]:
# second moment always False
flags = [absorption_map, scatter_map, radial_scatter_map, False, dpc_exclude_beam, dpc_include_beam]

start = time.perf_counter()

res = s.map(calculate_maps, mask=mask_combined, flags=flags,
            inplace=False, ragged=False)
res.compute()

print(f"Compute mappings: {time.perf_counter() - start} s")

In [None]:
absp = res.data[:, :, 0].astype(float)
scat = res.data[:, :, 1].astype(float)
radl = np.vstack(res.data[:,:, 2].flatten()).reshape(*s.data.shape[:2], len(mask_combined.radial_exclude))
smom = res.data[:, :, 3].astype(float)
com_exclude = np.vstack(res.data[:, :, 4].flatten()).reshape(*s.data.shape[:2], 2)
com_include = np.vstack(res.data[:, :, 5].flatten()).reshape(*s.data.shape[:2], 2)

### Retrieve the phase from centre-of-mass
#### Kottler method

In [None]:
phase_exclude, gradient_norm_exclude = calculate_phase(s, com_exclude, method='kottler',
                                                       offset_x=offset_x, offset_y=offset_y,
                                                       pixel_size=pixel_size, zeropad=dpc_zeropad,
                                                       mirroring=dpc_mirror, mirror_flip=dpc_anti)

In [None]:
phase_include, gradient_norm_include = calculate_phase(s, com_include, method='kottler',
                                                       offset_x=offset_x, offset_y=offset_y,
                                                       pixel_size=pixel_size, zeropad=dpc_zeropad,
                                                       mirroring=dpc_mirror, mirror_flip=dpc_anti)

#### Lazic method

In [None]:
phase_lazic_exclude, _ = calculate_phase(s, com_exclude, method='lazic', 
                                         offset_x=offset_x, offset_y=offset_y,
                                         pixel_size=pixel_size, zeropad=dpc_zeropad,
                                         mirroring=dpc_mirror, mirror_flip=dpc_anti,
                                         high_pass_filter=high_pass_filter, regularisation=regularisation
                                        )

In [None]:
phase_lazic_include, _ = calculate_phase(s, com_include, method='lazic', 
                                         offset_x=offset_x, offset_y=offset_y,
                                         pixel_size=pixel_size, zeropad=dpc_zeropad,
                                         mirroring=dpc_mirror, mirror_flip=dpc_anti,
                                         high_pass_filter=high_pass_filter, regularisation=regularisation
                                        )

### Visualise mapping

In [None]:
data_to_save = {'absorption_map': absp,
                'scatter_map': scat,
                'second_moment': smom,
                'CoMx_exclude': com_exclude[:, :, 1],
                'CoMy_exclude': com_exclude[:, :, 0],
                'CoMx_include': com_include[:, :, 1],
                'CoMy_include': com_include[:, :, 0],
                'kottler_exclude': phase_exclude,
                'phase_gradient_norm_exclude': gradient_norm_exclude,
                'kottler_include': phase_include,
                'phase_gradient_norm_include': gradient_norm_include,
                'lazic_exclude': phase_lazic_exclude,
                'lazic_include': phase_lazic_include,
               }

if radial_scatter_map:
    for k, radial_mask in enumerate(mask_combined.radial_exclude):
        data_to_save.update({f'radial_mask_{k}': ~radial_mask})
        data_to_save.update({f'radial_mask_data_{k}': radl[:, :, k]})

In [None]:
visualise_map(data_to_save)

### Save the results

In [None]:
save(s, outpath, inpath, data_to_save=data_to_save)

### Send the output file to GDA

In [None]:
try:
    sys.path.append("/dls_sw/i14/software/daqmessenger")
    from daqmessenger import DaqMessenger
except (PermissionError, ImportError):
    print("No messenger")
else:
    daq = DaqMessenger("i14-control")
    daq.connect()
    daq.send_file(outpath)
    daq.disconnect()

### Shut down

In [None]:
if dask_use_local_cluster:
    client.close()
    cluster.close()