In [None]:
import fastplotlib as fpl
import os
import sys
import masknmf
import tifffile
import numpy as np
%load_ext autoreload

In [None]:
data = tifffile.imread("/home/kushal/amol_data/cortexlab/Subjects/SP067/2024-07-18/001/raw_imaging_data_00/2024-07-18_1_SP058_2P_00000_00001.tif")[:, 4110:]

In [None]:
data.shape

# If you don't have a good template estimate, run the generic template estimation procedure. End result is a PiecewiseRigidRegistrationStrategy object, used to register all frames to a template

In [None]:
strategy = masknmf.PiecewiseRigidMotionCorrector(
    num_blocks=[10, 10],
    overlaps=[5, 5],
    max_rigid_shifts=[15, 15],
    max_deviation_rigid=[2, 2],
    batch_size=100
)

In [None]:
strategy.compute_template(data)

In [None]:
import torch

In [None]:
torch.cuda.empty_cache()

In [None]:
iw_template = fpl.ImageWidget(strategy.template.cpu().numpy())
iw_template.show()

# Define a RegistrationArray that lazily loads motion corrected frames of the raw data

In [None]:
moco_results = masknmf.RegistrationArray(data, strategy)

In [None]:
from ipywidgets import VBox

In [None]:
shifts = moco_results.shifts[:]

shifts_stack = np.dstack([
    shifts[..., 0].reshape(shifts.shape[0], np.prod(shifts.shape[1:-1])),
    shifts[..., 1].reshape(shifts.shape[0], np.prod(shifts.shape[1:-1])),
]).transpose(1, 0, 2)

In [None]:
iw = fpl.ImageWidget(
    [data, moco_results],
    cmap="viridis",
    window_funcs={"t": (np.mean, 11)},
    figure_kwargs={"size": (1000, 500)},
)

fig = fpl.Figure()

shifts_lg = fig[0, 0].add_line(shifts_stack.max(axis=-1).max(axis=0), thickness=1.1)

sel = shifts_lg.add_linear_selector()

@sel.add_event_handler("selection")
def update(ev: dict | fpl.GraphicFeatureEvent):
    if isinstance(ev, fpl.GraphicFeatureEvent):
        index = ev.get_selected_index()
        iw.current_index = {"t": index}3916MiB
    else:
        index = ev["t"]
        sel.selection = index

iw.add_event_handler(update, "current_index")

VBox([iw.show(), fig.show()])

In [None]:
shifts = moco_results.shifts[:]

shifts_stack = np.dstack([
    shifts[..., 0].reshape(shifts.shape[0], np.prod(shifts.shape[1:-1])),
    shifts[..., 1].reshape(shifts.shape[0], np.prod(shifts.shape[1:-1])),
]).transpose(1, 0, 2)

fig[0, 0].add_line(shifts_stack.max(axis=-1).max(axis=0), colors="cyan", thickness=1.1)

In [None]:
moco_results.strategy._max_deviation_rigid = [1, 1]

In [None]:
moco_results.strategy._max_rigid_shifts = [0, 0]

In [None]:
shifts_stack.max(axis=0).shape

In [None]:
shifts = moco_results.shifts[:]

shifts_stack = np.dstack([
    shifts[..., 0].reshape(shifts.shape[0], np.prod(shifts.shape[1:-1])),
    shifts[..., 1].reshape(shifts.shape[0], np.prod(shifts.shape[1:-1])),
]).transpose(1, 0, 2)

In [None]:
shifts_stack.shape

In [None]:
def pretty_size(size):
    """Pretty prints a torch.Size object"""
    assert isinstance(size, torch.Size)
    return " x ".join(map(str, size))
def dump_tensors(gpu_only=True):
    """Prints a list of the Tensors being tracked by the garbage collector."""
    import gc
    total_size = 0
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj):
                if not gpu_only or obj.is_cuda:
                    print(
                        "%s:%s%s %s"
                        % (
                            type(obj).__name__,
                            " GPU" if obj.is_cuda else "",
                            " pinned" if obj.is_pinned else "",
                            pretty_size(obj.size()),
                        )
                    )
                    total_size += obj.numel()
            elif hasattr(obj, "data") and torch.is_tensor(obj.data):
                if not gpu_only or obj.is_cuda:
                    print(
                        "%s → %s:%s%s%s%s %s"
                        % (
                            type(obj).__name__,
                            type(obj.data).__name__,
                            " GPU" if obj.is_cuda else "",
                            " pinned" if obj.data.is_pinned else "",
                            " grad" if obj.requires_grad else "",
                            " volatile" if obj.volatile else "",
                            pretty_size(obj.data.size()),
                        )
                    )
                    total_size += obj.data.numel()
        except Exception as e:
            pass
    print("Total size:", total_size)

In [None]:
dump_tensors()

# Visualize with fastplotlib imagewidget

In [None]:
iw = fpl.ImageWidget(
    data=[data, moco_results, pwrigid_strategy.template.cpu().numpy()],
    names = ['raw data', 'motion corrected', 'template'],
    figure_shape=(1, 3),
    cmap="viridis",
    window_funcs={"t": (np.mean, 11)},
    figure_kwargs={"size": (1300, 800)},
)

x, y = moco_results.block_centers.transpose(-1, 0, 1)
u, v = moco_results.shifts[0].transpose(-1, 0, 1)

# positions of each vector as [n_points, 2] array
positions = np.column_stack([x.ravel(), y.ravel()])

# directions of each vector as a [n_points, 2] array
# scale down by 5 otherwise they're too big
directions = np.column_stack([u.ravel(), v.ravel()]) / 5

vector_field = iw.figure[0, 0].add_vector_field(
    positions=positions,
    directions=directions,
    alpha=0.7,
    alpha_mode="add",
    color="w",
)

@iw.add_event_handler
def update_vector_field(index):
    t = index["t"]

    # u, v = moco_results.shifts[t].transpose(-1, 0, 1)
    # directions = np.column_stack([u.ravel(), v.ravel()])
    
    # vector_field.directions = np.random.rand(*vector_field.directions.value.shape)


iw.show()

In [None]:
%%timeit
i = np.random.randint(0, moco_results.shape[0] - 11)
moco_results[i]

In [None]:
%%timeit
i = np.random.randint(0, moco_results.shape[0] - 11)
moco_results[i:i+11]

In [None]:
%load_ext viztracer

In [None]:
%%viztracer
i = np.random.randint(0, moco_results.shape[0] - 11)
moco_results[i:i+11]

In [None]:
1_000 / 50

In [None]:
iw.window_funcs = None

In [None]:
iw.close()

In [None]:
type(moco_results.block_centers)