In [1]:
import glob
import sys
import os
import subprocess
from pathlib import Path
import pandas as pd
import numpy as np
import tifffile
from nd2reader import ND2Reader
import imageio
import matplotlib.pyplot as plt
import itk
import warnings; warnings.filterwarnings('ignore', category=UserWarning, module='itk')

import dask.array as da
import nd2

from tqdm import tqdm

import plot_utils
import utils

%matplotlib widget

from matplotlib.pyplot import close as close_plots
close_plots()

from IPython.display import display, HTML
display(HTML("<style>pre { white-space: pre !important; }</style>"))

In [2]:
DATA_PTH = Path(r'C:\Users\munib\POSTDOC\DATA\g5-HT-free\albert-g5ht-example\022025_eft_41z_starved_worm002')
INPUT_ND2 = '022025_eft_41z_starved_worm002.nd2'
INPUT_ND2_PTH = os.path.join(DATA_PTH, INPUT_ND2)
NOISE_PTH = r'C:\Users\munib\POSTDOC\CODE\g5-HT3_imaging_pipeline/noise_042925.tif'
noise_tif = tifffile.imread(NOISE_PTH)

OUT_DIR = os.path.splitext(INPUT_ND2_PTH)[0]

STACK_LENGTH = 41

noise_stack = np.stack([noise_tif] * STACK_LENGTH, axis=0).astype(np.float32)
stack_range = utils.get_range_from_nd2(INPUT_ND2_PTH) # aka frame indices

In [3]:
stack_range

range(0, 1200)

In [18]:
frame = utils.get_stack_from_nd2(INPUT_ND2_PTH, 1000, noise_stack, denoise=True, make_third_channel=True)

In [19]:
frame.shape

(39, 3, 512, 512)

In [20]:
close_plots()

# toplot= frame[:,[1,0,2],:,:]
toplot = frame
toplot[:,0,:,:] = toplot[:,0,:,:] / np.max(toplot[:,0,:,:])
toplot[:,1,:,:] = toplot[:,1,:,:] / np.max(toplot[:,1,:,:])
# max on each channel # TODO

%matplotlib qt
fig = plt.figure(layout='constrained', figsize=(5,5))
axgen = plot_utils.reflow_gen(fig)
# for s in range(STACK_LENGTH-2):
for s in range(STACK_LENGTH-2):
    ax = next(axgen)
    # ax.imshow(toplot[s,:,:,:].transpose())
    ax.imshow(toplot[s,:,:,:].transpose() / np.max(toplot[s,:,:,:].transpose()))
    ax.imshow(toplot[s,:,:,:].transpose())
    ax.set_title(s)
plt.show()

In [22]:
tifffile.imwrite(
    'temp.tif',
    frame,
    ome=True,
    metadata={"axes": "ZCYX"},
    compression=None,
    bigtiff=True,
)

In [None]:
def get_range(input_nd2, stack_length=41):
    """Returns a range object containing valid stack indices from the ND2 file."""
    with ND2Reader(input_nd2) as f:
        stack_range = range(f.metadata['num_frames'] // stack_length)
    return stack_range

def get_frames(path, stack_length=41):
    # TODO: convert this to get_frames_in_range()
    with ND2Reader(path) as rdr:
        # Peek at dimensions present in the file, e.g. {'t':1200,'z':41,'y':1024,'x':1024,'c':2}
        sizes = dict(rdr.sizes)

        # Iterate over time (first frame only). If no time axis, iterate over nothing.
        rdr.iter_axes = 't' if 't' in sizes else ''

        # Bundle spatial + Z + channels into one ndarray per frame.
        # Weâ€™ll request (z, y, x, c) and then transpose to (x, y, z, c).
        have_z = 'z' in sizes
        have_c = 'c' in sizes
        if have_z and have_c:
            rdr.bundle_axes = 'zyxc'
        elif have_z and not have_c:
            rdr.bundle_axes = 'zyx'
        elif have_c and not have_z:
            rdr.bundle_axes = 'yxc'
        else:
            rdr.bundle_axes = 'yx'

        n_frames = 10
        arr = np.zeros((n_frames,512,512,2,stack_length)) # (first 10 frames, x, y, c, z-slices)
        for t in tqdm(range(n_frames*stack_length)):
            s = t % stack_length
            it = t % n_frames
            arr[it,:,:,:,s] = rdr[s]

    return arr

def get_frame(path, frame_index, stack_length=41):
    with ND2Reader(path) as rdr:
        # Peek at dimensions present in the file, e.g. {'t':1200,'z':41,'y':1024,'x':1024,'c':2}
        sizes = dict(rdr.sizes)

        # Iterate over time (first frame only). If no time axis, iterate over nothing.
        rdr.iter_axes = 't' if 't' in sizes else ''
        rdr.bundle_axes = 'yxc'

        arr = np.zeros((512,512,2,41)) # (x, y, c, z-slices)
        for i,f in enumerate(range(frame_index,frame_index + stack_length)):
            arr[:,:,:,i] = rdr[f]
            # print(rdr[f].shape)

    return arr

In [None]:
frame0 = get_frame(INPUT_ND2_PTH, 0)
print(frame0.shape)

# create a new NaN channel of same shape as one existing channel
new_channel = np.full(frame0[:, :, :1, :].shape, 0)

# concatenate along the channel axis (axis=3)
frame0 = np.concatenate([frame0, new_channel], axis=2)

# rearrange channels
frame0 = frame0[:,:,[1,0,2],:]

frame0.shape

In [None]:
%matplotlib qt
fig = plt.figure(layout='constrained', figsize=(5,5))
axgen = plot_utils.reflow_gen(fig)
for s in range(STACK_LENGTH):
    ax = next(axgen)
    ax.imshow(frame0[:,:,:,s] / np.max(frame0[:,:,:,s]))
    ax.set_title(s)
plt.show()

In [None]:
frame1 = get_frame(INPUT_ND2_PTH, 1)
print(frame1.shape)

# create a new NaN channel of same shape as one existing channel
new_channel = np.full(frame1[:, :, :1, :].shape, 0)

# concatenate along the channel axis (axis=3)
frame1 = np.concatenate([frame1, new_channel], axis=2)

# rearrange channels
frame1 = frame1[:,:,[1,0,2],:]

frame1.shape

In [None]:
%matplotlib qt
fig = plt.figure(layout='constrained', figsize=(5,5))
axgen = plot_utils.reflow_gen(fig)
for s in range(STACK_LENGTH):
    ax = next(axgen)
    ax.imshow(frame1[:,:,:,s] / np.max(frame1[:,:,:,s]))
    ax.set_title(s)
plt.show()

In [None]:
print(frame0.shape)
print(frame0.transpose((2,3,0,1)).shape)

In [None]:
tifffile.imwrite(
    'temp2.tif',
    frame0.transpose((2,3,0,1)),
    ome=True,
    metadata={"axes": "CZYX"},
    compression=None,
    bigtiff=True,
)