## Usage
- Edit the settings in the cell below.
- Cell -> Run All.
- A Napari window will open, where you can scroll through your data.

In [None]:

# change this to point to your plate as seen in lmu_active1/instruments/Nano
plate = 'lyu/140324-A53T P62 staining/140324-A53T P62 staining/2024-03-27/20378/TimePoint_1'
plate = 'lyu/220824-hDA A53T-020924-1'
plate = 'karkkael/microwell images/Plate1/2024-09-23/1/TimePoint_1/'

# set lmu_active1 root folder for Linux or Windows
#lmu_active1 = Path('/mnt/lmu_active1') # Linux
#lmu_active1 = Path('L:\lmu_active1') # Windows

# define colors you want to use (as many as you have channels)
#colormap = ["yellow", "magenta", "cyan"]
colormap = ["blue", "green", "red"]
colormap = ["blue", "green", "red", "magenta"]


## Code
You don't need to make changes in the code cells below. 

In [None]:
from aicsimageio.aics_image import AICSImage
from pathlib import Path
import matplotlib.pyplot as plt
import napari
import numpy as np
import os
import pandas as pd
import platform

def get_lmu_active1():
    current_os = platform.system()
    
    if current_os == "Windows":
        return "L:\\lmu_active1"
    elif current_os == "Linux":
        return "/mnt/lmu_active1"
    else:
        raise ValueError(f"Unsupported operating system: {current_os}")
        
# original image folder
orig = get_lmu_active1() / Path('instruments/Nano') / Path(plate)
#orig = Path('/home/user/nanodata') / Path(plate)
orig = Path('/home/hajaalin/data/micro') / Path(plate)

seg = Path('/home/user/nanodata/stardist') / Path(plate)
seg = Path('/mnt/lmu_active1/airflow/nano') / plate.replace(' ', '_')
#seg = Path('/tmp')
seg = Path('/home/hajaalin/data/micro/karkkael/stardist_karkkael')
print(seg)

In [None]:
PATH = 'Path'
DATE = 'Date'
TIMEPOINT = 'TimePoint'
ZSTEP = 'ZStep'
PLATE = 'Plate'
WELL = 'Well'
SITE = 'Site'
CHANNEL = 'Channel'
UUID = 'UUID'

def create_file_list(orig):
    metadata_columns = {
        'mc1': DATE,
        'mc2': TIMEPOINT,
        'mc3': ZSTEP,
        'mc4': PLATE,
        'mc5': WELL,
        'mc6': SITE,
        'mc7': CHANNEL,
        'mc8': UUID,
    }

    files = [(str(x)) for x in orig.glob("**/*.tif") if not "thumb" in x.name]
    df = pd.DataFrame(files, columns=[PATH])

    if not df.empty:
        print(files[-1])

    # Cross-platform pattern with dynamic column names
    pattern = (\
        r'.*[/\\](?P<{mc1}>\d{{4}}-\d{{2}}-\d{{2}})'\
        + r'[/\\][^/\\]*[/\\]TimePoint_(?P<{mc2}>\d+)'\
        + r'(?:[/\\]ZStep_(?P<{mc3}>\d+))?'\
        + r'[/\\](?P<{mc4}>[^_]+)_(?P<{mc5}>\w\d{{2}})_s(?P<{mc6}>\d{{1,2}})_(?P<{mc7}>w\d)'\
        + r'(?P<{mc8}>[A-F0-9]{{8}}-[A-F0-9]{{4}}-[A-F0-9]{{4}}-[A-F0-9]{{4}}-[A-F0-9]{{12}})'\
    ).format(**metadata_columns)
    #pattern = r'.*[/\\](?P<{mc1}>\d{{4}}-\d{{2}}-\d{{2}})[/\\][^/\\]*[/\\]TimePoint_(?P<{mc2}>\d+)(?:[/\\]ZStep_(?P<{mc3}>\d+))?[/\\](?P<{mc4}>[^_]+)_(?P<{mc5}>\w\d{{2}})_s(?P<{mc6}>\d)_(?P<{mc7}>w\d)'.format(**metadata_columns)
    #pattern = r'.*[/\\](?P<{mc1}>\d{{4}}-\d{{2}}-\d{{2}})[/\\][^/\\]*[/\\]TimePoint_(?P<{mc2}>\d+)(?:[/\\]ZStep_(?P<{mc3}>\d+))?[/\\](?P<{mc4}>[^_]+)_(?P<{mc5}>\w\d{{2}})_s(?P<{mc6}>\d)_(?P<{mc7}>w\d)(?P<{mc8}>[A-F0-9]{8}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{12})'.format(**metadata_columns)

    print(pattern)
    
    # Apply the regex pattern and extract the desired columns
    df_extracted = df[PATH].str.extract(pattern)
    print()

    # Add the extracted columns back to the original dataframe
    df = df.join(df_extracted)

    return df

df = create_file_list(orig)
df_seg = create_file_list(seg)

In [None]:
# Show the result
df.head()

In [None]:
df.tail()

In [None]:
print(df_seg.shape)
df_seg.head()

In [None]:
print(df[df.Channel.isnull()].shape)

In [None]:
wavelengths = df.Channel.unique()
print(wavelengths)

wavelengths = sorted(wavelengths)
print(wavelengths)


In [None]:
# separate Z-slices and projection images
mask = df[ZSTEP].isnull()
df2d = df[mask].copy().reset_index(drop=True)
df3d = df[~mask].copy().reset_index(drop=True)
df3d[ZSTEP] = df3d[ZSTEP].astype(int)

df2d.sort_values(by=[PLATE, WELL, SITE, CHANNEL], inplace=True, ignore_index=True)
df3d.sort_values(by=[PLATE, WELL, SITE, ZSTEP, CHANNEL], inplace=True, ignore_index=True)

In [None]:
print(df2d.shape)
df2d.head()

In [None]:
df3d.head()

In [None]:
# keep only segmentations that have corresponding projection images
uuid_2d = df2d[[UUID]].copy()
print(uuid_2d.shape)
df_seg_2d = pd.merge(df_seg, uuid_2d, how='inner')
print(df_seg_2d.shape)
df_seg_2d.head()

In [None]:
grouped2d = df2d.groupby(by=[PLATE, WELL, SITE]).agg(list)
grouped2d

In [None]:
if not df3d.empty:
    grouped3d = df3d.groupby(by=[PLATE, WELL, SITE]).agg(list)
    grouped3d
else:
    print('No Z steps')

In [None]:
if not df_seg_2d.empty:
    grouped_seg_2d = df_seg_2d.groupby(by=[PLATE, WELL, SITE]).agg(list)
    grouped_seg_2d
else:
    print('No segmentations')

In [None]:
if not df3d.empty:
    print(len(grouped3d.loc['Plate1', 'C07', '2'][PATH]))
    print(len(grouped3d.loc['Plate1', 'C07', '2'][DATE]))
    print(len(grouped3d.loc['Plate1', 'C07', '2'][TIMEPOINT]))
    print(len(grouped3d.loc['Plate1', 'C07', '2'][ZSTEP]))
    print(len(grouped3d.loc['Plate1', 'C07', '2'][CHANNEL]))
    print(grouped3d.loc['Plate1', 'C07', '2'])
    paths = grouped3d.loc['Plate1', 'C07', '2'][PATH]
    for p in paths:
        print(p)
else:
    print('No Z steps')

In [None]:
import dask.array as da
from aicsimageio import AICSImage

def create_dask_array(grouped2d):
    # Dictionary to store Dask arrays for each plate
    plates = []
    plate_stack = None

    # Prebuild index mapping
    index_map = {}

    # Group by plate and well to handle multiple sites within a well
    for plate, plate_group in grouped2d.groupby(PLATE):
        wells = []

        # Iterate over each well
        for well, well_group in plate_group.groupby(WELL):
            sites = []

            # Iterate over each site
            for site, site_group in well_group.groupby(SITE):
                # At this point, we know the plate, well, and site
                # Add an entry to index_mapping for this site
                index_map[(plate, well, site)] = (len(plates), len(wells), len(sites))

                print(site_group.columns)
                print(site_group.shape)

                # Explode list columns
                exploded_site_group = site_group.explode([PATH, DATE, TIMEPOINT, CHANNEL, UUID])
                print(exploded_site_group.shape)
                print(exploded_site_group.apply(type).unique())
                print(exploded_site_group.head())

                channels = []

                # Iterate over each channel and stack them for the current Z-step
                for channel_path in exploded_site_group[PATH]:
                    print(plate, well, site, channel_path)
                    img = AICSImage(channel_path)
                    # Use img.get_image_dask_data() for lazy loading of data
                    dask_data = img.get_image_dask_data()
                    #print(dask_data.shape)
                    dask_data = dask_data.squeeze()
                    #print(dask_data.shape)
                    channels.append(dask_data)
                print()
                
                # Stack channels for the current site
                site_stack = da.stack(channels, axis=0)  # Stack Z-slices to form a 3D site-level array
                print(site_stack.shape)
                sites.append(site_stack)

            # Stack all site-level arrays into a well-level array
            well_stack = da.stack(sites, axis=0)  # Stack sites into a well
            wells.append(well_stack)

        # Stack all well-level arrays into a plate-level array
        plate_stack = da.stack(wells, axis=0)  # Stack wells into a plate
        plates.append(plate_stack)

    final_dask_array = da.stack(plates)
    return index_map, final_dask_array

In [None]:
tmp = np.random.random((4, 100, 100))
tmp.shape
da.stack([tmp]).shape

In [None]:
index_map_2d, final_dask_array_2d = create_dask_array(grouped2d)

In [None]:
if not df_seg_2d.empty:
    index_map_seg, final_dask_array_seg = create_dask_array(grouped_seg_2d)
    print(final_dask_array_seg.shape)
    #print(index_map_seg)
    final_dask_array_seg[0, 0, 0, :, :, :]
else:
    print('No segmentations')

In [None]:
print(final_dask_array_2d.shape)
if not df_seg_2d.empty:
    print(final_dask_array_seg.shape)
    #print(index_map_seg)
    #final_dask_array_seg[0, 0, 0, :, :, :]
#print(index_map_2d)
print(da.squeeze(final_dask_array_seg, 3).shape)

In [None]:
final_dask_array_2d[0, 0, 0, :, :, :]

In [None]:
import dask.array as da
from aicsimageio import AICSImage

def create_dask_array_with_z(grouped3d):
    # Dictionary to store Dask arrays for each plate
    plates = []
    plate_stack = None

    # Prebuild index mapping
    index_map = {}

    # Group by plate and well to handle multiple sites within a well
    for plate, plate_group in grouped3d.groupby(PLATE):
        wells = []

        # Iterate over each well
        for well, well_group in plate_group.groupby(WELL):
            sites = []

            # Iterate over each site
            for site, site_group in well_group.groupby(SITE):
                z_steps = []

                # At this point, we know the plate, well, and site
                # Add an entry to index_mapping for this site
                index_map[(plate, well, site)] = (len(plates), len(wells), len(sites))

                print(site_group.columns)
                print(site_group.shape)
                print(site_group[ZSTEP].apply(type).unique())  # Check the type of elements in the ZStep column
                #print(site_group[ZSTEP].head())  # Inspect the first few rows

                # Explode both ZStep and Channel columns to ensure they correspond correctly
                exploded_df = site_group.explode([PATH, DATE, TIMEPOINT, ZSTEP, CHANNEL])
                print(exploded_df.shape)
                print(exploded_df.apply(type).unique())
                #print(exploded_df.head())

                # Group by ZStep to handle stacking of channels for each Z-slice
                for zstep, zstep_group in exploded_df.groupby(ZSTEP):
                    channels = []

                    # Iterate over each channel and stack them for the current Z-step
                    for channel_path in zstep_group[PATH]:
                        print(plate, well, site, zstep, channel_path)
                        img = AICSImage(channel_path)
                        # Use img.get_image_dask_data() for lazy loading of data
                        dask_data = img.get_image_dask_data()
                        #print(dask_data.shape)
                        dask_data = dask_data.squeeze()
                        #print(dask_data.shape)
                        channels.append(dask_data)

                    print()
                    # Stack channels along a new axis (assume channels have same shape)
                    z_step_stack = da.stack(channels, axis=0)  # Stack channels for this Z-step
                    z_steps.append(z_step_stack)

                print()
                # Stack Z-steps into a full 3D array for the current site
                site_stack = da.stack(z_steps, axis=0)  # Stack Z-slices to form a 3D site-level array
                print(site_stack.shape)
                sites.append(site_stack)

            # Stack all site-level arrays into a well-level array
            well_stack = da.stack(sites, axis=0)  # Stack sites into a well
            wells.append(well_stack)

        # Stack all well-level arrays into a plate-level array
        plate_stack = da.stack(wells, axis=0)  # Stack wells into a plate
        plates.append(plate_stack)

    final_dask_array = da.stack(plates)
    return index_map, final_dask_array

In [None]:
if not df3d.empty:
    index_map_3d, final_dask_array_3d = create_dask_array_with_z(grouped3d)
    print(final_dask_array_3d.shape)
    print(index_map_3d)
    final_dask_array_3d[0, 0, 0, :, :, :, :]
else:
    print('No Z steps')

In [None]:
import napari
import dask.array as da
from magicgui import magicgui

# Prebuild the index with (plate, well, site) -> dask slice
#index_map = {}  # assuming this has been built during array construction
plates = list(df2d[PLATE].unique())
wells = list(df2d[WELL].unique())
sites = list(df2d[SITE].unique())

print(plates)
print(wells)
print(sites)

In [None]:
if not df3d.empty:
    print(final_dask_array_3d.shape)
    n_zsteps = final_dask_array_3d.shape[3]
    print(n_zsteps)

    # Expand 2D projection image to match the Z-axis length of the 3D image
    expanded_2d_da = da.repeat(final_dask_array_2d[:, :, :, None, :, :, :], repeats=n_zsteps, axis=3)
    print(expanded_2d_da.shape)

    if not df_seg_2d.empty:
        print(final_dask_array_seg.shape)
        # Expand 2D segmentation image to match the Z-axis length of the 3D image
        expanded_seg_da = da.repeat(final_dask_array_seg[:, :, :, None, :, :, :], repeats=n_zsteps, axis=3)
        print(expanded_seg_da.shape)
        # Squeeze out the channel dimension because for the real data it will be split in separate layers
        expanded_seg_da = da.squeeze(expanded_seg_da, 4)
        print(expanded_seg_da.shape)


In [None]:
from qtpy.QtWidgets import QVBoxLayout, QWidget, QLabel, QComboBox

# Initialize Napari viewer
viewer = napari.Viewer()#order=[2,4,5,6])

if not df3d.empty:
    # Add 3D image with ZStep axis
    viewer.add_image(
        final_dask_array_3d, 
        channel_axis=4,  # Channel is 4th dimension in 3D
        name=wavelengths,
    )

    # Add 2D projection image without ZStep axis
    names_2d = [w + " projection" for w in wavelengths]
    viewer.add_image(
        expanded_2d_da, 
        channel_axis=4,  # Channel is 4th dimension in 3D
        name=names_2d,
    )
    
else:
    # Add 2D projection image without ZStep axis
    viewer.add_image(
        final_dask_array_2d, 
        channel_axis=3,  # Channel is 3rd dimension in 2D
        name=wavelengths,
    )

# Add labels
if not df_seg_2d.empty:
    viewer.add_labels(
        expanded_seg_da,
        #final_dask_array_seg, 
        #da.squeeze(final_dask_array_seg,3), 
        name='stardist_w1',
    )

# Set axis labels
if not df3d.empty:
    viewer.dims.axis_labels = ['Plate', 'Well', 'Site', 'Z-slice', 'X', 'Y']
    # start from Z-slice 0 to have labels visible
    viewer.dims.set_point(3,0)
else:
    viewer.dims.axis_labels = ['Plate', 'Well', 'Site', 'X', 'Y']
    

# start from well 0 to match pull-down
viewer.dims.set_point(1,0)
# start from site 0
viewer.dims.set_point(2,0)



# Create a widget for navigation
IDX_WELL = 1
class NavigationWidget(QWidget):
    def __init__(self, wells):
        super().__init__()
        layout = QVBoxLayout()
        
        self.wells = wells

        # Well selection
        self.well_label = QLabel("Well")
        self.well_combo = QComboBox()
        self.well_combo.addItems(wells)
        self.well_combo.currentTextChanged.connect(self.update_image)

        # Adding widgets to layout
        layout.addWidget(self.well_label)
        layout.addWidget(self.well_combo)

        self.setLayout(layout)
        
        viewer.dims.events.point.connect(self._update_display)
        
    def _update_display(self):
        #print("_update_display")
        slider_index = viewer.dims.point[IDX_WELL]
        slider_index = round(slider_index)
        self.well_combo.setCurrentText(self.wells[slider_index])

    def update_image(self):
        well = self.well_combo.currentText()
        #print(well)  # Debugging print

        # Select data based on plate, well, and site
        if well in self.wells:
            viewer.dims.set_point(IDX_WELL, self.wells.index(well))


viewer.window.add_dock_widget(NavigationWidget(wells))
napari.run()


In [None]:
viewer.dims.point