# NK Cell Cytotoxicity Analysis Pipeline

### 📓 Fluorescence Time-lapse Imaging — Tutorial

This tutorial demonstrates an automated imaging analysis pipeline for quantifying NK cell cytotoxic behaviour at the single-cell level. It processes fluorescence time-lapse microscopy data through key steps including:

- Image import and visualisation
- Channel unmixing and preprocessing
- Cell segmentation using Cellpose
- Tracking with BayesianTracker
- Functional analysis and visual outputs

**Environment & Visualisation Setup**

This section configures the Python environment by importing all required libraries and setting up visualisation styles for consistent plotting.

In [None]:
import os
import sys
import json
import pickle
import textwrap
from glob import glob

import numpy as np
import pandas as pd

import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.animation import FuncAnimation, PillowWriter, FFMpegWriter
from IPython.display import HTML, display

from tqdm.auto import tqdm

import pims
import napari
from cellpose import models

from skimage.color import label2rgb
from skimage.exposure import adjust_gamma
from skimage.measure import regionprops
from skimage.morphology import disk, dilation

from sklearn.preprocessing import RobustScaler
from sklearn.mixture import GaussianMixture
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics import silhouette_score
from sklearn.preprocessing import StandardScaler, MinMaxScaler

sys.path.append('./Cyto-Visual')

In [None]:
mpl.rcParams['animation.embed_limit'] = 1000

plt.rcParams.update({
    "mathtext.fontset": "stix",
    "font.family": "STIXGeneral",
    "legend.fontsize": 14,
    "xtick.labelsize": 16,
    "ytick.labelsize": 16,
    "axes.titlesize": 18,
    "axes.labelsize": 18,
    "savefig.format": "pdf",
    "legend.edgecolor": "0.0",
    "legend.framealpha": 1.0,
})

sns.set(style="whitegrid", context="notebook")

# $\Alpha$. Import the imaging file 

In [None]:
from getIN import load_tiff_sequence

# define your folder containing the imaging data here:
folder_path = "/Users/marsian/PhD_LifeSciRes/Image_Python/Cathal_Data/good_3/test"

# Import the imaging
ioi_files, ioi_files_name = load_tiff_sequence(folder_path)

**Use Napari for Visualisation**

<p align="left">
  <img src="https://napari.org/stable/_static/logo.png" alt="napari logo" width="40" style="vertical-align:middle; margin-right:10px;">
  <code>Napari</code> is a fast, interactive, and user-friendly viewer optimised for large, multi-dimensional image data. Compared to <code>matplotlib</code>-based animations in Jupyter notebooks, Napari provides a smoother, more flexible experience for visualising time-lapse microscopy data interactively.
</p>

> 📌 Use <code>napari.run()</code> to launch the GUI interactively from a script or notebook.

---

While Napari is ideal for interactive exploration, `matplotlib` is still useful when you need to **export animations** as GIF or MP4 files.  
We provide a helper function to generate and save publication-ready animations using `matplotlib.animation`.

In [None]:
file = 0 # Define the file sequence that you would like to visualise
viewer = napari.Viewer()
max_list = [0.4, 0.6, 0.8]
Colormap = ['I Orange', 'I Forest', 'I Bordeaux']
for ch in range(ioi_files[file].shape[3]-1):
    viewer.add_image(
        ioi_files[file][:,:,:,ch],
        name = f'file_{file}_post_ch_{ch}',
        blending='minimum', 
        colormap=Colormap[ch],
        contrast_limits=(0, max_list[ch]*np.max(ioi_files[file][:,:,:,ch]))
    )

In [None]:
# Use Matplotlib to visualise
from visua import mat_visua

# Just display
html = mat_visua(ioi_files)
display(html)

# Display and save as MP4
# mat_visua(ioi_files, save=True, filename="NK_animation", format="mp4")

# Save as GIF
# mat_visua(ioi_files, save=True, filename="NK_animation", format="gif")

# $\Beta$. Pre-processing


Use the `pre_processing()` function to apply a customisable sequence of image processing steps to 4D fluorescence time-lapse stacks `(T, Y, X, C)`.

---

**Default usage**  
This runs `["trim", "unmix", "normalize"]` by default:

```python
processed = pre_processing(ioi_files)
```

---

**Custom pipeline**  
Define your own ordered steps:

```python
steps = ["background", "unmix", "normalize", "gamma"]
processed = pre_processing(
    ioi_files,
    time_interval=100,
    processing_steps=steps,
    gamma=1.5
)
```

---

**Use a custom unmixing matrix:**

```python
custom_M = np.array([
    [0.55, 0.05, 0.20],
    [0.01, 0.70, 0.02],
    [0.08, 0.00, 0.45]
])
processed = pre_processing(ioi_files, mixing_matrix=custom_M)
```

---

**Parameters summary**

| Parameter         | Description                                      |
|------------------|--------------------------------------------------|
| `ioi_files`       | Dict of 4D stacks `(T, Y, X, C)`                 |
| `time_interval`   | Trim to max number of frames                    |
| `processing_steps`| List of steps: `["trim", "background", "threshold", "unmix", "normalize", "gamma"]` |
| `mixing_matrix`   | `"default"` or custom `np.ndarray` (C × C)       |
| `gamma`           | Gamma value (e.g. 1.5) if `"gamma"` in steps     |

---

**Output**  
Returns a dictionary of processed stacks with the same keys and shapes as `ioi_files`.

```python
processed["sample1"].shape  # (T, Y, X, C)
```

In [None]:
from prepo import *

custom_matrix = np.array([
    [0.55, 0.05, 0.20],
    [0.01, 0.70, 0.02],
    [0.08, 0.00, 0.45]
])

ioi_processing = pre_processing(
    ioi_files,
    time_interval=120,
    processing_steps=["trim", "unmix", "normalize"],
    mixing_matrix=custom_matrix
)

In [None]:
file = 0
viewer = napari.Viewer()
Colormap = ['I Blue', 'I Forest', 'I Bordeaux']
for ch in range(ioi_processing[file].shape[3]):
    viewer.add_image(
        ioi_processing[file][:,:,:,ch],
        name = f'post_ch_{ch}',
        blending='minimum', 
        colormap=Colormap[ch]
    )

# for ch in range(ioi_processing[file].shape[3]):
#     viewer.add_image(
#         ioi_files[file][:time_interval,:,:,ch],
#         name = f'raw_ch_{ch}',
#         blending='minimum', 
#         colormap=Colormap[ch]
#     )

# $\Gamma$. Segmenting and Merging the Masks. 
# Extracting the regional information

In [None]:
from seg import *  

**Cellpose Segmentation Parameters**

| Parameter           | Description                                                    | Typical Range     | Default | Effect            |
|---------------------|----------------------------------------------------------------|--------------------|---------|----------------------------------------|
| `flow_threshold`    |  the maximum allowed error of the flows for each mask     | 0.0 – 1.0          | 0.4     |Increase this threshold if cellpose is not returning as many ROIs as you’d expect. Similarly, decrease this threshold if cellpose is returning too many ill-shaped ROIs  |
| `cellprob_threshold`| it is used to run dynamics and determine ROIs               | -6.0 – 6.0         | 0.0     | Decrease this threshold if cellpose is not returning as many ROIs as you’d expect. Similarly, increase this threshold if cellpose is returning too ROIs particularly from dim areas.         |

Use these to balance sensitivity and specificity depending on signal quality.

> ⚠️ **Warning**: `run_cellpose_segmentation()` can be **very slow** when processing many timepoints or large image stacks.

To speed things up, we provide a **Snakemake-based pipeline** for batch segmentation on **HPC clusters** — see the [GitHub repository](https://github.com/ElephesSung/CytotoxicVision) for details.

For this tutorial, we **re-import pre-segmented mask files** to save time.  
You're welcome to run the full segmentation yourself and export your own results.

In [None]:
ioi_masks = run_cellpose_segmentation(
    ioi_processing,
    diameter=None,
    channels=[0, 0],
    flow_thresholds=[0.4, 0.4, 0.6],
)

you can use this code to export your masks data (```ioi_masks```) to a pkl file.

```python
import pickle
import os

file_name = "ioi_masks_Les_gamma2.pkl"  # Name of the saved file
file_path = os.path.join(folder_path, file_name)  # Combine folder path with filename
os.makedirs(folder_path, exist_ok=True)

# Save dictionary to a file
with open(file_path, "wb") as file:
    pickle.dump(ioi_masks, file)
print(f"Dictionary saved successfully to {file_path}")
```

In [None]:
# You can re-import your saved pkl file here
file_path = os.path.join(folder_path, "ioi_masks_Les.pkl")
with open(file_path, "rb") as file:
    ioi_masks = pickle.load(file)
print("Dictionary reloaded successfully from:", file_path)

**Remove Small Masks**

Utility functions are provided to:

- Plot mask area distributions by file and channel.
- Remove small masks likely caused by segmentation noise.

You can define a minimum area threshold per channel or apply a global threshold.

In [None]:
plot_mask_area_distributions(
    ioi_masks,
    file_names=ioi_files_name,
    channel_labels=["NK cells", "721.221 cells", "Death reporter"],
    bins=50,
    xlim=(0, 175),
    # save_path="mask_area_distribution.svg"  # Save as SVG
)

cleaned_ioi_masks = clean_masks(ioi_masks, area_threshold=40, channels=[1, 2])

**Region Extraction & Mask Merging**

This function merges selected mask channels and extracts per-object region properties (e.g. area, centroid, intensity) for downstream analysis.

In [None]:
ioi_region, ioi_masks_merge = region_info(
    ioi_masks=cleaned_ioi_masks,
    ioi_processing=ioi_processing,
    ioi_files=ioi_files,
    merge_channels=(1, 2),
    merge_mode='OR',
    iou_threshold=0.04,
    extract_channels=(0, 1, 2),
    region_features=("area", "coords", "centroid", "label"),
    include_intensity=True,
    include_raw_intensity=True,
    time_trim=60
)

You can extract fluorescence intensities from the merged 721.221 tumour and apoptotic reporter channels to identify two distinct clusters, corresponding to live and dead cells.
This feature will be used in the following code, after tracking, to classify cell states.

In [None]:
from tempo_ANA import plot_joint_fluo
plot_joint_fluo(
    ioi_region[1]['ch12'],
    file_name=ioi_files_name[1],
    mode="linear"
    )

# $\Delta$ Cell migration tracking

<p align="left">
  <img src="https://camo.githubusercontent.com/778aad38a3c4e6c07eb70f1de8e596d12d09b36b6e67e6b49df00bf67b0a68ed/68747470733a2f2f62747261636b2e72656164746865646f63732e696f2f656e2f6c61746573742f5f696d616765732f62747261636b5f6c6f676f2e706e67" alt="btrack logo" width="100" style="vertical-align:middle; margin-right:10px;">
  We use <code>btrack</code>, the Bayesian tracker, to track cell migration.
  For more information and optimisation tips, see the 
  <a href="https://github.com/quantumjot/btrack" target="_blank">GitHub repo</a> and the 
  <a href="https://btrack.readthedocs.io/en/latest/" target="_blank">official documentation</a>.
</p>


In [None]:
import track
from track import *

feature_dict = {
    'ch0': ['area'],
    'ch12': ['area']
}

ioi_tracking, ioi_napari_tracks = run_btrack(
    ioi_region=ioi_region,
    ioi_masks=ioi_masks,
    feature_dictionary=feature_dict,
    ioi_processing=ioi_processing,
    ioi_files=ioi_files,
    config_path="./cell_config.json"
)

#os.path.join(folder_path, "")

In [None]:
import napari

file = 0
trim = 50
viewer = napari.Viewer()

pixel_size_xy = 810.44 / 512
scale = (1, pixel_size_xy, pixel_size_xy)

channel_colormaps = ['I Blue', 'I Forest', 'I Bordeaux']
mask_colormap = 'gist_earth'

for ch in range(ioi_processing[file].shape[-1]):
    viewer.add_image(
        ioi_processing[file][0:trim, :, :, ch],
        name=f'post_ch_{ch}',
        blending='minimum',
        colormap=channel_colormaps[ch],
        scale=scale,
    )

for ch in range(ioi_masks_merge[file].shape[-1]):
    viewer.add_image(
        ioi_masks_merge[file][0:trim, :, :, ch],
        name=f'seg_ch_{ch}',
        blending='additive',
        colormap=mask_colormap,
        scale=scale,
    )

track_labels = {
    'ch0': 'Tracks_NK',
    'ch12': 'Tracks_TUDR',
}

for label_key, display_name in track_labels.items():
    if label_key in ioi_napari_tracks[file]:
        viewer.add_tracks(
            ioi_napari_tracks[file][label_key],
            name=display_name,
            blending='translucent',
            visible=True,
            colormap=mask_colormap,
            tail_length=8,
            scale=scale,
        )

viewer.scale_bar.visible = True
viewer.scale_bar.unit = "um"
viewer.scale_bar.position = "bottom_right"
viewer.scale_bar.font_size = 8

# $\Delta$. Temporal Evolution of the fluorescence intensities.

### $\delta$-i. Classifying Target Cell Status: From Clustering to Thresholding

We initially used unsupervised clustering (Ward or GMM) on fluorescence features—Calcium Green (channel 1) and TO-PRO-3 (channel 2)—to distinguish live and dead target cells. In this scheme, live targets exhibit high Calcium Green and low TO-PRO-3, while dead targets show the opposite pattern.

However, the subset of cells with high signals in both channels (“double positive”), which also corresponded to dead cells. To address this, we adopted a threshold-based classification, segmenting cells into four groups based on marginal thresholds for each channel. Only the (high Calcium, low TO-PRO-3) group is considered truly alive; all other combinations, including double positives, are classified as dead. This approach provides a more accurate and biologically meaningful distinction between live and dead target cells.

In [None]:
from tempo_ANA import *

In [None]:
clustering_test = classify_cell_states(
    ioi_tracking,
    ioi_files_name=ioi_files_name,
    cluster_labels=["live targets", "dead targets"] ,
    # use_gmm=True,
    use_raw=True,
    use_log=True,
    point_size=12,
    # n_clusters=4,
    xlim=(5.1, 7.6),  
    ylim=(5.4, 6.2)   
)

In [None]:
# Default: GMM thresholding, log1p, raw intensities, default group-to-cluster mapping
ioi_tracking = threshold_cell_states(
    ioi_tracking,
    ioi_files_name=ioi_files_name,
    threshold_method="gmm",  # "gmm" or "kde", "otsu", "yen", "li"
    use_log=True,
    use_raw=True,
    cluster_labels=["Alive", "Dead"],
    point_size=12,
    group_to_cluster={(1,0): 0, (0,0): 1, (0,1): 1, (1,1): 1},  # (1,0) is alive, others dead
    xlim=(5.1, 7.6),
    ylim=(5.4, 6.2)
)

In [None]:
import tempo_ANA
from tempo_ANA import *

ioi_tracking = threshold_cell_states_3(
    ioi_tracking,
    ioi_files_name=ioi_files_name,
    cluster_labels=["live 721.221 cells", "dead 722.221 cells"],
    palette={"live 721.221 cells": "seagreen", "dead 722.221 cells": "salmon"},
    xlabel="log(Calcium Green intensity)",
    ylabel="log(TO-PRO-3 intensity)",
    axis_label_fontsize=18,
    # legend_title="Cell Status",
    legend_fontsize=14,
    hist_bins=100,
    save_path="./",
    save_format="pdf"  # "svg" or "pdf"
)

### $\delta$-ii. Temporal evolution of fluorescence intensity 
We visualise individual cell tracks based on their temporal fluorescence intensity in two channels:

- **Channel 1**: 721.221 (green) — marker for live cells  
- **Channel 2**: TO-PRO-3 (red) — marker for cell death

Each row corresponds to a cell behaviour category:
1. **Always Live** (consistently low TO-PRO-3, high 721.221)
2. **Always Dead** (high TO-PRO-3, low 721.221)
3. **Live → Dead Transition** (initially live, then death marker rises)

Each subplot shows a single cell’s intensity over time (`Frames`).  

In [None]:
# plot_all_TempoFluo(ioi_tracking[0]['ch12'], min_track_length=20)

In [None]:
# Select the tracking dataframe for one sample (e.g., file 0)
df = ioi_tracking[1]['ch12']

# Categorise cells by their temporal cell status transition
always_0, always_1, transition_0_to_1 = categorise(df, min_track_length=25)

# Plot temporal fluorescence profiles for each category
plot_ex_TempoFluo(
    df,
    always_0,
    always_1,
    transition_0_to_1,
    num_cols=7,  # You can change this to however many columns you want
    figsize=(16, 6),
    ch1_label='721.221',
    ch2_label='TO-PRO-3',
    ch1_color='seagreen',
    ch2_color='salmon',
    show_ids=False,
    ylim=(0, 1600),
    save_path="temporal_profiles.pdf"
)

In [None]:
plot_barrels(
    df,
    always_0,
    always_1,
    transition_0_to_1,
    ch1_label='Calcein Green AM (mean)',
    ch2_label='TO-PRO-3 (mean)',
    xlims=[(0, 60), (0, 60), (-30, 25)],
    titles = [None, None, None]
    # save_path="barrels_plot.pdf"
)

In [None]:
# Select the sample index and target cell ID
sample_index = 0
target_cell_id = 12

import importlib
importlib.reload(tempo_ANA)
from tempo_ANA import animate_scProfile
interactive_html = animate_scProfile(
    ioi_tracking=ioi_tracking,         # List of tracking DataFrames
    ioi_files=ioi_files,               # List of 5D image arrays (T, H, W, C)
    file=sample_index,                 # Select the specific sample
    target_Tu_id=target_cell_id,       # Particle ID to animate
    time_interval=60,                  # Number of frames to animate
    crop_size=100,                     # Size of cropped view around cell
    fig_width=15,                      # Width of the output figure
    fig_height=6,                      # Height of the output figure
    profile_time_extend=5,            # Extend profile range on both sides
    save_gif_path="./",       # Uncomment to save animation as GIF
    ch1_label='Calcein Green AM',
    ch2_label='TO-PRO-3'
)

# Display in notebook
display(interactive_html)

# $\Epsilon$. Contact Detection.


This section outlines a pipeline to detect and summarise NK cell contacts with target cells (e.g., tumour cells), based on mask overlap and distance criteria across time-lapse imaging data.

> ⚠️ **Note:** The current algorithm is not yet validated against manual ground truth annotations. Experimental verification will be conducted in future steps.

---

In [None]:
def detect_contacts(ioi_tracking, normal_TUDR_list, distance_threshold=30, min_track_length=25):
    from skimage.morphology import disk, dilation
    import numpy as np
    import json
    from tqdm.auto import tqdm

    def enlarge_coordinates(coords, radius):
        max_y = max(coord[0] for coord in coords) + radius + 10
        max_x = max(coord[1] for coord in coords) + radius + 10
        binary_mask = np.zeros((int(max_y), int(max_x)), dtype=bool)
        for y, x in coords:
            binary_mask[int(y), int(x)] = 1
        footprint = disk(radius)
        dilated_mask = dilation(binary_mask, footprint)
        return [tuple(p) for p in np.argwhere(dilated_mask)]

    def has_intersection(coords1, coords2):
        return int(bool(set(coords1) & set(coords2)))

    for _, image_key in tqdm(enumerate(ioi_tracking), desc='files', total=len(ioi_tracking)):
        df_TUDR = ioi_tracking[image_key]['ch12']
        df_TUDR = df_TUDR[df_TUDR['id'].isin(normal_TUDR_list)]

        df_NK = ioi_tracking[image_key]['ch0']
        df_NK = df_NK.groupby('id').filter(lambda sub: len(sub) >= min_track_length)
        NK_ids = df_NK['id'].unique()

        for nk_id in tqdm(NK_ids, desc='NK_cell_id', leave=False):
            nk_df = df_NK[df_NK['id'] == nk_id]
            for t in nk_df['t'].unique():
                nk_frame = nk_df[nk_df['t'] == t]
                nk_coords = [tuple(c) for c in np.array(nk_frame['coords'].iloc[0])]
                nk_y, nk_x = nk_frame['y'].iloc[0], nk_frame['x'].iloc[0]

                TUDR_frame = df_TUDR[df_TUDR['t'] == t]
                contact_no = 0
                contacted_ids = []

                for _, tu in TUDR_frame.iterrows():
                    tu_y, tu_x = tu['y'], tu['x']
                    tu_coords = [tuple(p) for p in np.array(tu['coords'])]
                    dist = np.hypot(nk_y - tu_y, nk_x - tu_x)
                    if dist <= distance_threshold:
                        intersects = has_intersection(
                            enlarge_coordinates(tu_coords, 1),
                            enlarge_coordinates(nk_coords, 1)
                        )
                        if intersects:
                            contact_no += 1
                            contacted_ids.append(tu['id'])

                            # Ensure NK_incontact_id exists and update
                            if 'NK_incontact_id' not in ioi_tracking[image_key]['ch12'].columns:
                                ioi_tracking[image_key]['ch12']['NK_incontact_id'] = ""
                            existing = ioi_tracking[image_key]['ch12'].loc[
                                (ioi_tracking[image_key]['ch12']['id'] == tu['id']) &
                                (ioi_tracking[image_key]['ch12']['t'] == t), 'NK_incontact_id'
                            ].values[0]
                            try:
                                current_ids = json.loads(existing) if existing else []
                                if isinstance(current_ids, int):
                                    current_ids = [current_ids]
                                current_ids.append(int(nk_id))
                            except json.JSONDecodeError:
                                current_ids = [int(nk_id)]
                            ioi_tracking[image_key]['ch12'].loc[
                                (ioi_tracking[image_key]['ch12']['id'] == tu['id']) &
                                (ioi_tracking[image_key]['ch12']['t'] == t), 'NK_incontact_id'
                            ] = json.dumps(current_ids)

                # Save contact info to NK cell
                ioi_tracking[image_key]['ch0'].loc[
                    (ioi_tracking[image_key]['ch0']['id'] == nk_id) &
                    (ioi_tracking[image_key]['ch0']['t'] == t), 'contact'
                ] = contact_no
                ioi_tracking[image_key]['ch0'].loc[
                    (ioi_tracking[image_key]['ch0']['id'] == nk_id) &
                    (ioi_tracking[image_key]['ch0']['t'] == t), 'TU_contact'
                ] = json.dumps(contacted_ids)

In [None]:
def summarise_contacts(ioi_tracking, NK_channel='ch0', min_track_length=25):
    import numpy as np
    import json
    from tqdm.auto import tqdm

    stats_data = {}

    def count_contact_events(binary_list):
        binary = np.array(binary_list)
        changes = np.diff(binary, prepend=0)
        starts = np.where(changes == 1)[0]
        ends = np.where(changes == -1)[0]
        if len(ends) < len(starts):
            ends = np.append(ends, len(binary))
        return len(starts), (ends - starts).tolist()

    for _, image_key in tqdm(enumerate(ioi_tracking), desc='File', total=len(ioi_tracking)):
        stats_data[image_key] = {'NK': {}}
        df = ioi_tracking[image_key][NK_channel]
        df = df.groupby('id').filter(lambda x: len(x) >= min_track_length)
        for pid in tqdm(df['id'].unique(), desc='NK Particle', leave=False):
            stats_data[image_key]['NK'][pid] = {}
            df_pid = df[df['id'] == pid].sort_values('t')

            contact_seq = [
                json.loads(row['TU_contact']) if isinstance(row['TU_contact'], str) else row['TU_contact']
                for _, row in df_pid.iterrows()
            ]
            stats_data[image_key]['NK'][pid]['contact_sequence'] = contact_seq

            all_tu = sorted(set(num for s in contact_seq if isinstance(s, (list, tuple)) for num in s))
            stats_data[image_key]['NK'][pid]['TU_unique_id'] = all_tu

            total_contacts = 0
            durations = []
            for tu in all_tu:
                binary = [1 if tu in s else 0 for s in contact_seq]
                num, dur = count_contact_events(binary)
                total_contacts += num
                durations += dur

            stats_data[image_key]['NK'][pid]['total_contacts'] = total_contacts
            stats_data[image_key]['NK'][pid]['duration_list'] = durations

    return stats_data

In [None]:
# 1. Detect contacts
detect_contacts(
    ioi_tracking=ioi_tracking,
    normal_TUDR_list=always_0 + always_1 + transition_0_to_1,
    distance_threshold=30,
    min_track_length=25
)

# 2. Summarise contact statistics
contact_stats = summarise_contacts(
    ioi_tracking=ioi_tracking,
    NK_channel='ch0',           # default
    min_track_length=25         # default
)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.ticker as ticker

def plot_nk_contact_duration_distribution(ioi_file_stats_data, 
                                          binwidth=1, 
                                          contact_xlim=(0, 60), 
                                          figsize_per_file=(5, 4),
                                          save_path=None, 
                                          save_format=None):
    """
    Plot NK cell contact duration distributions across files.

    Parameters
    ----------
    ioi_file_stats_data : dict
        Output from summarise_contacts, structured by file -> 'NK' -> particle_id.
    binwidth : int
        Bin width for histogram bars.
    contact_xlim : tuple
        (min, max) range for x-axis (contact durations).
    figsize_per_file : tuple
        Size of each subplot (width, height).
    save_path : str or None
        Directory to save the figure. If None, shows plot instead.
    save_format : str or None
        If provided ('svg' or 'pdf'), figure will be saved in that format.
    """

    num_files = len(ioi_file_stats_data)
    fig, axes = plt.subplots(
        num_files, 1, 
        # figsize=(figsize_per_file[0], figsize_per_file[1] * num_files),
        figsize =(10, 6),
        sharex=False,
        dpi = 300
    )

    if num_files == 1:
        axes = [axes]  # ensure it's iterable

    for file_idx, ax in enumerate(axes):
        durations = [
            duration
            for particle_id in ioi_file_stats_data[file_idx]['NK']
            for duration in ioi_file_stats_data[file_idx]['NK'][particle_id]['duration_list']
        ]

        if not durations:
            ax.text(0.5, 0.5, "No contacts", ha="center", va="center", fontsize=10)
            ax.axis("off")
            continue

        sns.histplot(
            durations, 
            binwidth=binwidth,
            ax=ax,
            kde=False,
            stat='count',
            color="salmon", 
            edgecolor="black"
        )

        ax.set_title(f"File {file_idx} – NK Contact Duration", fontsize=11, fontweight="bold")
        ax.set_ylabel("Frequency", fontsize=9)
        ax.set_xlim(*contact_xlim)

        # Ticks
        ax.xaxis.set_major_locator(ticker.MultipleLocator(5))
        ax.xaxis.set_minor_locator(ticker.AutoMinorLocator(2))
        ax.tick_params(axis='x', which='both', labelsize=7)
        ax.tick_params(axis='y', which='both', labelsize=7)

        # Grid
        ax.grid(axis='y', linestyle='--', alpha=0.5)

    # Final shared X label
    axes[-1].set_xlabel("Contact Duration (frames)", fontsize=10)

    plt.tight_layout()

    if save_path and save_format:
        if save_format.lower() not in ['svg', 'pdf']:
            raise ValueError("save_format must be either 'svg' or 'pdf'")
        output_file = f"{save_path}/nk_contact_durations.{save_format.lower()}"
        plt.savefig(output_file, bbox_inches='tight')
        print(f"Figure saved to {output_file}")
    else:
        plt.show()

In [None]:
plot_nk_contact_duration_distribution(
    ioi_file_stats_data=contact_stats,
    binwidth=1,
    contact_xlim=(0, 60),
    save_path=None,             # Set path to save, e.g., './figures'
    save_format=None            # Use 'svg' or 'pdf' to save
)