In [43]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
import os
import napari
import numpy as np
import pandas as pd
from qtpy.QtWidgets import (
    QPushButton, QVBoxLayout, QWidget,
    QTableWidget, QTableWidgetItem, QHeaderView
)


# Enable Qt event loop in Jupyter if needed, e.g.:
# %gui qt

# Dummy image stacks (replace with your actual 3D data)
# em_stack = np.random.random((100, 100, 50))  # Example EM stack
# lm_stack = np.random.random((100, 100, 50))  # Example LM stack
import tifffile
processed_data_dir = 'processed_data'
landmark_dir = 'landmarks'

In [45]:
# Load the image stacks
em_file_name = 'dp1_em_downsampled_stack_transformed'
em_stack_file = os.path.join(processed_data_dir, f'transformed_stacks/{em_file_name}.tif')
lm_stack_file = os.path.join(processed_data_dir, 'lm_em_stacks/dp1_lm_anatomy_trial11_CLAHE.tif')
lm_scale = (100, 4, 4)  # scale z, y, x


In [46]:
# Load raw landmarks
load_raw_landmarks = False
if load_raw_landmarks:
    from compute_affine3d_from_landmarks import transform_points, invert_affine_matrix
    raw_landmark_file = os.path.join(processed_data_dir, os.path.join(landmark_dir, 'landmarks_Nila_confirmed.csv'))
    affine3d_file = os.path.join(processed_data_dir, f'transformed_stacks/{em_file_name}_affine3d_mat.txt')
    affine3d_mat = np.loadtxt(affine3d_file)

    inv_affine3d_mat = invert_affine_matrix(affine3d_mat)

    # Process the landmarks
    raw_landmark_df = pd.read_csv(raw_landmark_file)

    points = raw_landmark_df[['z_em_downsampled', 'y_em_downsampled', 'x_em_downsampled']].values
    # Convert with affine3d_mat
    transformed_points = transform_points(points, affine3d_mat)
    transformed_points

    landmark_df = raw_landmark_df[['landmark_name']].copy()
    landmark_df['x_EM'] = transformed_points[:, 2]
    landmark_df['y_EM'] = transformed_points[:, 1]
    landmark_df['z_EM'] = transformed_points[:, 0]
    landmark_df['x_LM'] = raw_landmark_df['x_lm'] * lm_scale[2]
    landmark_df['y_LM'] = raw_landmark_df['y_lm'] * lm_scale[1]
    landmark_df['z_LM'] = raw_landmark_df['z_lm'] * lm_scale[0]
    landmark_df['tag'] = ''
else:
    landmark_df = pd.read_csv(os.path.join(landmark_dir, 'landmarks.csv'), header=0)

In [47]:
em_stack = tifffile.imread(em_stack_file)
lm_stack = tifffile.imread(lm_stack_file)


In [48]:
# -- Napari Viewer --
viewer = napari.Viewer()

em_layer = viewer.add_image(em_stack, name='EM Stack')
lm_layer = viewer.add_image(
    lm_stack,
    name='LM Stack',
    scale= lm_scale
)

#em_affine_layer = viewer.add_image(em_stack_affine, name='EM Affine Stack')
# Points layers
em_points_layer = viewer.add_points(name='EM Landmarks', ndim=3, size=5, face_color='red')
lm_points_layer = viewer.add_points(name='LM Landmarks', ndim=3, size=5, face_color='blue')

lm_points_layer.current_face_color = 'yellow'

em_points_layer.editable = False
lm_points_layer.editable = False

# Optional: automatically enable point-adding mode
#em_points_layer.mode = 'add'
#lm_points_layer.mode = 'add'


In [49]:

# We want to override the default mouse wheel (zoom) action
# so that it changes the current_step along the z dimension instead.
@viewer.mouse_wheel_callbacks.append
def scroll_through_z_stack(viewer, event):
    """
    This callback will move through the z-stack on mouse wheel scroll
    and prevent the default zoom action.
    """
    if 'Control' in event.modifiers:
        # event.delta is a (dx, dy) tuple for the scroll.
        # For typical mouse wheels, dy is +1 or -1 when scrolling up/down.
        dx, dy = event.delta

        # We'll treat a positive dy as going "forward" in the stack
        # and negative as going "backward."
        current_z = viewer.dims.current_step[0]
        if dy > 0:
            viewer.dims.current_step = (max(current_z - 1, 0),) + viewer.dims.current_step[1:]
        elif dy < 0:
            max_z = viewer.dims.range[0][1]  # upper limit (exclusive) for z
            viewer.dims.current_step = (min(current_z + 1, max_z - 1),) + viewer.dims.current_step[1:]

        # Mark the event as handled so napari does not continue with zoom
        event.handled = True

In [50]:

from manage_landmarks import LandmarkManager, LandmarkTableWidget
landmark_manager = LandmarkManager()
# Instantiate the table widget
table_widget = LandmarkTableWidget(landmark_manager)
viewer.window.add_dock_widget(table_widget, area='right')

<napari._qt.widgets.qt_viewer_dock_widget.QtViewerDockWidget at 0x72586e5434a0>

In [51]:
# --- Save landmarks to CSV ---
@viewer.bind_key("Control-S")
def save_landmarks():
    landmark_manager.df.to_csv(os.path.join(landmark_dir, "landmarks.csv"), index=False)
    print("Landmarks saved to landmarks.csv")

class ControlWidget(QWidget):
    def __init__(self):
        super().__init__()
        self.setWindowTitle("Controls")
        self.layout = QVBoxLayout()
        self.setLayout(self.layout)

        self.save_button = QPushButton("Save Landmarks")
        self.save_button.clicked.connect(save_landmarks)
        self.layout.addWidget(self.save_button)

control_widget = ControlWidget()
viewer.window.add_dock_widget(control_widget, area='right')

napari.run()


# --- Keybindings for switching to LM/EM quickly ---
@viewer.bind_key('u')
def display_lm_stack(viewer):
    viewer.layers['LM Stack'].visible = True
    viewer.layers['EM Stack'].visible = False

@viewer.bind_key('i')
def display_em_stack(viewer):
    viewer.layers['EM Stack'].visible = True
    viewer.layers['LM Stack'].visible = False



In [52]:
def load_landmarks_from_dataframe(df):
    landmark_manager.load_from_dataframe(df)
    table_widget.update_table()

    # Update points layers directly from the landmark manager
    em_points_layer.data = landmark_manager.get_em_points_zyx()
    lm_points_layer.data = landmark_manager.get_lm_points_zyx()

    # point size
    point_size = 20
    em_points_layer.size = point_size
    lm_points_layer.size = point_size

    em_points_layer.refresh()
    lm_points_layer.refresh()

In [53]:
load_landmarks_from_dataframe(landmark_df)

In [54]:
napari.run()

In [55]:
def on_lm_selection_pair_em(event):
    selected_indices = lm_points_layer.selected_data
    if len(selected_indices) == 1:
        index = next(iter(selected_indices))
        em_points_layer.selected_data = {index}
    else:
        em_points_layer.selected_data.clear()

def on_lm_selection_update_table(event):
    selected_indices = lm_points_layer.selected_data
    table_widget.table.itemSelectionChanged.disconnect(on_table_selection)
    if len(selected_indices) == 1:
        index = next(iter(selected_indices))
        table_widget.table.selectRow(index)
    else:
        table_widget.table.clearSelection()
    table_widget.table.itemSelectionChanged.connect(on_table_selection)

lm_points_layer.selected_data.events.items_changed.connect(on_lm_selection_pair_em)
lm_points_layer.selected_data.events.items_changed.connect(on_lm_selection_update_table)

# Callback for table selection (table → points sync)
def on_table_selection():
    rows = table_widget.table.selectionModel().selectedRows()
    # Temporarily disconnect to prevent recursion
    lm_points_layer.selected_data.events.items_changed.disconnect(on_lm_selection_update_table)
    if rows:
        idx = rows[0].row()
        lm_points_layer.selected_data = {idx}
    else:
        lm_points_layer.selected_data.clear()
    lm_points_layer.selected_data.events.items_changed.connect(on_lm_selection_update_table)

table_widget.table.itemSelectionChanged.connect(on_table_selection)


<PyQt5.QtCore.QMetaObject.Connection at 0x72586e35d1d0>

In [56]:
# --- Keybindings for jumping to selected landmarks in LM or EM ---
@viewer.bind_key('w')
def jump_to_selected_lm(viewer, overwrite=True):
    """Press 'w' to jump to the selected LM landmark."""
    viewer.layers['LM Stack'].visible = True
    viewer.layers['EM Stack'].visible = False
    selected_indices = lm_points_layer.selected_data
    if selected_indices:
        idx = next(iter(selected_indices))
        point = lm_points_layer.data[idx]
        viewer.camera.center = point
        viewer.dims.set_point(0, point[0])
        viewer.layers.selection.active = lm_points_layer

@viewer.bind_key('e')
def jump_to_selected_em(viewer, overwrite=True):
    """Press 'e' to jump to the selected EM landmark."""
    viewer.layers['EM Stack'].visible = True
    viewer.layers['LM Stack'].visible = False
    selected_indices = em_points_layer.selected_data
    if len(selected_indices) == 0:
        on_lm_selection_pair_em(None)
    if selected_indices:
        idx = next(iter(selected_indices))
        point = em_points_layer.data[idx]
        viewer.camera.center = point
        viewer.dims.set_point(0, point[0])
        viewer.layers.selection.active = em_points_layer


In [57]:
from qtpy.QtCore import Qt

def tag_selected_landmark(tag):
    rows = table_widget.table.selectionModel().selectedRows()
    if rows:
        idx = rows[0].row()
        landmark_manager.set_tag(idx, tag)
        table_widget.update_table()
        table_widget.table.selectRow(idx)
    else:
        print("No landmark selected to tag.")

overwrite = True
@viewer.bind_key('c', overwrite=overwrite)
def tag_confirmed(viewer):
    tag = 'confirmed'
    tag_selected_landmark(tag)

@viewer.bind_key('x', overwrite=overwrite)
def tag_ambiguous(viewer):
    tag = 'questioned'
    tag_selected_landmark(tag)

@viewer.bind_key('d', overwrite=overwrite)
def tag_delete(viewer):
    tag = 'delete'
    tag_selected_landmark(tag)


In [58]:
lm_points_layer.out_of_slice_display = True
em_points_layer.out_of_slice_display = True

point_size = 20
em_points_layer.current_size = point_size
lm_points_layer.current_size = point_size



In [59]:
from qtpy.QtWidgets import QShortcut
from qtpy.QtGui import QKeySequence

def select_next_landmark():
    """
    Select the next landmark row in the table, wrapping around if needed.
    """
    rows = table_widget.table.selectionModel().selectedRows()
    row_count = table_widget.table.rowCount()

    if row_count == 0:
        return  # no rows at all, do nothing

    if rows:
        current_index = rows[0].row()
        next_index = (current_index + 1) % row_count
    else:
        # if no row is selected yet, just select the first
        next_index = 0

    # Actually select the row in the table
    table_widget.table.selectRow(next_index)

# 1) Napari key binding for 'n'
@viewer.bind_key('Down')
def goto_next_landmark(event=None):
    select_next_landmark()


def select_previous_landmark():
    """
    Select the previous landmark row in the table (wrap to the last row if at the top).
    """
    rows = table_widget.table.selectionModel().selectedRows()
    row_count = table_widget.table.rowCount()
    
    if row_count == 0:
        return  # no rows at all, do nothing

    if rows:
        current_index = rows[0].row()
        prev_index = (current_index - 1) % row_count
    else:
        # if no row is selected yet, default to the last row
        prev_index = row_count - 1

    # Actually select the row in the table
    table_widget.table.selectRow(prev_index)

def goto_previous_landmark(viewer=None):
    """
    Wrapper that calls the logic to select the previous landmark.
    If you need additional viewer logic, add it here.
    """
    select_previous_landmark()

# -- Napari key binding for 'Up' arrow --
@viewer.bind_key('Up')
def _goto_previous_landmark_napari(viewer):
    goto_previous_landmark(viewer)

In [60]:
@viewer.bind_key('t')
def toggle_lm_points_visibility(viewer):
    lm_points_layer.visible = not lm_points_layer.visible

@viewer.bind_key('y')
def toggle_em_points_visibility(viewer):
    em_points_layer.visible = not em_points_layer.visible

In [61]:
# Define a global ViewerStatus object
# Track both modes using ViewerStatus
class ViewerStatus:
    def __init__(self):
        self.replace_landmark_idx = None
        self.adding_pair = False
        self.initial_lm_count = 0
        self.initial_em_count = 0

viewer_status = ViewerStatus()

# Step 1: Press 'r' to initiate EM landmark replacement
@viewer.bind_key('r')
def initiate_em_landmark_replacement(viewer):
    selected_rows = table_widget.table.selectionModel().selectedRows()
    if not selected_rows:
        print("No landmark selected in the table.")
        return

    viewer_status.replace_landmark_idx = selected_rows[0].row()

    # Switch viewer to EM stack only
    viewer.layers['EM Stack'].visible = True
    viewer.layers['LM Stack'].visible = False

    # Activate EM landmarks layer in add mode
    em_points_layer.mode = 'add'
    viewer.layers.selection.active = em_points_layer

    print(f"Place a new EM landmark to replace landmark '{landmark_manager.get_landmark_name(viewer_status.replace_landmark_idx)}'. Press Esc to cancel.")


# Step 3: Handle the replacement
def replace_em_landmark(event):
    if viewer_status.replace_landmark_idx is None:
        print("No landmark selected for replacement.")
        return  # Ignore if not in replace mode

    idx = viewer_status.replace_landmark_idx

    if len(em_points_layer.data) <= len(landmark_manager.df):
        return  # Not an addition event, ignore

    new_em_coords = em_points_layer.data[-1]

    print(f'before update em selected data: {em_points_layer.selected_data}')
    # Replace coordinates at the specific index (preserving landmark indices)
    em_points_layer.data[idx] = new_em_coords
    # Disconnect replace_em_landmark to avoid recursion
    em_points_layer.events.data.disconnect(replace_em_landmark)
    em_points_layer.data = em_points_layer.data[:-1]  # Remove the extra appended landmark
    # Select the modified landmark in em_points_layer
    em_points_layer.selected_data = {idx}

    # Update landmark manager DataFrame
    landmark_manager.update_landmark_coords(idx, em_coords=new_em_coords)

    # Refresh layers and UI
    table_widget.update_table()
    table_widget.table.selectRow(idx)

    print(f"Replaced EM coordinates for landmark '{landmark_manager.get_landmark_name(idx)}' at index {idx}.")

    # Reset EM points layer mode and state
    em_points_layer.mode = 'pan_zoom'
    viewer_status.replace_landmark_idx = None
    em_points_layer.events.data.connect(replace_em_landmark)
    
# Connect the permanent callback once
em_points_layer.events.data.connect(replace_em_landmark)


<function __main__.replace_em_landmark(event)>

In [62]:
def on_lm_added(event):
    if viewer_status.adding_pair:
        # Switch to EM layer for adding
        viewer.layers.selection.active = em_points_layer
        em_points_layer.mode = 'add'
        em_points_layer.events.data.connect(on_em_added)
        print("LM point added. Now add EM point.")

from qtpy.QtCore import QTimer

def on_em_added(event):
    if not viewer_status.adding_pair:
        return

    def finalize_em_point():
        new_em_count = len(em_points_layer.data)
        new_lm_count = len(lm_points_layer.data)

        if new_em_count <= viewer_status.initial_em_count or new_lm_count <= viewer_status.initial_lm_count:
            print("New EM or LM point not found yet.")
            return
        
        if len(landmark_manager.df) > viewer_status.initial_em_count:
            print("Landmark manager already has more points than initial EM count.")
            return

        new_em_point = em_points_layer.data[viewer_status.initial_em_count]
        new_lm_point = lm_points_layer.data[viewer_status.initial_lm_count]

        landmark_manager.add_pair(new_em_point, new_lm_point)
        table_widget.update_table()

        viewer_status.adding_pair = False
        em_points_layer.mode = 'pan_zoom'
        lm_points_layer.mode = 'pan_zoom'

        print(f"Added new pair")

        # Reconnect selection sync
        lm_points_layer.selected_data.events.items_changed.connect(on_lm_selection_pair_em)
        lm_points_layer.selected_data.events.items_changed.connect(on_lm_selection_update_table)

        # Disconnect temp listeners
        lm_points_layer.events.data.disconnect(on_lm_added)
        em_points_layer.events.data.disconnect(on_em_added)

    # Delay execution until the event loop has updated .data
    QTimer.singleShot(0, finalize_em_point)



@viewer.bind_key('Escape')
def cancel_operations(viewer):
    if viewer_status.adding_pair:
        print("Adding pair canceled.")
        # Remove any new points by slicing to original counts
        lm_points_layer.events.data.disconnect(on_lm_added)
        try:
            em_points_layer.events.data.disconnect(on_em_added)
        except Exception as e:
            pass
        lm_points_layer.data = lm_points_layer.data[:viewer_status.initial_lm_count]
        em_points_layer.data = em_points_layer.data[:viewer_status.initial_em_count]
        viewer_status.adding_pair = False
        em_points_layer.mode = 'pan_zoom'
        lm_points_layer.mode = 'pan_zoom'
        lm_points_layer.selected_data.events.items_changed.connect(on_lm_selection_pair_em)
        lm_points_layer.selected_data.events.items_changed.connect(on_lm_selection_update_table)



@viewer.bind_key('Space')
def start_adding_pair(viewer):
    if viewer_status.adding_pair:
        print("Already in adding mode. Press Esc to cancel.")
        return
    
    viewer_status.adding_pair = True
    viewer_status.initial_lm_count = len(lm_points_layer.data)
    viewer_status.initial_em_count = len(em_points_layer.data)
    print('Initial counts:', viewer_status.initial_lm_count, viewer_status.initial_em_count)
    # Activate LM layer and set to add mode
    viewer.layers.selection.active = lm_points_layer
    lm_points_layer.mode = 'add'
    print("Adding new pair. Add LM point first, then EM point. Press Esc to cancel.")
    lm_points_layer.selected_data.events.items_changed.disconnect(on_lm_selection_pair_em)
    lm_points_layer.selected_data.events.items_changed.disconnect(on_lm_selection_update_table)
    # Connect the event handlers
    lm_points_layer.events.data.connect(on_lm_added)




In [63]:

# Qt Shortcut so that key bindings works even when the table has focus
# shortcut_n = QShortcut(QKeySequence('n'), table_widget)
# shortcut_n.activated.connect(select_next_landmark)

from qtpy.QtWidgets import QShortcut
from qtpy.QtGui import QKeySequence
from functools import partial

# Define each key + its corresponding function (without calling them yet)
shortcuts = [
    ('c', tag_confirmed),
    ('x', tag_ambiguous),
    ('d', tag_delete),
    ('w', jump_to_selected_lm),
    ('e', jump_to_selected_em),
    ('Down', goto_next_landmark),
    ('Up', goto_previous_landmark),
    ('t', toggle_lm_points_visibility),
    ('y', toggle_em_points_visibility),
    ('r', initiate_em_landmark_replacement),
    ('Escape', cancel_operations),
    ('u', display_lm_stack),
    ('i', display_em_stack),
    ('Space', start_adding_pair),
]

# Create a list to hold references to the shortcuts (so they don’t get garbage-collected)
table_widget.shortcuts = []

for key, func in shortcuts:
    shortcut = QShortcut(QKeySequence(key), table_widget)
    # Use functools.partial to pass the 'viewer' argument into the function
    shortcut.activated.connect(partial(func, viewer))
    table_widget.shortcuts.append(shortcut)
