# Averaging detector data with Dask

We often want to average large detector data across trains, keeping the pulses within each train separate, so we have an average image for pulse 0, another for pulse 1, etc.

This data may be too big to load into memory at once, but using [Dask](https://dask.org/) we can work with it like a numpy array. Dask takes care of splitting the job up into smaller pieces and assembling the result.

In [None]:
from karabo_data import open_run

import dask.array as da
from dask.distributed import Client, progress
from dask_jobqueue import SLURMCluster
import numpy as np

First, we use [Dask-Jobqueue](https://jobqueue.dask.org/en/latest/) to talk to the Maxwell cluster.

In [None]:
partition = 'exfel'  # For EuXFEL staff
#partition = 'upex'   # For users

cluster = SLURMCluster(
    queue=partition,
    # 16 Dask workers per job - our SLURM config gives every job its own node
    processes=16, cores=16, memory='200GB',
)

# Get a notbook widget showing the cluster state
cluster

In [None]:
# Submit 2 SLURM jobs, for 32 Dask workers
cluster.scale(32)

If the cluster is busy, you might need to wait a while for the jobs to start.
The cluster widget above will update when they're running.

Next, we'll set Dask up to use those workers:

In [None]:
client = Client(cluster)
print("Created dask client:", client)

Now Dask is ready, let's open the run we're going to operate on:

In [None]:
run = open_run(proposal=2212, run=103)
run.info()

We're working with data from the DSSC detector.
In this run, it's recording 75 frames for each train:

In [None]:
counts = run.get_data_counts('SCS_DET_DSSC1M-1/DET/0CH0:xtdf', 'image.data')
counts.unique()

Now, we'll define how we're going to average over trains for each module:

In [None]:
def average_module(modno, run, pulses_per_train=75):
    source = f'SCS_DET_DSSC1M-1/DET/{modno}CH0:xtdf'
    counts = run.get_data_counts(source, 'image.data')
    
    arr = run.get_dask_array(source, 'image.data')
    # Make a new dimension for trains
    arr_trains = arr.reshape(-1, pulses_per_train, 128, 512)
    if modno == 0:
        print("array shape:", arr.shape)  # frames, dummy, 128, 512
        print("Reshaped to:", arr_trains.shape)

    return arr_trains.mean(axis=0, dtype=np.float32)

In [None]:
mod_averages = [
    average_module(i, run, pulses_per_train=75)
    for i in range(16)
]

mod_averages

In [None]:
# Stack the averages into a single array
all_average = da.stack(mod_averages)
all_average

So far, no real computation has happened. Now that we've defined what we want, let's tell Dask to compute it.

This will take a minute or two. If you're running it, scroll up to the Dask cluster widget and click the status link to see what it's doing.

In [None]:
%%time
all_average_arr = all_average.compute()  # Get a concrete numpy array for the result

`all_average_arr` is a regular numpy array with our results. Here are the values from the corner of module 0, frame 0:

In [None]:
print(all_average_arr[0, 0, :5, :5])

Please shut down the cluster (or scale it down to 0 workers) if you won't be using it for a while.
This releases the resources for other people.

In [None]:
client.close()
cluster.close()