# ND2 extractor
* An nd2 extractor using the package `nd2` from `https://pypi.org/project/nd2/`.
* Supports different dimensions of nd2 files (works fine if there is only one time point or one field of view).
* Adjustable threshold for parallelisation to avoid unnecessary `joblib` overheads.
* Automatic colour channel identification from metadata.
* Human-readable metadata extraction to `.txt` files.
* Metadata output in a `.json` format for loading back into your pipeline later as a dictionary.

In [None]:
import numpy as np
import nd2
import os
from PIL import Image
from joblib import Parallel, delayed
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import json

## 1. Get metadata
* `generate_metadata_txt` outputs a text file with human readable metadata about the experiment.
* `generate_metadata_json` outputs a json file with various metadata from your experiment which can be loaded back in as a dictionary at various point of your pipeline.

In [None]:
def generate_metadata_txt(nd2_file, outfile):
    """
    Outputs a plain text file summarising the metadata in a human readable format.
    """
    with nd2.ND2File(nd2_file) as ndfile:
        with open(outfile, "w") as f:
            f.write("nd2 file is {} \n\n".format(nd2_file))
            for key, value in ndfile.text_info.items():
                f.write(key)
                f.write(": \n\n")
                f.write(value)
                f.write("\n\n")
            f.write("Pixel to micron conversion: \n")
            f.write("x voxel size is {} microns/pixel \n".format(str(ndfile.voxel_size().x)))
            f.write("y voxel size is {} microns/pixel \n".format(str(ndfile.voxel_size().y)))

In [None]:
nd2_file = "your_file.nd2"

In [None]:
outfile = "metadata_nd2_" + nd2_file.split(".")[0] + ".txt"
generate_metadata_txt(nd2_file, outfile)

In [None]:
def generate_metadata_json(nd2_file, json_outfile, save_file=True):
    """
    Outputs nd2 metadata in a json file which can be loaded back in as a dictionary.
    """
    ##TODO
    # laser power setting for each channel (is currently in the text version of metadata)
    # exposure time setting for each channel (is currently in the text version of metadata)
    # actual time of camera exposure for each image (ask Nikon?)
    # actual time of illumination for each image (ask Nikon?)
    
    nd2_metadata = dict()
    with nd2.ND2File(nd2_file) as ndfile:
        nd2_metadata["file_path"] = ndfile.path
        nd2_metadata["shape"] = ndfile.shape
        nd2_metadata["ndim"] = ndfile.ndim
        nd2_metadata["dtype"] = str(ndfile.dtype)
        nd2_metadata["sizes"] = ndfile.sizes
        nd2_metadata["is_rgb"] = ndfile.is_rgb
        
        FOV_xyz_PFS = dict()
        for count, position in enumerate(ndfile.experiment[1].parameters.points):
            FOV_xyz_PFS["xy{}".format(str(count).zfill(3))] = (position.stagePositionUm, position.pfsOffset)
        nd2_metadata["FOV_xyz_PFS_key"] = {"FOV_key": "([x, y, z], PFS_offset)"}
        nd2_metadata["FOV_xyz_PFS"] = FOV_xyz_PFS
        
        nd2_metadata["axes_order"] = ["x", "y", "z"]
        nd2_metadata["axes_calibrated"] = ndfile.metadata.channels[0].volume.axesCalibrated
        nd2_metadata["axes_calibration"] = ndfile.metadata.channels[0].volume.axesCalibration
        nd2_metadata["FOV_size_in_pixels"] = ndfile.metadata.channels[0].volume.voxelCount
        nd2_metadata["microns_per_pixel"] = float(ndfile.metadata.channels[0].volume.axesCalibration[0])
        
        nd2_metadata["num_timepoints"] = ndfile.experiment[0].count
        nd2_metadata["imaging_interval_ms"] = ndfile.experiment[0].parameters.periodMs
        nd2_metadata["imaging_interval_ms_mean"] = ndfile.experiment[0].parameters.periodDiff.avg
        nd2_metadata["imaging_interval_ms_max"] = ndfile.experiment[0].parameters.periodDiff.max
        nd2_metadata["imaging_interval_ms_min"] = ndfile.experiment[0].parameters.periodDiff.min
        
        channel_list = []
        channel_dict_key = dict()
        channel_dict_key["channel_name"] = ("excitation_wavelength", "emission_wavelength")
        channel_lambdas = dict()
        for i in range(len(ndfile.metadata.channels)):
            channel_list.append(ndfile.metadata.channels[i].channel.name)
            channel_lambdas[ndfile.metadata.channels[i].channel.name] = (ndfile.metadata.channels[i].channel.excitationLambdaNm,
                                                                         ndfile.metadata.channels[i].channel.emissionLambdaNm)
        nd2_metadata["channels"] = channel_list
        nd2_metadata["channel_wavelengths_key"] = channel_dict_key
        nd2_metadata["channel_wavelengths"] = channel_lambdas
        
        nd2_metadata["objective"] = ndfile.metadata.channels[0].microscope.objectiveName
        nd2_metadata["numerical_aperture"] = ndfile.metadata.channels[0].microscope.objectiveNumericalAperture
        nd2_metadata["objective_magnification"] = ndfile.metadata.channels[0].microscope.objectiveMagnification
        nd2_metadata["post_objective_magnification"] = ndfile.metadata.channels[0].microscope.zoomMagnification
        nd2_metadata["immersion_refractive_index"] = ndfile.metadata.channels[0].microscope.immersionRefractiveIndex
        nd2_metadata["camera_name"] = "Hamamatsu C14440-20UP SN:000470"
        
        if save_file:
            with open(json_outfile, 'w') as f: 
                json.dump(nd2_metadata, f)
    
    return nd2_metadata

In [None]:
json_outfile = "metadata_nd2_" + nd2_file.split(".")[0] + ".json"
nd2_metadata = generate_metadata_json(nd2_file, json_outfile, save_file = True)

## 2. Function definitions for nd2 extraction

In [None]:
def get_nd2_dimensions(nd2_file):
    """
    Return a dict of the nd2 dimensions. Modifies the .sizes method to give a standardised output.
    """
    with nd2.ND2File(nd2_file) as ndfile:
        dimensions = ndfile.sizes
        if "T" not in dimensions.keys():
            dimensions["T"] = 1
        if "P" not in dimensions.keys():
            dimensions["P"] = 1
        if "C" not in dimensions.keys():
            dimensions["C"] = 1
        
        return dict(sorted(dimensions.items()))

In [None]:
def get_channel_dict(nd2_file):
    """
    Returns a dictionary mapping the channel index to the channel name (colour)
    """
    channels = {}
    with nd2.ND2File(nd2_file) as ndfile:
        for i in range(len(ndfile.metadata.channels)): 
            channels[str(i)] = ndfile.metadata.channels[i].channel.name
    
    return channels

In [None]:
def create_joblib_packages(dimensions, parallelisation_min = 50, njobs_max = 16):
    """
    Split the largest dimension into packages to be parallelised along.
    Packages will be supplied as a nested list, with len(packages) = njobs
    """
    
    all_frames = []
    for p in range(dimensions["P"]):
        for c in range(dimensions["C"]):
            for t in range(dimensions["T"]):
                all_frames.append((p,c,t))
    
    ## if nd2 is small, it is better to avoid joblib overhead ##
    if dimensions["T"] * dimensions["P"] * dimensions["C"] < parallelisation_min:
        njobs = 1
    else:
        njobs = njobs_max
        
    #####     
    
    num_imgs = dimensions["P"] * dimensions["T"] * dimensions["C"]
    package_length = int(num_imgs/njobs)
    
    nd2_indices_dict = dict()
    FOV_counter = 0
    t_counter = 0
    ch_counter = 0
    for x in range(num_imgs):
        nd2_indices_dict[x] = [FOV_counter, t_counter, ch_counter]
        ch_counter = ch_counter + 1
        if ch_counter >= dimensions["C"]:
            ch_counter = 0
            t_counter = t_counter + 1
            if t_counter >= dimensions["T"]:
                t_counter = 0
                FOV_counter = FOV_counter + 1
                    
    packages = np.linspace(0, num_imgs, njobs, endpoint=False)
    indices = [int(x) for x in packages]
    indices[0] = 0
    
    # seems to fix it for njobs=1
    if njobs > 1:
        indices[-1] = len(nd2_indices_dict) - (package_length + 1)
    else:
        indices[-1] = len(nd2_indices_dict) - (package_length)
        
    joblib_packages = []
    for count, idx in enumerate(indices):
        if count <= len(indices)-2:
            # package = list(range(idx,indices[count+1]))
            package = all_frames[idx:indices[count+1]]
        else:
            # package = list(range(idx,idx+package_length+1))
            package = all_frames[idx:idx+package_length+1]
        joblib_packages.append(package)
    
    return njobs, joblib_packages, all_frames

In [None]:
def save_png(numpy_array, channels, p, c, t, save_dir):
    im = Image.fromarray(numpy_array)
    im.save("{}/xy{}_{}_T{}.png".format(save_dir, str(p).zfill(3), channels[str(c)], str(t).zfill(4)))

In [None]:
def extract_pngs(dask_array, save_dir, dimensions, channels, package):
    """
    Extract all png files from a package.
    """
    
    for tup in tqdm(package):
        p,c,t = tup
        
        ## based on nd2 dimensions, index dask array appropriately ##
        if dimensions["T"] != 1 and dimensions["P"] != 1 and dimensions["C"] != 1:
            arr = np.array(dask_array[t,p,c,:,:])
            save_png(arr, channels, p, c, t, save_dir)
        elif dimensions["T"] == 1 and dimensions["P"] == 1:
            arr = np.array(dask_array[c,:,:])
            save_png(arr, channels, p, c, t, save_dir)
        elif dimensions["T"] == 1 and dimensions["C"] == 1:
            arr = np.array(dask_array[p,:,:])
            save_png(arr, channels, p, c, t, save_dir)
        elif dimensions["P"] == 1 and dimensions["C"] == 1:
            arr = np.array(dask_array[t,:,:])
            save_png(arr, channels, p, c, t, save_dir)
        elif dimensions["T"] == 1:
            arr = np.array(dask_array[p,c,:,:])
            save_png(arr, channels, p, c, t, save_dir)
        elif dimensions["P"] == 1:
            arr = np.array(dask_array[t,c,:,:])
            save_png(arr, channels, p, c, t, save_dir)
        elif dimensions["C"] == 1:
            arr = np.array(dask_array[t,p,:,:])
            save_png(arr, channels, p, c, t, save_dir)
        
        else:
            arr = np.array(dask_array[:,:])
            save_png(arr, channels, p, c, t, save_dir)
        

In [None]:
def extract_all_pngs(nd2_file, save_dir, dimensions, channels, njobs, packages, backend="loky"):
    """
    Extract all png files from the nd2.
    """
    
    ## create a save folder ##
    try:
        os.mkdir(save_dir)
    except:
        print("Target save directory already exists!")
        pass
    
    ## generate dask array from nd2 file ##
    dask_array = nd2.imread(nd2_file, dask=True)
    
    ## extract pngs from nd2 file, avoiding joblib overhead for small jobs ##
    if njobs == 1:
        for package in packages:
            extract_pngs(dask_array, save_dir, dimensions, channels, package)
            
    ## parallelise for big jobs ##
    else:
        Parallel(n_jobs=njobs, backend=backend)(delayed(extract_pngs)(dask_array, save_dir, dimensions, channels, package) for package in tqdm(packages))

## 3. Extract all pngs

In [None]:
save_dir = "extracted"
dimensions = get_nd2_dimensions(nd2_file)
channels = get_channel_dict(nd2_file)
njobs, packages, all_frames = create_joblib_packages(dimensions)

In [None]:
extract_all_pngs(nd2_file, save_dir, dimensions, channels, njobs, packages)

## 4. Additional metadata
* `temporal_frame_spacing` outputs a graph showing the consistency of the frame spacing in different channels
    * TODO: Plot all colour channels on same graph in `temporal_frame_spacing`

In [None]:
def temporal_frame_spacing(nd2_file, dimensions, channel=0, channels=None, outfile=None):
    """
    Plots the temporal frame spacing of a single colour channel.
    """
    frame_timings = []
    with nd2.ND2File(nd2_file) as f:
        for frame in range(channel,f.metadata.contents.frameCount,int(f.metadata.contents.frameCount/dimensions["T"])):
            frame_timings.append(f.frame_metadata(frame).channels[channel].time.relativeTimeMs/1000)
        
        single_channel_frame_list = list(range(channel,f.metadata.contents.frameCount,int(f.metadata.contents.frameCount/dimensions["T"]))) # if you want to plot frame number instead of time point
        frame_spacing = np.diff(frame_timings).tolist()
        time_points = list(range(len(frame_spacing)))
        
        if outfile:
            mean_frame_spacing = np.mean(frame_spacing)
            CV_frame_spacing = np.std(frame_spacing)/mean_frame_spacing
            with open(outfile, 'a') as f:
                f.write("\n")
                f.write("The mean temporal frame spacing in the {} channel is {}s and the coefficient of variation in the frame spacing is {}".format(channels[str(channel)], str(round(mean_frame_spacing, 4)), str(round(CV_frame_spacing, 5))))
        
        plt.plot(time_points, frame_spacing)
        plt.xlabel("Time point (-)")
        plt.ylabel("Frame spacing (s)")
        if channels:
            plt.title("Temporal frame spacing in {} channel".format(channels[str(channel)]))
        else:
            plt.title("Temporal frame spacing in channel {}".format(str(channel)))
        
        plt.savefig("{}_{}_channel_temporal_frame_spacing.png".format(nd2_file.split(".")[0], channels[str(channel)]), bbox_inches='tight', dpi=250)
        plt.show()

In [None]:
for i in range(len(channels)):
    temporal_frame_spacing(nd2_file, dimensions, channel=i, channels=channels, outfile=outfile)