In [None]:
from utils import load_video, ensure_thw
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib import patheffects as pe
from itertools import cycle

In [None]:
#example_path = "../../data/output/ants/method_sweep/strong/v7/run_7f6a7e57/artifacts/3czi.mp4"
example_path = "../../data/input/strong_movement/3czi.tif"
#example_path = "../../data/output/cotracker/template_index/strong/v7/run_057af8d9/artifacts/3czi.mp4"
frames, _ = load_video(example_path, order="CTHW")

In [None]:
frames = np.array(frames)
print(frames.shape)

In [None]:
np.max(frames)

In [None]:
frames

In [None]:
np.min(frames)

In [None]:
np.mean(frames)

In [None]:
plt.imshow(frames[0])

In [None]:
def get_neighbor_coords(image: np.ndarray, x: int, y: int) -> list[tuple[int, int]]:
    H, W = image.shape[0], image.shape[1]
    x_neighbors = [-1, 0, 1]
    y_neighbors = [-1, 0, 1]
    neighbor_coords = []
    for dx in x_neighbors:
        for dy in y_neighbors:

            nx, ny = x + dx, y + dy
            if 0 <= nx < W and 0 <= ny < H:
                if not (dx == 0 and dy == 0):
                    neighbor_coords.append((nx, ny))
    return neighbor_coords


In [None]:
def get_corr_with_neighbors(image_stack: np.ndarray, x: int, y:int) -> float:
    """:param image_stack: (N, H, W) array of N timepoints of HxW images
    :param x: x coordinate of the pixel
    :param y: y coordinate of the pixel
    """
    neighbor_coords = get_neighbor_coords(image_stack[0], x, y)

    pixel_fl_time_series = image_stack[:, y, x]
    neighbors_fl_time_series = np.stack([image_stack[:, ny, nx] for nx, ny in
                                         neighbor_coords], axis=1)
    neighbors_mean_time_series = neighbors_fl_time_series.mean(axis=1)

    if pixel_fl_time_series.std() == 0.0 or neighbors_mean_time_series.std() == 0.0:
        return 0.0
    corr = np.corrcoef(pixel_fl_time_series, neighbors_mean_time_series)[0, 1]
    return corr

In [None]:
corr_test_pixel = get_corr_with_neighbors(frames, 0, 0)
print(corr_test_pixel)

In [None]:
def calculate_all_corrs(image_stack: np.ndarray) -> np.ndarray:
    H, W = image_stack.shape[1], image_stack.shape[2]
    corrs = np.zeros_like(image_stack[0])
    for x in range(W):
        for y in range(H):
            corr = get_corr_with_neighbors(image_stack, x, y)
            corrs[y, x] = corr
    #corrs = (corrs - corrs.min()) / (corrs.max() - corrs.min())
    return corrs

In [None]:
from scipy.signal import convolve2d

def neighbor_mean_stack(X: np.ndarray) -> np.ndarray:
    """
    X: (N, H, W) float array
    returns M: (N, H, W) neighbor mean for each frame, excluding the center pixel.
    Edge pixels use the mean over available in-bounds neighbors (3/5/8).
    """
    N, H, W = X.shape
    K = np.array([[1,1,1],
                  [1,0,1],
                  [1,1,1]], dtype=X.dtype)
    # counts of valid neighbors are the same for all frames; compute once
    ones = np.ones((H, W), dtype=X.dtype)
    denom = convolve2d(ones, K, mode='same', boundary='fill', fillvalue=0)  # 8 interior, 5 edges, 3 corners

    M = np.empty_like(X)
    for t in range(N):
        s = convolve2d(X[t], K, mode='same', boundary='fill', fillvalue=0)
        M[t] = s / denom
    return M

def corr_with_neighbors_all(X: np.ndarray) -> np.ndarray:
    """
    X: (N, H, W) — N timepoints of HxW images
    returns corrs: (H, W) Pearson correlation between each pixel's time series
                   and the mean of its neighbors' time series.
    """
    # 1) neighbor mean per frame (N,H,W)
    M = neighbor_mean_stack(X)

    N = X.shape[0]
    # 2) Compute timewise statistics without materializing huge intermediates
    sum_x  = X.sum(axis=0)                   # (H,W)
    sum_m  = M.sum(axis=0)                   # (H,W)
    sum_x2 = (X*X).sum(axis=0)               # (H,W)
    sum_m2 = (M*M).sum(axis=0)               # (H,W)
    sum_xm = (X*M).sum(axis=0)               # (H,W)

    # 3) Covariance and variance over time
    # cov(x,m) = E[xm] - E[x]E[m]
    cov_xm = sum_xm - (sum_x * sum_m) / N
    var_x  = sum_x2 - (sum_x * sum_x) / N
    var_m  = sum_m2 - (sum_m * sum_m) / N

    # 4) Pearson correlation with safe divide
    denom = np.sqrt(var_x * var_m)
    corrs = np.zeros_like(denom, dtype=X.dtype)
    mask = denom > 0
    corrs[mask] = (cov_xm[mask] / denom[mask]).astype(X.dtype)
    return corrs

In [None]:
def corr_local_on_diff(X: np.ndarray) -> np.ndarray:
    """
    X: (N,H,W) float array of frames
    Returns: (H,W) local correlation computed on ΔF = np.diff(X, axis=0)
    """
    dX = np.diff(X.astype(np.float32, copy=False), axis=0)  # (N-1,H,W)
    # optional but helpful: remove global frame component
    dX -= np.median(dX, axis=(1,2), keepdims=True)
    return corr_with_neighbors_all(dX)  # use the convolution-based impl you already have


In [None]:
corrs = corr_with_neighbors_all(frames)
#corrs = corr_local_on_diff(frames)

In [None]:
corrs

In [None]:
np.argmax(corrs)

In [None]:
np.max(corrs)

In [None]:
plt.imshow(corrs)

In [None]:
plt.imshow(frames[0])

In [None]:
def build_rois(corrs: np.ndarray, threshold=0.5) -> list:
    labels = np.full_like(corrs, fill_value=-1, dtype=int)
    label_id = 1
    H, W = corrs.shape
    for y in range(H):
        for x in range(W):
            if labels[y, x] == -1 and corrs[y, x] >= threshold:
                roi = build_roi_recursive(corrs, (y, x), threshold, labels, label_id)
                if roi:
                    for ry, rx in roi:
                        labels[ry, rx] = label_id
                    label_id += 1
                else:
                    labels[y, x] = 0  # Mark as visited but not part of ROI
            elif labels[y, x] == -1:
                labels[y, x] = 0
    return labels

In [None]:
def build_roi(corrs, threshold, visited):
    masked_corrs = corrs.copy()
    for y, x in visited:
        masked_corrs[y, x] = -np.inf
    roi = set()
    center = np.unravel_index(np.argmax(masked_corrs, axis=None), corrs.shape)
    roi.update(build_roi_recursive(corrs, center, threshold, visited))
    return roi


In [None]:
#def build_roi_recursive(corrs: np.ndarray, center: tuple, threshold=0.5,
#                        visited: set = None) ->set:
#    if visited is None:
#        visited = set()
#    if center in visited or corrs[center] < threshold:
#        return set()
#    visited.add(center)
#
#    roi = {center}
#    neighbors = get_neighbor_coords(corrs, center[1], center[0])
#    for x, y in neighbors:
#        if corrs[y, x] > threshold:
#            roi.update(build_roi_recursive(corrs, (y, x), threshold, visited))
#
#    return roi
#TODO: use iterative programming

In [None]:
def build_roi_recursive(corrs: np.ndarray, center: tuple, threshold: float,
                        labels: np.ndarray, label_id: int) -> set:
    y, x = center
    if labels[y, x] != -1 or corrs[y, x] < threshold:
        return set()


    roi = {center}
    labels[y, x] = label_id  # Temporarily mark, to prevent revisits

    neighbors = get_neighbor_coords(corrs, x, y)
    for nx, ny in neighbors:
        if labels[ny, nx] == -1 and corrs[ny, nx] >= threshold:
            roi.update(build_roi_recursive(corrs, (ny, nx), threshold, labels, label_id))
        elif labels[ny, nx] == -1:
            labels[ny, nx] = 0  # Visited but not part of ROI
    return roi

In [None]:
rois = build_rois(corrs, 0.9)



In [None]:
rois.max()


In [None]:
plt.imshow(rois)

In [None]:
def calculate_roi_brightness_over_time(image_stack: np.ndarray, labels: np.ndarray) -> dict[int, np.ndarray]:
    """
    :param image_stack: (T, H, W) array of images over time
    :param labels: (H, W) array with ROI labels (1, 2, ...) and 0/-1 for non-ROI
    :return: Dictionary mapping ROI label -> (T,) array of average brightness over time
    """
    roi_brightness = {}
    T = image_stack.shape[0]
    unique_labels = np.unique(labels)

    for label in unique_labels:
        if label <= 0:
            continue  # Skip background and visited non-ROI pixels

        # Find pixel indices belonging to the ROI
        roi_mask = labels == label
        roi_pixels = image_stack[:, roi_mask]  # Shape: (T, N_pixels)

        # Compute average over pixels for each timepoint
        roi_mean = roi_pixels.mean(axis=1)  # Shape: (T,)
        roi_brightness[label] = roi_mean

    return roi_brightness

In [None]:
roi_brightness = calculate_roi_brightness_over_time(frames, rois)

In [None]:
def plot_roi_brightness_curves(roi_brightness: dict[int, np.ndarray]):
    """
    Plots the average brightness over time for each ROI.

    :param roi_brightness: Dictionary mapping ROI label -> (T,) array of brightness values
    """
    plt.figure(figsize=(10, 6))

    for label, brightness_curve in roi_brightness.items():
        plt.plot(brightness_curve, label=f'ROI {label}')

    plt.xlabel("Time")
    plt.ylabel("Average Brightness")
    plt.title("ROI Brightness Over Time")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

In [None]:
plot_roi_brightness_curves(roi_brightness)

In [None]:
len(rois[0])

In [None]:
print(len(rois))

In [None]:
def plot_rois_on_image(frames: np.ndarray, labels: np.ndarray, base: str = "mean") -> None:
    """
    Show ROI areas as colored overlays on top of the base image.
    No labels or numbers, just filled ROI regions.
    """
    T, H, W = frames.shape

    # ----- choose base image -----
    if base == "mean":
        base_img, title_suffix = frames.mean(axis=0), "Mean Image"
    elif base == "max":
        base_img, title_suffix = frames.max(axis=0), "Max-Projection Image"
    elif base.startswith("frame:"):
        try:
            idx = int(base.split(":")[1])
        except Exception:
            idx = 0
        idx = max(0, min(T - 1, idx))
        base_img, title_suffix = frames[idx], f"Frame {idx}"
    else:
        base_img, title_suffix = frames.mean(axis=0), "Mean Image"

    plt.figure(figsize=(8, 8))
    plt.imshow(base_img, cmap="gray", interpolation="nearest")

    # color cycle for ROI fills
    cmap = plt.get_cmap("tab20")
    colors = cycle([cmap(i) for i in range(cmap.N)])

    unique_labels = [int(v) for v in np.unique(labels) if v > 0]
    for lab, color in zip(unique_labels, colors):
        mask = (labels == lab)

        # semi-transparent overlay
        overlay = np.zeros((H, W, 4))
        overlay[..., :3] = color[:3]
        overlay[..., 3] = 0.75 * mask  # alpha only on ROI pixels
        plt.imshow(overlay, interpolation="nearest")

    plt.title(f"ROIs over {title_suffix}")
    plt.axis("off")
    plt.tight_layout()
    plt.show()

In [None]:
plot_rois_on_image(frames, rois)