## Usage
- Edit the settings in the cell below.
- Cell -> Run All.
- A Napari window will open, where you can scroll through your data.

In [None]:

# change this to point to your plate as seen in lmu_active1/instruments/Nano
plate = 'lyu/140324-A53T P62 staining/140324-A53T P62 staining/2024-03-27/20378/TimePoint_1'
plate = 'karkkael/microwell images/Plate1/2024-09-23/1/TimePoint_1/'

# set lmu_active1 root folder for Linux or Windows
#lmu_active1 = Path('/mnt/lmu_active1') # Linux
#lmu_active1 = Path('L:\lmu_active1') # Windows

# define colors you want to use (as many as you have channels)
#colormap = ["yellow", "magenta", "cyan"]
colormap = ["blue", "green", "red"]
colormap = ["blue", "green", "red", "magenta"]


## Code
You don't need to make changes in the code cells below. 

In [None]:
from aicsimageio.aics_image import AICSImage
from pathlib import Path
import matplotlib.pyplot as plt
import napari
import numpy as np
import os
import pandas as pd
import platform

def get_lmu_active1():
    current_os = platform.system()
    
    if current_os == "Windows":
        return "L:\\lmu_active1"
    elif current_os == "Linux":
        return "/mnt/lmu_active1"
    else:
        raise ValueError(f"Unsupported operating system: {current_os}")
        
# original image folder
orig = get_lmu_active1() / Path('instruments/Nano') / Path(plate)
orig = Path('/home/user/nanodata') / Path(plate)
orig = Path('/home/hajaalin/data/micro') / Path(plate)


In [None]:
PATH = 'Path'
DATE = 'Date'
TIMEPOINT = 'TimePoint'
ZSTEP = 'ZStep'
PLATE = 'Plate'
WELL = 'Well'
SITE = 'Site'
CHANNEL = 'Channel'

files = [(str(x),x.parent,x.name) for x in orig.glob("**/*.tif") if not "thumb" in x.name]
df = pd.DataFrame(files, columns=[PATH,'DirName','FileName'])

files = [(str(x)) for x in orig.glob("**/*.tif") if not "thumb" in x.name]
#files = ['/home/user/data/ael/microwell images/Plate1/2024-09-23/1/TimePoint_1/ZStep_9/Plate1_E05_s2_w1D10A7B39-96B7-4967-8EC6-22EA232A7725.tif']
df = pd.DataFrame(files, columns=[PATH])

print(files[-1])
df.head()


metadata_columns = {
    'mc1': DATE,
    'mc2': TIMEPOINT,
    'mc3': ZSTEP,
    'mc4': PLATE,
    'mc5': WELL,
    'mc6': SITE,
    'mc7': CHANNEL
}
# Regular expression pattern
pattern = r'.*/(?P<Date>\d\d\d\d-\d\d-\d\d)/[^/]*/TimePoint_(?P<TimePoint>\d+)/(?:ZStep_(?P<ZStep>\d+)/)?(?P<Plate>[^_]*)_(?P<Well>\w\d{2})_s(?P<Site>\d)_(?P<Wavelength>w\d)'

# Cross-platform pattern with dynamic column names
pattern = r'.*[/\\](?P<{mc1}>\d{{4}}-\d{{2}}-\d{{2}})[/\\][^/\\]*[/\\]TimePoint_(?P<{mc2}>\d+)(?:[/\\]ZStep_(?P<{mc3}>\d+))?[/\\](?P<{mc4}>[^_]+)_(?P<{mc5}>\w\d{{2}})_s(?P<{mc6}>\d)_(?P<{mc7}>w\d)'.format(**metadata_columns)

# Apply the regex pattern and extract the desired columns
df_extracted = df[PATH].str.extract(pattern)
print()

# Add the extracted columns back to the original dataframe
df = df.join(df_extracted)

# Show the result
df.head()

In [None]:
df.tail()

In [None]:
files[-1]

In [None]:
wavelengths = sorted(df.Channel.unique())
wavelengths

In [None]:
metadata_cols = metadata_columns.values()
metadata_cols

In [None]:
mask = df[ZSTEP].isnull()
df2d = df[mask].copy().reset_index(drop=True)
df3d = df[~mask].copy().reset_index(drop=True)
df3d[ZSTEP] = df3d[ZSTEP].astype(int)

df2d.sort_values(by=[PLATE, WELL, SITE, CHANNEL], inplace=True, ignore_index=True)
df3d.sort_values(by=[PLATE, WELL, SITE, ZSTEP, CHANNEL], inplace=True, ignore_index=True)
df3d.head()

In [None]:
grouped2d = df2d.groupby(by=[PLATE, WELL, SITE]).agg(list)
grouped2d

In [None]:
grouped3d = df3d.groupby(by=[PLATE, WELL, SITE]).agg(list)
grouped3d

In [None]:
print(len(grouped3d.loc['Plate1', 'C07', '2'][PATH]))
print(len(grouped3d.loc['Plate1', 'C07', '2'][DATE]))
print(len(grouped3d.loc['Plate1', 'C07', '2'][TIMEPOINT]))
print(len(grouped3d.loc['Plate1', 'C07', '2'][ZSTEP]))
print(len(grouped3d.loc['Plate1', 'C07', '2'][CHANNEL]))
print(grouped3d.loc['Plate1', 'C07', '2'])
paths = grouped3d.loc['Plate1', 'C07', '2'][PATH]
for p in paths:
    print(p)

In [None]:
import dask.array as da
from aicsimageio import AICSImage

def create_dask_array(grouped2d):
    # Dictionary to store Dask arrays for each plate
    plates = []
    plate_stack = None

    # Prebuild index mapping
    index_map = {}

    # Group by plate and well to handle multiple sites within a well
    for plate, plate_group in grouped2d.groupby(PLATE):
        wells = []

        # Iterate over each well
        for well, well_group in plate_group.groupby(WELL):
            sites = []

            # Iterate over each site
            for site, site_group in well_group.groupby(SITE):
                # At this point, we know the plate, well, and site
                # Add an entry to index_mapping for this site
                index_map[(plate, well, site)] = (len(plates), len(wells), len(sites))

                print(site_group.columns)
                print(site_group.shape)

                # Explode list columns
                exploded_site_group = site_group.explode([PATH, DATE, TIMEPOINT, CHANNEL])
                print(exploded_site_group.shape)
                print(exploded_site_group.apply(type).unique())
                print(exploded_site_group.head())

                channels = []

                # Iterate over each channel and stack them for the current Z-step
                for channel_path in exploded_site_group[PATH]:
                    print(plate, well, site, channel_path)
                    img = AICSImage(channel_path)
                    # Use img.get_image_dask_data() for lazy loading of data
                    dask_data = img.get_image_dask_data()
                    #print(dask_data.shape)
                    dask_data = dask_data.squeeze()
                    #print(dask_data.shape)
                    channels.append(dask_data)
                print()
                
                # Stack channels for the current site
                site_stack = da.stack(channels, axis=0)  # Stack Z-slices to form a 3D site-level array
                print(site_stack.shape)
                sites.append(site_stack)

            # Stack all site-level arrays into a well-level array
            well_stack = da.stack(sites, axis=0)  # Stack sites into a well
            wells.append(well_stack)

        # Stack all well-level arrays into a plate-level array
        plate_stack = da.stack(wells, axis=0)  # Stack wells into a plate
        plates.append(plate_stack)

    final_dask_array = da.stack(plates)
    return index_map, final_dask_array

In [None]:
index_map_2d, final_dask_array_2d = create_dask_array(grouped2d)
print(final_dask_array_2d.shape)
print(index_map_2d)

In [None]:
import dask.array as da
from aicsimageio import AICSImage

def create_dask_array_with_z(grouped3d):
    # Dictionary to store Dask arrays for each plate
    plates = []
    plate_stack = None

    # Prebuild index mapping
    index_map = {}

    # Group by plate and well to handle multiple sites within a well
    for plate, plate_group in grouped3d.groupby(PLATE):
        wells = []

        # Iterate over each well
        for well, well_group in plate_group.groupby(WELL):
            sites = []

            # Iterate over each site
            for site, site_group in well_group.groupby(SITE):
                z_steps = []

                # At this point, we know the plate, well, and site
                # Add an entry to index_mapping for this site
                index_map[(plate, well, site)] = (len(plates), len(wells), len(sites))

                print(site_group.columns)
                print(site_group.shape)
                print(site_group[ZSTEP].apply(type).unique())  # Check the type of elements in the ZStep column
                #print(site_group[ZSTEP].head())  # Inspect the first few rows

                # Explode both ZStep and Channel columns to ensure they correspond correctly
                exploded_df = site_group.explode([PATH, DATE, TIMEPOINT, ZSTEP, CHANNEL])
                print(exploded_df.shape)
                print(exploded_df.apply(type).unique())
                #print(exploded_df.head())

                # Group by ZStep to handle stacking of channels for each Z-slice
                for zstep, zstep_group in exploded_df.groupby(ZSTEP):
                    channels = []

                    # Iterate over each channel and stack them for the current Z-step
                    for channel_path in zstep_group[PATH]:
                        print(plate, well, site, zstep, channel_path)
                        img = AICSImage(channel_path)
                        # Use img.get_image_dask_data() for lazy loading of data
                        dask_data = img.get_image_dask_data()
                        #print(dask_data.shape)
                        dask_data = dask_data.squeeze()
                        #print(dask_data.shape)
                        channels.append(dask_data)

                    print()
                    # Stack channels along a new axis (assume channels have same shape)
                    z_step_stack = da.stack(channels, axis=0)  # Stack channels for this Z-step
                    z_steps.append(z_step_stack)

                print()
                # Stack Z-steps into a full 3D array for the current site
                site_stack = da.stack(z_steps, axis=0)  # Stack Z-slices to form a 3D site-level array
                print(site_stack.shape)
                sites.append(site_stack)

            # Stack all site-level arrays into a well-level array
            well_stack = da.stack(sites, axis=0)  # Stack sites into a well
            wells.append(well_stack)

        # Stack all well-level arrays into a plate-level array
        plate_stack = da.stack(wells, axis=0)  # Stack wells into a plate
        plates.append(plate_stack)

    final_dask_array = da.stack(plates)
    return index_map, final_dask_array

In [None]:
index_map_3d, final_dask_array_3d = create_dask_array_with_z(grouped3d)
print(final_dask_array_3d.shape)
print(index_map_3d)

In [None]:
final_dask_array[0, 0, 0, :, :, :, :]

In [None]:
import napari
import dask.array as da
from magicgui import magicgui

# Prebuild the index with (plate, well, site) -> dask slice
#index_map = {}  # assuming this has been built during array construction
plates = list(df3d[PLATE].unique())
wells = list(df3d[WELL].unique())
sites = list(df3d[SITE].unique())

print(plates)
print(wells)
print(sites)

In [None]:
# Initialize Napari viewer
viewer = napari.Viewer()
   
# Function to update the image and retain contrast settings
def update_image(data_slice):
    # Get current contrast limits if the image layer exists
    contrast_limits = None
    if viewer.layers:
        contrast_limits = [l.contrast_limits for l in viewer.layers]
      
    # Remove the previous image layer
    viewer.layers.clear()

    # Re-add the image layer with the new data and channel axis
    new_layer = viewer.add_image(final_dask_array[data_slice],
                                 channel_axis=1,  # Keep channel axis
                                 contrast_limits=contrast_limits,
                                 name=wavelengths)

update_image(index_map[plates[0],wells[0],sites[0]])

# Create pull-down for plates and wells
@magicgui(plate={"choices": plates}, well={"choices": wells}, site={"choices": sites, "label": "Site"}, call_button=False, auto_call=True)
def navigation_widget(plate: str, well: str, site: str):  # Change site type to str
    print(plate, well, site)  # Debug print
    # Select data based on plate, well, and site
    if (plate, well, site) in index_map:
        data_slice = index_map[(plate, well, site)]
        print("data_slice: " + str(data_slice))
        print(final_dask_array[data_slice].shape)
        
        update_image(data_slice)
    else:
        print(f"No data found for (plate: {plate}, well: {well}, site: {site})")  # Print message if no data

# Add widget to Napari viewer
viewer.window.add_dock_widget(navigation_widget)
napari.run()


In [None]:
print(type(image_layer))
print(type(image_layer[0]))
print(image_layer[0].contrast_limits)
print(image_layer[0].name)
contrast_limits = [l.contrast_limits for l in image_layer]
print(contrast_limits)