### CONDA ENVIRONMENTS

For steps __1. preprocess__ and __2. mip__, `conda activate g5ht-pipeline`

For step __3. segment__, `conda activate segment-torch` or `conda activate torchcu129`

For step __4. spline, 5. orient, 6. warp, 7. reg__

## TODO:

1. I wonder if I computed a spline on each and every z slice and warped each, oriented each of them, and warped each of them, if the problem of weirdly sheared image stacks would be solved
2. quick mp4 for all recordings
   1. now working in engaging, works per one nd2 sbatch
3. focus check for all recordings
   1. maybe focus check can be used to specify which z slices are good to use and which frames are good to use
4. for recordings starting in december 2025, need to trim first 2 rather than last 2 z slices
5. flip worms so that VNC is always up
6. fixed mask could be automated, but if not, make sure to save which index is fixed
7. extract behavior
8. posture similarity
   1. posture might consist of the spline + thresholded z-stack
      1. I'm thinking that the orientation shouldn't matter, but the z-planes in focus will, and curvature/spline of the head will
      2. maybe need to actually interpolate to 117 z slices
   2. sub registration problems
   3. label each set of registered frames with one set of ROIs, or auto segment ROIs from each set of registered frames
9.  track z over time, which zslices are consistent
   1. focus + correlation
10. beads -> train/test
11. gfp+1 relative to rfp channel (might only apply to pre december 2025 recordings)
12. wholistic 
    1.  parameter sweep, might change
    2.  python version
13. autocorr/scorr
14. automate z slice trimming
    1.  pre december 2025 (trim last 2 z slices)
    2.  post december 2025 (trim first z slice)
15. photobleaching estimation?
    1.  record immo with serotonin
    2.  at least do it for RFP
16. try deltaF/F [ (F(t) - F0) / F0 ]
17. coding directions (preencounter-baseline) (postencounter-baseline)
    1.  then show voxel weights
18. port everything to engaging

In [55]:
import sys
import os
import importlib
from tqdm import tqdm

try:
    import utils
    is_torch_env = False
except ImportError:
    is_torch_env = True
    print("utils not loaded because conda environment doesn't have nd2reader installed. probably using torchcu129 env, which is totally fine for just doing the segmentation step")

## SPECIFY DATA TO PROCESS

In [65]:
DATA_PTH = r'D:\DATA\g5ht-free\20251028'

INPUT_ND2 = 'date-20251028_time-1500_strain-ISg5HT_condition-starvedpatch_worm001.nd2'

INPUT_ND2_PTH = os.path.join(DATA_PTH, INPUT_ND2)

NOISE_PTH = r'C:\Users\munib\POSTDOC\CODE\g5ht-pipeline\noise\noise_042925.tif'

OUT_DIR = os.path.splitext(INPUT_ND2_PTH)[0]

STACK_LENGTH = 41

if not is_torch_env:
    noise_stack = utils.get_noise_stack(NOISE_PTH, STACK_LENGTH)
    num_frames, height, width, num_channels = utils.get_range_from_nd2(INPUT_ND2_PTH, stack_length=STACK_LENGTH) 
    beads_alignment_file = utils.get_beads_alignment_file(INPUT_ND2_PTH)
else:
    print("utils not loaded because conda environment doesn't have nd2reader installed. probably using torchcu129 env, which is totally fine for just doing the segmentation step")

print(INPUT_ND2)
print('Num z-slices: ', STACK_LENGTH)
if not is_torch_env:
    print('Number of frames: ', num_frames)
    print('Height: ', height)
    print('width: ', width)
    print('Number of channels: ', num_channels)
    print('Beads alignment file: ', beads_alignment_file)

date-20251028_time-1500_strain-ISg5HT_condition-starvedpatch_worm001.nd2
Num z-slices:  41
Number of frames:  1200
Height:  512
width:  512
Number of channels:  2
Beads alignment file:  None


## 10. LABEL ROIs

`conda activate g5ht-pipeline`

- after this step, use `lbl` conda env to label ROI of fixed frame
  - run `labelme` in terminal


maybe also see here for video annotation: https://github.com/wkentaro/labelme/tree/main/examples/video_annotation

### EXPORT FIXED VOLUME AS PNGs for labeling with `labelme`

In [69]:
# code that exports each z-slice of fixed.tif as a separate png
import tifffile
import os
import glob
import scipy.ndimage as ndi

PTH = os.path.splitext(INPUT_ND2_PTH)[0]

# in PTH directory, find a fixed_XXXX*.tif file, where XXXX are digits
fixed_fn = glob.glob(os.path.join(PTH, 'fixed_[0-9][0-9][0-9][0-9]*.tif'))[0]
fixed_pth = os.path.join(PTH, fixed_fn)

# fixed_pth = os.path.join(PTH, 'fixed.tif')
# fixed_stack = ndi.zoom(tifffile.imread(fixed_pth), zoom=(3,1,1,1))
fixed_stack = tifffile.imread(fixed_pth)

out_dir = os.path.join(PTH, 'fixed_png')
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

for i in range(fixed_stack.shape[0]):
    slice_pth = os.path.join(out_dir, f'fixed_z{i:02d}.png')
    # make sure to save channel 1, and that it is visible, correct data type, clipped to 0-255
    slice_img = fixed_stack[i,1,:,:]
    slice_img = (slice_img - slice_img.min()) / (slice_img.max() - slice_img.min()) * 255
    slice_img = slice_img.astype('uint8')
    tifffile.imwrite(slice_pth, slice_img) 

out_dir = os.path.join(PTH, 'fixed_xz_png')
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

for i in range(fixed_stack.shape[2]):
    slice_pth = os.path.join(out_dir, f'fixed_xz{i:02d}.png')
    # make sure to save channel 1, and that it is visible, correct data type, clipped to 0-255
    slice_img = fixed_stack[:,1,i,:]
    slice_img = (slice_img - slice_img.min()) / (slice_img.max() - slice_img.min()) * 255
    slice_img = slice_img.astype('uint8')
    tifffile.imwrite(slice_pth, slice_img) 

out_dir = os.path.join(PTH, 'fixed_yz_png')
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

for i in range(fixed_stack.shape[3]):
    slice_pth = os.path.join(out_dir, f'fixed_yz{i:02d}.png')
    # make sure to save channel 1, and that it is visible, correct data type, clipped to 0-255
    slice_img = fixed_stack[:,1,:,i]
    slice_img = (slice_img - slice_img.min()) / (slice_img.max() - slice_img.min()) * 255
    slice_img = slice_img.astype('uint8')
    # save as high quality tiff
    tifffile.imwrite(slice_pth, slice_img) 

### PARSE OUTPUT OF `labelme`

- outputs `roi.tif`

In [70]:
import numpy as np
import json
from skimage.draw import polygon

PTH = os.path.splitext(INPUT_ND2_PTH)[0]

out_dir = os.path.join(PTH, 'fixed_png')
fixed_fn = glob.glob(os.path.join(PTH, 'fixed_[0-9][0-9][0-9][0-9]*.tif'))[0]
fixed_pth = os.path.join(PTH, fixed_fn)
fixed_stack = tifffile.imread(fixed_pth)
Z,C,H,W = fixed_stack.shape

roi = np.zeros((Z, H, W), dtype=fixed_stack.dtype) # ZHW

# get all unique roi_labels from json files
roi_json_files = glob.glob(os.path.join(out_dir, 'fixed_z[0-9][0-9]*.json'))
roi_labels = []
for roi_json_file in roi_json_files:
    with open(roi_json_file, 'r') as f:
        roi_dict = json.load(f)
        roi_labels.append([shape['label'] for shape in roi_dict['shapes']])
# get all unique roi_labels
roi_labels = list(set([item for sublist in roi_labels for item in sublist]))
print(roi_labels)

# # roi_labels = ['PC','MC','IM','TB','NR','VNC','DNC']
# # procorpus, metacorpus, isthmus, terminal bulb, nerve ring, ventral nerve cord, dorsal nerve cord

for i in range(Z):
    slice_roi_json = os.path.join(out_dir, f'fixed_z{i:02d}.json')
    # if slice_roi_json doesn't exist, continue
    if not os.path.exists(slice_roi_json):
        continue
    with open(slice_roi_json, 'r') as f:
        roi_dict = json.load(f)
        # loop through each shape in roi_dict['shapes']
        for shape in roi_dict['shapes']:
            label = shape['label']
            if label in roi_labels:
                points = shape['points']
                # get integer coordinates
                points = [(int(round(p[1])), int(round(p[0]))) for p in points]
                # create a mask for the polygon
                
                rr, cc = polygon([p[0] for p in points], [p[1] for p in points], shape=(H,W))
                # should set to correct z slice
                roi[i, rr, cc] = roi_labels.index(label) + 1 # start from 1

# save roi stack as tif image, imagej=true and save the roi labels as metadata
roi_pth = os.path.join(PTH, 'roi.tif')
tifffile.imwrite(roi_pth, roi.astype(np.uint8), imagej=True, metadata={'Labels': roi_labels})

['NR', 'isthmus', 'DNC', 'VNC']


## 11. QUANTIFY

`conda activate g5ht-pipeline`

Have to first label dorsal and ventral nerve rings and pharynx. See ...

In [73]:
import sys
import os
import quantify
from numpy import genfromtxt
import matplotlib.pyplot as plt
import numpy as np
import importlib

import matplotlib
font = {'family' : 'Arial',
        'weight' : 'normal',
        'size'   : 15}
matplotlib.rc('font', **font)
matplotlib.rcParams['svg.fonttype'] = 'none'

_ = importlib.reload(sys.modules['quantify'])

PTH = os.path.splitext(INPUT_ND2_PTH)[0]
REG_DIR = r'registered_elastix'
PLOT_ONLY = True
time_type = 'frame'  # 'min', 'sec', or 'frame'
baseline_window = (0,60) # specify baseline frame range, default is (0, 60)

# %matplotlib inline
%matplotlib qt
# %matplotlib notebook
# %matplotlib widget

plt.close('all')

sys.argv = ["", PTH, REG_DIR, baseline_window, PLOT_ONLY, time_type]
quantify.main()

## 12 QUANTIFY VOXELS

In [72]:
import sys
import os
import quantify_voxels
from numpy import genfromtxt
import matplotlib.pyplot as plt
import numpy as np
import tifffile
from skimage.morphology import erosion, disk
import importlib
import glob

import matplotlib
%matplotlib widget
font = {'family' : 'DejaVu Sans',
        'weight' : 'normal',
        'size'   : 15}
matplotlib.rc('font', **font)

_ = importlib.reload(sys.modules['quantify_voxels'])

PTH = r'D:\DATA\g5ht-free\20251028\date-20251028_time-1500_strain-ISg5HT_condition-starvedpatch_worm001'
reg_dir = 'registered_elastix'
bin_factor = 2
fps = 1/0.533
baseline_window = (0,60) # specify baseline frame range, default is (0, 60)

sys.argv = ["", PTH, reg_dir, bin_factor, baseline_window]
quantify_voxels.main()

Processing 1200 files using 24 workers...


Processing stacks: 100%|██████████| 1200/1200 [00:32<00:00, 36.81it/s]


Processed data shape: (1200, 39, 100, 250)
Binning factor: 2
Saved normalized data (ratiometric) to D:\DATA\g5ht-free\20251028\date-20251028_time-1500_strain-ISg5HT_condition-starvedpatch_worm001\normalized_voxels.npy


In [None]:
# load normalized_voxels.npy, also been binned
g5 = np.load(os.path.join(PTH, 'normalized_voxels.npy'))
rfp_mean = np.load(os.path.join(PTH, 'rfp_mean.npy'))
gfp_mean = np.load(os.path.join(PTH, 'gfp_mean.npy'))
baseline = np.load(os.path.join(PTH, 'baseline.npy'))

# load mask, bin it
# find the fixed_mask_*.tif file in PTH directory
try:
    fixed_mask_fn = glob.glob(os.path.join(PTH, 'fixed_mask_*.tif'))[0]
except:
    fixed_mask_fn = glob.glob(os.path.join(PTH, 'fixed_mask*.tif'))[0]

mask = tifffile.imread(fixed_mask_fn)

h, w = mask.shape
h_binned = h // bin_factor
w_binned = w // bin_factor
# binning of mask
mask_binned = mask.reshape(h_binned, bin_factor, w_binned, bin_factor).max(axis=(1,3))

In [None]:
# find percentage of time a voxel is zero within the worm mask
g5_masked = g5 * mask_binned[np.newaxis, :, :]

g5_masked.shape

# g5_masked is shape (T, Z, H, W) array with zeros outside the worm region
# find all voxels where g5_masked is zero for each time point
zero_mask = g5_masked == 0
# calculate the probability of a voxel being zero across time, but don't include time points where the voxel is masked out (i.e., outside the worm)
zero_prob = np.mean(zero_mask, axis=0)
# set zero probability to NaN for voxels outside the worm (mask binned needs a z dimension added)
zero_prob[mask_binned[np.newaxis].repeat(zero_prob.shape[0], axis=0) == 0] = np.nan

In [None]:
print(g5.shape)
print(rfp_mean.shape)
print(gfp_mean.shape)
print(baseline.shape)
print(mask_binned.shape)

In [None]:


plt.close('all')
%matplotlib qt
# plot zero_prob as an image, with colorbar, for each z slice in subplots
fig, axes = plt.subplots(4, 10, figsize=(20, 8), constrained_layout=True)
for z in range(zero_prob.shape[0]):
    ax = axes[z // 10, z % 10]
    im = ax.imshow(zero_prob[z, :, :], cmap='viridis', vmin=0, vmax=1)
    ax.set_title(f'Z={z}')
    ax.axis('off')
    # replace last subplot with colorbar
    if z == zero_prob.shape[0] - 1:
        fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
# delete last subplots if z slices are less than 40
for z in range(zero_prob.shape[0], 40):
    ax = axes[z // 10, z % 10]
    ax.axis('off')
plt.show()

plt.figure()
plt.hist(zero_prob.flatten(), bins=50, color='blue', alpha=0.7)
plt.xlabel('Probability of Voxel Being Zero')
plt.ylabel('Frequency')
plt.title('Histogram of Zero Probability Across Voxels')
plt.show()

In [None]:
probability_threshold = 0.3 # if a voxel is zero more than this fraction of the time, consider it outside the worm or a bad voxel
# create a mask of good voxels
good_voxel_mask = zero_prob < probability_threshold
normalized_data = g5_masked * good_voxel_mask[np.newaxis].repeat(g5_masked.shape[0], axis=0)

In [None]:
normalized_data.shape

In [None]:
norm_data_mean = np.mean(normalized_data, axis=0)

# clip norm_data mean between 0th and 99th percentiles
p0 = np.percentile(norm_data_mean, 0)
p99 = np.percentile(norm_data_mean, 99)
norm_data_mean = np.clip(norm_data_mean, p0, p99)

z2plot = [0,5,15,25,35]
# z2plot = [0,5]

%matplotlib inline
plt.close('all')
fig, axes = plt.subplots(len(z2plot), 3, figsize=(15, 9), constrained_layout=True)
for i in z2plot: # for each z slice
    plt.subplot(len(z2plot), 3, z2plot.index(i)*3+1)
    plt.pcolor(gfp_mean[i,:,:])
    plt.title(f'GFP Mean Z={i}')
    # plt.colorbar()

    plt.subplot(len(z2plot), 3, z2plot.index(i)*3+2)
    plt.pcolor(rfp_mean[i,:,:])
    plt.title(f'RFP Mean Z={i}')
    # plt.colorbar()

    plt.subplot(len(z2plot), 3, z2plot.index(i)*3+3)
    plt.pcolor(norm_data_mean[i,:,:])
    # add colorbar outside to the right
    plt.colorbar()
    # plt.clim(0, 12)
    # make all axes not be squished
plt.show()


In [None]:
# cluster normalized_data using k means
from sklearn.cluster import KMeans
n_clusters = 10
kmeans = KMeans(n_clusters=n_clusters, random_state=0)
T, Z, H, W = normalized_data.shape
normalized_data_reshaped = normalized_data.reshape(T, Z*H*W).T  # shape (Z*H*W, T)
kmeans.fit(normalized_data_reshaped)


In [None]:
# plot cluster mean activity over time
cluster_means = np.zeros((n_clusters, T))
for c in range(n_clusters):
    cluster_means[c, :] = normalized_data_reshaped[kmeans.labels_ == c, :].mean(axis=0)

In [None]:
plt.figure(figsize=(10,6))
t = np.arange(T) * (1/fps)
ic = 0
for c in range(6,9):#range(n_clusters):
    plt.plot(t, cluster_means[c, :], label=f'Cluster {ic+1}')
    ic += 1
plt.legend(frameon=False)
plt.xlabel('Time (sec)')
plt.ylabel('$\ R / R_{baseline}$')
plt.show()

In [None]:
# plot cluster 1 and 2 spatial maps at each z slice
plt.close('all')
ic = 0
for c in range(6,9):#range(n_clusters):
    fig, axes = plt.subplots(4, 10, figsize=(20, 8), constrained_layout=True)
    for z in range(Z):
        ax = axes[z // 10, z % 10]
        cluster_map = kmeans.labels_.reshape(Z, H, W)[z, :, :] == c
        im = ax.imshow(cluster_map, cmap='gray')
        ax.set_title(f'Cluster {ic+1} Z={z}')
        ax.axis('off')
    plt.suptitle(f'Spatial Map of Cluster {c+1} Across Z Slices')
    ic += 1
    # deleate last subplots if z slices are less than 40
    for z in range(Z, 40):
        ax = axes[z // 10, z % 10]
        ax.axis('off')
    plt.show()

In [None]:
# PCA
from sklearn.decomposition import PCA
# flatten normalized_data to (T, Z*H*W)
T, Z, H, W = normalized_data.shape
data_reshaped = normalized_data.reshape(T, Z*H*W)
pca = PCA(n_components=5)
pca.fit(data_reshaped)

In [None]:

# get scores, plot
scores = pca.transform(data_reshaped)
# plot the first 5 principal component scores
for i in range(5):
    plt.figure()
    plt.plot(scores[:,i])
    plt.show()

# plot the first 5 principal component weights as images across z slices
components = pca.components_



In [None]:

# plot gfp_mean, rfp_mean, dat in a (5,3) grid of subplots (5 z slices, 3 columns)


# divide gfp_mean by rfp_mean, but only if rfp_mean is not zero

dat = gfp_mean / rfp_mean
# dat = np.divide(gfp_mean, rfp_mean, out=np.zeros_like(gfp_mean), where=rfp_mean!=0)
# divide each voxel in dat by its baseline, but only if baseline is not zero
dat = np.divide(dat, baseline, out=np.zeros_like(dat), where=baseline!=0)
dat = dat * mask_binned[np.newaxis, :, :]

# remove outliers by clipping to 1st and 99th percentile
p1 = np.percentile(dat, 0)
p99 = np.percentile(dat, 99.9)
dat = np.clip(dat, p1, p99)


# plt.figure()
# plt.hist(baseline.ravel(), bins=1000, range=(0,100))
# plt.show()

# z2plot = [0,5,15,25,35]
z2plot = [0,5]

%matplotlib inline
plt.close('all')
fig, axes = plt.subplots(len(z2plot), 3, figsize=(15, 5), constrained_layout=True)
for i in z2plot: # for each z slice
    plt.subplot(len(z2plot), 3, z2plot.index(i)*3+1)
    plt.pcolor(gfp_mean[i,:,:])
    plt.title(f'GFP Mean Z={i}')
    # plt.colorbar()

    plt.subplot(len(z2plot), 3, z2plot.index(i)*3+2)
    plt.pcolor(rfp_mean[i,:,:])
    plt.title(f'RFP Mean Z={i}')
    # plt.colorbar()

    plt.subplot(len(z2plot), 3, z2plot.index(i)*3+3)
    plt.pcolor(dat[i,:,:])
    # add colorbar outside to the right
    plt.colorbar()
    plt.clim(0, 12)
    # make all axes not be squished
plt.show()


In [None]:
# load mask, bin
# find the fixed_mask_*.tif file in PTH directory
try:
    fixed_mask_fn = glob.glob(os.path.join(PTH, 'fixed_mask_*.tif'))[0]
except:
    fixed_mask_fn = glob.glob(os.path.join(PTH, 'fixed_mask*.tif'))[0]

mask = tifffile.imread(fixed_mask_fn)

h, w = mask.shape
h_binned = h // bin_factor
w_binned = w // bin_factor
# binning of mask
mask_binned = mask.reshape(h_binned, bin_factor, w_binned, bin_factor).max(axis=(1,3))
# shrink the mask slightly using erosion skimage
mask_binned = erosion(mask_binned, disk(5))

# zero out values outside the mask
g5_masked = g5 * mask_binned[np.newaxis,np.newaxis,:,:]



In [None]:
g5_masked.shape

In [None]:
g5_masked_trimmed  = g5_masked[:, 5:27, :,40:230]
g5_masked_trimmed.shape

In [None]:
plt.close('all')

# get the intensity over time of the voxels in the middle of the (x,y) plane, definding a box, and all z slices
middle_x = int(g5_masked_trimmed.shape[3] // 1.5)
middle_y = int(g5_masked_trimmed.shape[2] // 1.95)
box_size = 10  # Define the size of the box around the middle point
middle_voxels = g5_masked_trimmed[:, 5:27, middle_y-box_size:middle_y+box_size, middle_x-box_size:middle_x+box_size].mean(axis=(2,3))
# divide each z slice by the mean activity in the first 60 time points to get F/F0
middle_voxels = middle_voxels / (middle_voxels[:60,:].mean(axis=0) + 1e-6)

t = np.arange(middle_voxels.shape[0]) / fps

plt.close('all')

plt.figure(figsize=(10, 6))
plt.plot(middle_voxels)
plt.xlabel('Time (sec)')
plt.ylabel('$F/F_{\\mathrm{baseline}}$')
# plt.ylim(0, 5)
plt.show()
# plot same thing as heatmap and sort by mean signal across time for each z slice
sorted_indices = np.argsort(middle_voxels.mean(axis=0))
middle_voxels_sorted = middle_voxels[:, sorted_indices]

plt.figure(figsize=(10, 6))
plt.pcolor(t, np.arange(middle_voxels.shape[1]), middle_voxels_sorted.T, cmap='plasma',vmin=0,vmax=5)
plt.colorbar(label='$F/F_{\\mathrm{baseline}}$')
plt.xlabel('Time (sec)')
plt.ylabel('Z slice (sorted by mean intensity)')
plt.show()

# plot g5_masked and the box defined by box_size and middle_x, middle_y to confirm that the box is in the middle of the (x,y) plane and covers the expected area
for i in range(g5_masked_trimmed.shape[1]):
    plt.figure(figsize=(10, 4))
    plt.pcolor(g5_masked_trimmed[158,i,:,:],vmin=0, vmax=10)
    plt.plot([middle_x-box_size, middle_x+box_size, middle_x+box_size, middle_x-box_size, middle_x-box_size],
             [middle_y-box_size, middle_y-box_size, middle_y+box_size, middle_y+box_size, middle_y-box_size], color='red')
    plt.show()

# GRID ANALYSIS

In [None]:
import grid_analysis
import sys
import importlib
_ = importlib.reload(sys.modules['grid_analysis'])

PTH = r'D:\DATA\g5ht-free\20251028\date-20251028_time-1500_strain-ISg5HT_condition-starvedpatch_worm001'

xspace = 10
yspace = 10
fps = 1/0.533
bin_factor = 1
z_start = 5
z_end = 25
baseline_start_sec = 5 # start baseline calculation here (relative to the start of the recording, not relative to start_time_sec)
baseline_end_sec = 35 # end baseline calculation here (relative to the start of the recording, not relative to start_time_sec)
start_time_sec = 0
time_food_sec = 435
sys.argv = ["", PTH, xspace, yspace, fps, bin_factor, z_start, z_end, baseline_start_sec, baseline_end_sec, start_time_sec]#time_food_sec] 
grid_analysis.main()

In [None]:
import grid_analysis
_ = importlib.reload(sys.modules['grid_analysis'])
# load the grid_timeseries_flat_smoothed.npz file and plot the heatmap using the plot_heatmap function
data = np.load(os.path.join(PTH, 'grid_timeseries_flat_smoothed.npz'))
flat_timeseries_smoothed = data['timeseries']
z_labels = data['z_labels']
grid_analysis.plot_heatmap(flat_timeseries_smoothed, z_labels, fps=1/0.533, output_path=os.path.join(PTH, 'grid_heatmap.png'), data_start_time=start_time_sec)
# grid_analysis.plot_heatmap(flat_timeseries_smoothed, z_labels, fps=1/0.533, output_path=os.path.join(PTH, 'grid_heatmap.png'), data_start_time=start_time_sec)

In [None]:
flat_timeseries_smoothed.shape

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(np.arange(flat_timeseries_smoothed.shape[0])/fps,flat_timeseries_smoothed[:,20:27])
# plt.ylim(0.03,0.3)
plt.show()

In [None]:
np.unique(z_labels)

In [None]:
flat_timeseries_smoothed.shape

In [None]:
_ = importlib.reload(sys.modules['grid_analysis'])
# sort the grid squares by their mean signal across time, and plot the heatmap again with the grid squares in sorted order
mean_signal_grid = flat_timeseries_smoothed.mean(axis=0)
sorted_indices = np.argsort(mean_signal_grid)
flat_timeseries_smoothed_sorted = flat_timeseries_smoothed[:, sorted_indices]

plt.close('all')

# smooth some more
from grid_analysis import causal_smooth
toplot = causal_smooth(flat_timeseries_smoothed_sorted, sigma=1.0)
grid_analysis.plot_heatmap(toplot, z_labels, fps=1/0.533, output_path=os.path.join(PTH, 'grid_heatmap_sorted.png'), data_start_time=start_time_sec)
# plt.xlim(-50,125)
# plt.ylim(50,770)

# plot the mean signal across all grid squares for each z-slice over time, and mark the time of food addition
mean_signal_grid = flat_timeseries_smoothed.mean(axis=1)
plt.figure(figsize=(10, 6))
plt.plot(np.arange(flat_timeseries_smoothed.shape[0])/fps - start_time_sec, mean_signal_grid)
plt.axvline(0, color='red', linestyle='--', label='Food addition')
plt.xlabel('Time (sec)')
plt.ylabel('Mean $F/F_{\\mathrm{baseline}}$ across grid squares')
plt.legend()
# plt.xlim(-50,125)
# plt.ylim(0.03,0.3)
plt.show()

In [None]:
flat_timeseries_smoothed.shape

In [None]:
# pca on flat_timeseries_smoothed, which of size (time,features)
from sklearn.decomposition import PCA
import utils
pca = PCA(n_components=5)
pca.fit(flat_timeseries_smoothed)
# get scores
pc_scores = pca.transform(flat_timeseries_smoothed)
explained_variance = pca.explained_variance_ratio_
explained_variance


line_colors = ['C0', 'C1', 'C2', 'C3', 'C4']
ylim = (-12,14)
for i in range(5):
    fig, ax = utils.pretty_plot()
    ax.plot(np.arange(pc_scores.shape[0])/fps, pc_scores[:,i], label=f'PC{i+1}', lw=3, color=line_colors[i])
    ax.set_ylim(ylim)
    ax.set_xlabel('Time (sec)')
    plt.ylabel(f'PC{i+1} projection')
    plt.show()

# scree plot
fig, ax = utils.pretty_plot()
ax.bar(np.arange(len(explained_variance))+1, explained_variance*100)
ax.set_xlabel('PC')
ax.set_ylabel('Variance explained (%)')
plt.show()

In [None]:
# cluster flat_timeseries_smoothed, then plot heatmap with clusters indicated
from sklearn.cluster import KMeans
n_clusters = 3
kmeans = KMeans(n_clusters=n_clusters, random_state=0)
kmeans.fit(flat_timeseries_smoothed.T)
cluster_labels = kmeans.labels_


In [None]:
plt.figure()
plt.plot(sorted_timeseries)
plt.show()

In [None]:
# sort flat_timeseries_smoothed by cluster labels
sorted_indices = np.argsort(cluster_labels)
sorted_timeseries = flat_timeseries_smoothed[:, sorted_indices]

plt, ax = utils.pretty_plot(figsize=(10,6))
im = ax.pcolormesh(sorted_timeseries.T, cmap='plasma')
plt.show()

# plot mean of each cluster over time on the same axis
i = 0
for c in range(n_clusters-1):
    cluster_mean = flat_timeseries_smoothed[:, cluster_labels == c].mean(axis=1)
    # normalize between 0 and 1
    cluster_mean = (cluster_mean - cluster_mean.min()) / (cluster_mean.max() - cluster_mean.min())
    # smooth cluste means
    # cluster_mean = pca_analysis.causal_smooth(cluster_mean[:,np.newaxis], sigma=3.0)
    if i==0:
        fig, ax = utils.pretty_plot()
    ax.plot(np.arange(cluster_mean.shape[0])/fps, cluster_mean, lw=1)
    ax.set_xlabel('Time (sec)')
    # plt.ylabel(f'Cluster {c+1} mean $F/F_{{\\mathrm{{baseline}}}}$')
    i += 1
plt.show()

# cross correlation between cluster 1 and 2 means signals
from scipy.signal import correlate
cluster1_mean = flat_timeseries_smoothed[:, cluster_labels == 0].mean(axis=1)
cluster2_mean = flat_timeseries_smoothed[:, cluster_labels == 1].mean(axis=1)
# normalize means bwetween 0 and 1
cluster1_mean = (cluster1_mean - cluster1_mean.min()) / (cluster1_mean.max() - cluster1_mean.min())
cluster2_mean = (cluster2_mean - cluster2_mean.min()) / (cluster2_mean.max() - cluster2_mean.min())
# compute cross-correlation
corr = correlate(cluster1_mean - cluster1_mean.mean(), cluster2_mean - cluster2_mean.mean(), mode='full')
lags = np.arange(-len(cluster1_mean)+1, len(cluster1_mean))
# plot xcorr
fig, ax = utils.pretty_plot()
ax.plot(lags/fps, corr)
ax.set_xlabel('Lag (sec)')
ax.set_ylabel('Cross-correlation')
plt.show()

# PCA analysis

In [None]:
import pca_analysis
import sys
_ = importlib.reload(sys.modules['pca_analysis'])

PTH = r'D:\DATA\g5ht-free\20251028\date-20251028_time-1500_strain-ISg5HT_condition-starvedpatch_worm001'

# Arguments: input_dir, fps, bin_factor, z_start, z_end, lowpass_freq, n_components,
#            n_pcs_plot, baseline_start, baseline_end, start_time, time_offset
sys.argv = ["", PTH, 1/0.533, 1, 7, 18, 0, 50, 6, 5, 35, 0]
pca_analysis.main()

# DMD analysis

In [None]:
import dmd_analysis
import sys
_ = importlib.reload(sys.modules['dmd_analysis'])

PTH = r'D:\DATA\g5ht-free\20251028\date-20251028_time-1500_strain-ISg5HT_condition-starvedpatch_worm001'

# Arguments: input_dir, fps, bin_factor, z_start, z_end, lowpass_freq, dmd_rank, 
#            n_modes_plot, baseline_start, baseline_end, start_time, time_offset
sys.argv = ["", PTH, 1/0.533, 1, 7, 18, 0, 5, 6, 5, 35, 0]

dmd_analysis.main()

# DIFFUSION MODELING

- 20251223 worm005
- choose a z slice (z=26)
- choose frames (800-1000)
- make movie
- show gfp/rfp frame for 822

- load 2x2 binned, R/R0 
- load mask and bin

## Conceptual Explanation of the Reaction-Diffusion Model

### The Physical Picture

Imagine you're watching serotonin (or another signaling molecule) being released, spreading, and getting taken back up in a worm's nervous system. The fluorescence intensity you measure at each pixel reflects this dynamic process. The code models three key phenomena:

1. **Diffusion**: Molecules spread from high to low concentration regions, like ink dropped in water
2. **Decay/Reuptake**: Molecules are removed from the extracellular space (via transporters, degradation, etc.) at a rate proportional to their local concentration
3. **Release Sources**: Certain locations periodically release molecules into the system

### The Mathematical Model

The fluorescence field $i(x,y,t)$ evolves according to a **reaction-diffusion PDE**:

$$\frac{\partial i}{\partial t} = D \nabla^2 i - k \cdot i + s(x,y,t)$$

Where:
- $D$ is the **diffusion coefficient** (how fast molecules spread spatially)
- $k$ is the **decay rate** (how fast molecules are removed)
- $s(x,y,t)$ is the **source field** (where and when molecules are released)

### Discretization on an Irregular Domain

Since you have a tissue mask (not a regular rectangle), the code builds a **graph Laplacian**. Each masked pixel becomes a node, connected to its 4 neighbors. The Laplacian operator $L$ is a sparse matrix where:
- Off-diagonal entries are +1 for connected pixels
- Diagonal entries are −(number of neighbors)

This naturally implements **no-flux boundary conditions**: at the mask edge, there are simply fewer neighbors, so diffusion stops there automatically.

### Time Stepping (Semi-Implicit)

To advance the simulation forward in time, the code uses a **backward Euler** (or optionally Crank-Nicolson) scheme. Instead of naively computing "next = current + dt × derivative" (which can blow up), it solves:

$$(I - dt \cdot A) \cdot i_{t+1} = i_t + dt \cdot s_t$$

where $A = D \cdot L - k \cdot I$. This requires solving a linear system at each time step, but the matrix is **sparse** and gets **LU-factorized once**, making subsequent solves very fast.

### Learning the Source Field from Data

The source term $s(x,y,t)$ is unknown. The code learns it from the movie itself using **Non-negative Matrix Factorization (NMF)**:

1. **Preprocessing**: High-pass filter the movie temporally (remove slow baseline drift), then rectify (keep only positive values, since we're interested in "release-like" transient increases)

2. **NMF Decomposition**: Factor the preprocessed activity as:
   $$A \approx W \cdot H$$
   where $W$ is (time × M) and $H$ is (M × pixels). This finds M spatial patterns that, when weighted by time-varying amplitudes, reconstruct the activity.

3. **Interpretation**: 
   - $\Phi = H^T$ (N×M) are the **spatial source maps** — where release happens
   - $a_t = W$ (T×M) are the **time courses** — when each source is active

The source field is then: $s(t) = g \cdot \Phi \cdot a_t$, where $g$ is a global gain parameter.

### The Fitting Pipeline

The full fitting process is **staged**:

1. **Learn sources** (Φ, a_t) via NMF on the preprocessed movie — this identifies *where* and *when* release happens

2. **Fit parameters** (D, k, g) with sources fixed — the optimizer simulates the model with candidate parameters and minimizes the mismatch to observed data

3. **Optionally refine** the time courses with temporal smoothness regularization

### Parameter Fitting Details

The optimizer (L-BFGS-B) searches for D, k, g that minimize mean squared error between simulated and observed fluorescence. To enforce positivity, parameters are represented in log-space: optimizing $\log(D)$ ensures $D > 0$.

For efficiency with large images:
- The Laplacian is **sparse** (only ~4N nonzero entries for N pixels)
- The implicit matrix is **factorized once** per (D,k) pair
- Loss is computed on a **random subset** of pixels

### What the Results Tell You

- **D (pixels²/s)**: How fast the signal spreads spatially. Larger D means faster diffusion. Remember to scale by pixel size² if comparing across binning levels.
  - If your original pixel size is s (µm/pixel), and you bin by factor b, then fitted Dbinned​ relates to physical Dμm roughly by:
    - Dμm2/s≈Dbinned⋅(b s)^2

- **k (1/s)**: How fast the signal decays. The "half-life" is $\ln(2)/k$ seconds.

- **Φ (spatial maps)**: Where the sources are located. Each column is a probability-like distribution over pixels.

- **a_t (time courses)**: When each source is active. Peaks indicate release events.

- **R²**: Overall fit quality — how much variance the model explains.

### Key Design Choices

1. **Phenomenological, not biophysical**: The model fits the *fluorescence* dynamics directly, not the underlying concentration. This avoids needing to know the sensor's binding kinetics.

2. **Data-driven sources**: Rather than assuming source locations, NMF discovers them from the data's spatio-temporal structure.

3. **Implicit time stepping**: Ensures numerical stability even with large D or dt.

4. **Modular pipeline**: Each stage can be tuned independently (e.g., adjust NMF preprocessing, change number of sources, use Crank-Nicolson for better accuracy).

### REGULARIZATION

1. **L1 sparsity on source time courses** - encourages bursts rather than continuous activity
2. **Total Variation (TV) on sources** - penalizes frequent changes, promoting sparse "events"  
3. **D prior/regularization** - soft prior pushing D away from zero
4. **Source energy penalty** - limits total source "power" to force more diffusion 


## Regularization Added to diffusion_v2.py

I've added optional regularization to prevent D from collapsing to near-zero. The key additions:

### `RegularizationConfig` dataclass
Located around line 900, this controls all regularization terms:

```python
reg_config = RegularizationConfig(
    # Source sparsity (temporal)
    source_l1_weight=0.1,    # L1 penalty: ∑|a_t| — promotes sparse firing
    source_tv_weight=0.05,   # TV penalty: ∑|a_t - a_{t-1}| — promotes discrete events
    source_energy_weight=0.0, # L2 penalty: ∑a_t² — limits total source power
    
    # D priors (prevent collapse)
    D_min_penalty_weight=0.001,  # 1/D penalty — blows up as D→0
    D_prior_weight=0.5,          # Gaussian prior on log(D)
    D_prior_mean=0.05,           # Prior mean for D
    
    # Source gain limit
    g_max_penalty_weight=0.0,    # Penalty if g exceeds threshold
    g_max_threshold=5.0          # Threshold for g penalty
)
```

### How It Works

The regularized loss is:

$$
\mathcal{L}_{\text{reg}} = \mathcal{L}_{\text{data}} + \lambda_1 \|a_t\|_1 + \lambda_{TV} \text{TV}(a_t) + \frac{\alpha}{D} + \beta (\log D - \log D_0)^2
$$

| Term | Effect |
|------|--------|
| `source_l1_weight` | Forces sources to be sparse in time (few frames active) |
| `source_tv_weight` | Forces sources to have discrete on/off events |
| `D_min_penalty_weight` | $\frac{1}{D}$ penalty prevents D→0 |
| `D_prior_weight` | Soft prior pulling D toward `D_prior_mean` |

### Usage

```python
from diffusion_v2 import RegularizationConfig, fit_diffusion_model_2d

reg_config = RegularizationConfig(
    source_l1_weight=0.1,
    D_prior_weight=0.5,
    D_prior_mean=0.05  # Your expected D value
)

result = fit_diffusion_model_2d(
    Y, mask, dt,
    reg_config=reg_config,  # Pass the config
    ...
)
```

### Tuning Tips

1. **If D still collapses**: Increase `D_min_penalty_weight` or `D_prior_weight`
2. **If sources are too active**: Increase `source_l1_weight` 
3. **If sources are noisy/oscillatory**: Increase `source_tv_weight`
4. **Set `D_prior_mean`** to your expected diffusion coefficient from physics/calibration

Made changes.



## SAVE TRIMMED DATA

In [None]:
import sys
import os
import quantify_voxels
from numpy import genfromtxt
import matplotlib.pyplot as plt
import numpy as np
import tifffile
from skimage.morphology import erosion, disk
import importlib
import glob

import matplotlib
%matplotlib widget
font = {'family' : 'DejaVu Sans',
        'weight' : 'normal',
        'size'   : 15}
matplotlib.rc('font', **font)

_ = importlib.reload(sys.modules['quantify_voxels'])

PTH = r'D:\DATA\g5ht-free\20251223\date-20251223_strain-ISg5HT_condition-starvedpatch_worm005'
bin_factor = 2
fps = 1/0.533
frames = (800,1000)  # only analyze frames from 800 to 1000
z = 26
zero_prob_threshold = 0.3 # remove voxels that are zero more than this fraction of the time


# load normalized_voxels.npy (R/R0), already been binned by bin_factor
g5 = np.load(os.path.join(PTH, 'normalized_voxels.npy'))
# load mask, bin it
fixed_mask_fn = glob.glob(os.path.join(PTH, 'fixed_mask_*.tif'))[0]
mask = tifffile.imread(fixed_mask_fn)
h, w = mask.shape
h_binned = h // bin_factor
w_binned = w // bin_factor
mask_binned = mask.reshape(h_binned, bin_factor, w_binned, bin_factor).max(axis=(1,3))
# mask g5 data
g5_trimmed = g5[frames[0]:frames[1], z, :, :]

# remove voxels that are only present a small percentage of the time
# find percentage of time a voxel is zero within the worm mask
g5_masked = g5 * mask_binned[np.newaxis, :, :]

# g5_masked is shape (T, Z, H, W) array with zeros outside the worm region
# find all voxels where g5_masked is zero for each time point
zero_mask = g5_masked == 0
# calculate the probability of a voxel being zero across time, but don't include time points where the voxel is masked out (i.e., outside the worm)
zero_prob = np.mean(zero_mask, axis=0)
# set zero probability to NaN for voxels outside the worm (mask binned needs a z dimension added)
zero_prob[mask_binned[np.newaxis].repeat(zero_prob.shape[0], axis=0) == 0] = np.nan
good_voxel_mask = zero_prob < zero_prob_threshold
g5 = g5_masked * good_voxel_mask[np.newaxis].repeat(g5_masked.shape[0], axis=0)

# save g5_trimmed and mask_binned to a npy file
np.savez(os.path.join(PTH, f'g5_trimmed_frames{frames[0]}to{frames[1]}.npz'), g5=g5, mask_binned=mask_binned, zero_prob=zero_prob)


## LOAD DATA

In [25]:
import sys
import os
import quantify_voxels
from numpy import genfromtxt
import matplotlib.pyplot as plt
import numpy as np
import tifffile
from skimage.morphology import erosion, disk
import importlib
import glob

import matplotlib
%matplotlib widget
font = {'family' : 'DejaVu Sans',
        'weight' : 'normal',
        'size'   : 15}
matplotlib.rc('font', **font)

_ = importlib.reload(sys.modules['quantify_voxels'])

PTH = r'D:\DATA\g5ht-free\20251223\date-20251223_strain-ISg5HT_condition-starvedpatch_worm005'
bin_factor = 2
fps = 1/0.533
# frames = (800,1000)  # only analyze frames from 800 to 1000 (these time points correspond to start of food encounter and two serotonin cycles)
frames = (800,1199)
z = 26
rfp_thresh = 50 # only include voxels where rfp_mean is above this threshold
keep_width = (25,228) # only keep this x range before

# load g5_trimmed
data = np.load(os.path.join(PTH, f'g5_trimmed_frames.npz'))
g5 = data['g5'][frames[0]:frames[1], z, :, :]  # g5 is R/R0 for frames 800 to 1000
mask = data['mask_binned']
print(g5.shape)
print(mask.shape)

# load rfp_mean.npy
rfp_mean = np.load(os.path.join(PTH, 'rfp_mean.npy'))
print(rfp_mean.shape)
rfp_mean_z = rfp_mean[z,:,:] # HW

# update mask based on rfp_thresh
mask_updated = mask.copy()
mask_updated[rfp_mean_z < rfp_thresh] = 0

# g5 is R/R0 for a single z slice, for frames 800 to 1000
# mask_updated is the segmentation mask updated based on rfp_thresh
g5 = g5 * mask_updated[np.newaxis, :, :]

# trim width in g5 and mask_updated
g5 = g5[:, :, keep_width[0]:keep_width[1]]
mask_updated = mask_updated[:, keep_width[0]:keep_width[1]]
print('Number of voxels in updated mask:', np.sum(mask_updated > 0))

raise NotImplementedError('Need to implement temporal smoothing before moving on.')


(399, 100, 250)
(100, 250)
(39, 100, 250)
Number of voxels in updated mask: 8409


In [14]:
%matplotlib qt
plt.figure()
plt.hist(rfp_mean[z,:,:].ravel(), bins=100)
plt.show()

## TEST POINT SOURCE IDENTIFICATION

Identify candidate source voxels for point source diffusion model.

In [31]:
def identify_point_sources(Y, mask, fps=1/0.533, n_sources=10, 
                          methods=['variance', 'early_response', 'peak_detection', 'gradient'],
                          early_window=(10, 60), baseline_window=(0, 60),
                          spatial_exclusion_radius=5, verbose=True):
    """
    Identify candidate point source locations from fluorescence movie.
    
    Parameters
    ----------
    Y : (T, H, W) array
        Fluorescence movie (already F/F0 or R/R0).
    mask : (H, W) bool array
        Tissue mask.
    fps : float
        Frames per second.
    n_sources : int
        Number of candidate sources to identify per method.
    methods : list of str
        Which detection methods to use. Options:
        - 'variance': High temporal variance voxels
        - 'early_response': Voxels that respond early
        - 'peak_detection': Voxels with strong transient peaks
        - 'gradient': Spatial gradient maxima (likely centers of activity)
    early_window : tuple (start_frame, end_frame)
        Time window to detect early responders.
    baseline_window : tuple (start_frame, end_frame)
        Baseline period for computing response magnitude.
    spatial_exclusion_radius : int
        Minimum distance (pixels) between detected sources.
    verbose : bool
        Print diagnostic info.
    
    Returns
    -------
    sources_dict : dict
        Dictionary with keys as method names, values as (M, 2) arrays of (row, col) coords.
    metrics_dict : dict
        Dictionary with metric values for each detected source.
    """
    T, H, W = Y.shape
    mask = mask.astype(bool)
    
    # Mask the data
    Y_masked = Y * mask[np.newaxis, :, :]
    
    sources_dict = {}
    metrics_dict = {}
    
    # =========================================================================
    # METHOD 1: TEMPORAL VARIANCE
    # =========================================================================
    if 'variance' in methods:
        # Compute temporal variance for each voxel (only where mask is True)
        var_map = np.var(Y_masked, axis=0)
        var_map[~mask] = 0  # Zero out outside mask
        
        # Find top n_sources local maxima
        sources_var = _find_local_maxima_2d(var_map, n_sources, 
                                            exclusion_radius=spatial_exclusion_radius,
                                            mask=mask)
        sources_dict['variance'] = sources_var
        metrics_dict['variance'] = var_map[sources_var[:, 0], sources_var[:, 1]]
        
        if verbose:
            print(f"Variance method: found {len(sources_var)} sources")
            print(f"  Variance range: {metrics_dict['variance'].min():.3f} - {metrics_dict['variance'].max():.3f}")
    
    # =========================================================================
    # METHOD 2: EARLY RESPONDERS
    # =========================================================================
    if 'early_response' in methods:
        # Compute baseline
        baseline = np.mean(Y_masked[baseline_window[0]:baseline_window[1]], axis=0)
        baseline[baseline == 0] = 1  # Avoid division by zero
        
        # Compute response magnitude in early window
        early_mean = np.mean(Y_masked[early_window[0]:early_window[1]], axis=0)
        response_mag = (early_mean - baseline) / baseline
        response_mag[~mask] = -np.inf  # Exclude outside mask
        
        # Find voxels with strongest early response
        sources_early = _find_local_maxima_2d(response_mag, n_sources,
                                              exclusion_radius=spatial_exclusion_radius,
                                              mask=mask)
        sources_dict['early_response'] = sources_early
        metrics_dict['early_response'] = response_mag[sources_early[:, 0], sources_early[:, 1]]
        
        if verbose:
            print(f"Early response method: found {len(sources_early)} sources")
            print(f"  Response magnitude range: {metrics_dict['early_response'].min():.3f} - {metrics_dict['early_response'].max():.3f}")
    
    # =========================================================================
    # METHOD 3: PEAK DETECTION (transient activity)
    # =========================================================================
    if 'peak_detection' in methods:
        from scipy.signal import find_peaks
        
        # For each voxel, count number of significant peaks
        peak_count_map = np.zeros((H, W))
        peak_height_map = np.zeros((H, W))
        
        for i in range(H):
            for j in range(W):
                if not mask[i, j]:
                    continue
                trace = Y_masked[:, i, j]
                # Find peaks with minimum prominence
                peaks, properties = find_peaks(trace, prominence=0.3, distance=10)
                peak_count_map[i, j] = len(peaks)
                if len(peaks) > 0:
                    peak_height_map[i, j] = np.max(properties['prominences'])
        
        # Use peak height as metric
        peak_height_map[~mask] = 0
        sources_peak = _find_local_maxima_2d(peak_height_map, n_sources,
                                             exclusion_radius=spatial_exclusion_radius,
                                             mask=mask)
        sources_dict['peak_detection'] = sources_peak
        metrics_dict['peak_detection'] = peak_height_map[sources_peak[:, 0], sources_peak[:, 1]]
        
        if verbose:
            print(f"Peak detection method: found {len(sources_peak)} sources")
            print(f"  Peak height range: {metrics_dict['peak_detection'].min():.3f} - {metrics_dict['peak_detection'].max():.3f}")
    
    # =========================================================================
    # METHOD 4: SPATIAL GRADIENT (centers of activity)
    # =========================================================================
    if 'gradient' in methods:
        # Compute mean intensity over time
        mean_intensity = np.mean(Y_masked, axis=0)
        
        # Compute Laplacian (regions where intensity is locally maximal)
        from scipy.ndimage import laplace
        laplacian = -laplace(mean_intensity)  # Negative because we want local maxima
        laplacian[~mask] = -np.inf
        
        sources_grad = _find_local_maxima_2d(laplacian, n_sources,
                                             exclusion_radius=spatial_exclusion_radius,
                                             mask=mask)
        sources_dict['gradient'] = sources_grad
        metrics_dict['gradient'] = laplacian[sources_grad[:, 0], sources_grad[:, 1]]
        
        if verbose:
            print(f"Gradient method: found {len(sources_grad)} sources")
            print(f"  Laplacian range: {metrics_dict['gradient'].min():.3f} - {metrics_dict['gradient'].max():.3f}")
    
    return sources_dict, metrics_dict


def _find_local_maxima_2d(metric_map, n_sources, exclusion_radius=5, mask=None):
    """
    Find n_sources local maxima in metric_map with spatial exclusion.
    
    Returns (n, 2) array of (row, col) coordinates.
    """
    from scipy.ndimage import maximum_filter
    
    if mask is not None:
        metric_map = metric_map.copy()
        metric_map[~mask] = -np.inf
    
    # Find local maxima using maximum filter
    local_max = maximum_filter(metric_map, size=exclusion_radius) == metric_map
    local_max[metric_map == -np.inf] = False
    local_max[np.isnan(metric_map)] = False
    
    # Get coordinates and values
    coords = np.argwhere(local_max)
    values = metric_map[local_max]
    
    # Sort by value (descending)
    sorted_idx = np.argsort(values)[::-1]
    coords_sorted = coords[sorted_idx]
    
    # Take top n_sources
    return coords_sorted[:n_sources]


def visualize_point_sources(Y, mask, sources_dict, metrics_dict, 
                            time_point=0, figsize=(18, 12)):
    """
    Visualize identified point sources overlaid on the image.
    
    Parameters
    ----------
    Y : (T, H, W) array
    mask : (H, W) bool
    sources_dict : dict
        Output from identify_point_sources.
    metrics_dict : dict
        Metrics from identify_point_sources.
    time_point : int
        Which time frame to show.
    figsize : tuple
    """
    import matplotlib.pyplot as plt
    from matplotlib.patches import Circle
    
    n_methods = len(sources_dict)
    
    fig, axes = plt.subplots(2, n_methods, figsize=figsize, constrained_layout=True)
    if n_methods == 1:
        axes = axes.reshape(2, 1)
    
    # Row 1: Sources overlaid on single time frame
    # Row 2: Sources overlaid on temporal mean
    
    mean_img = np.mean(Y, axis=0)
    frame_img = Y[time_point]
    
    colors = plt.cm.tab10(np.linspace(0, 1, 10))
    
    for col_idx, (method, sources) in enumerate(sources_dict.items()):
        # Top row: single frame
        ax1 = axes[0, col_idx]
        ax1.imshow(frame_img, cmap='gray', vmin=0, vmax=np.percentile(frame_img[mask], 99))
        
        for i, (row, col) in enumerate(sources):
            circle = Circle((col, row), radius=3, color=colors[i % 10], 
                          fill=False, linewidth=2, label=f'S{i+1}')
            ax1.add_patch(circle)
            ax1.plot(col, row, 'x', color=colors[i % 10], markersize=8, markeredgewidth=2)
        
        ax1.set_title(f'{method.replace("_", " ").title()}\n(Frame {time_point})')
        ax1.axis('off')
        
        # Bottom row: temporal mean
        ax2 = axes[1, col_idx]
        ax2.imshow(mean_img, cmap='gray', vmin=0, vmax=np.percentile(mean_img[mask], 99))
        
        for i, (row, col) in enumerate(sources):
            circle = Circle((col, row), radius=3, color=colors[i % 10], 
                          fill=False, linewidth=2)
            ax2.add_patch(circle)
            ax2.plot(col, row, 'x', color=colors[i % 10], markersize=8, markeredgewidth=2)
        
        ax2.set_title('Temporal Mean')
        ax2.axis('off')
    
    plt.suptitle('Identified Point Source Candidates', fontsize=16, y=0.98)
    plt.show()
    
    # Also plot the time traces for each source
    fig2, axes2 = plt.subplots(n_methods, 1, figsize=(12, 3*n_methods), constrained_layout=True)
    if n_methods == 1:
        axes2 = [axes2]
    
    t = np.arange(Y.shape[0]) / (1/0.533)  # Assume default fps
    
    for ax_idx, (method, sources) in enumerate(sources_dict.items()):
        ax = axes2[ax_idx]
        
        for i, (row, col) in enumerate(sources):
            trace = Y[:, row, col]
            ax.plot(t, trace, label=f'S{i+1} ({row},{col})', alpha=0.7, linewidth=1.5)
        
        ax.set_xlabel('Time (s)')
        ax.set_ylabel('R/R₀')
        ax.set_title(f'{method.replace("_", " ").title()} - Source Time Traces')
        ax.legend(ncol=3, fontsize=8, frameon=False)
        ax.grid(alpha=0.3)
    
    plt.show()
    
    # Print metric values
    print("\n" + "="*60)
    print("SOURCE METRICS")
    print("="*60)
    for method, sources in sources_dict.items():
        print(f"\n{method.upper()}:")
        for i, (row, col) in enumerate(sources):
            metric_val = metrics_dict[method][i]
            print(f"  Source {i+1}: ({row:3d}, {col:3d}) - metric = {metric_val:.4f}")


In [32]:
# Run point source identification on the loaded data
%matplotlib qt
plt.close('all')

# Identify candidate sources using multiple methods
sources_dict, metrics_dict = identify_point_sources(
    g5, 
    mask_updated, 
    fps=fps,
    n_sources=4,  # Find top 4 candidates per method
    methods=['variance', 'early_response', 'peak_detection', 'gradient'],
    early_window=(10, 60),  # frames 10-60 for early response
    baseline_window=(0, 30),  # frames 0-30 for baseline
    spatial_exclusion_radius=20,  # min 8 pixels between sources
    verbose=True
)

# Visualize the identified sources
visualize_point_sources(
    g5, 
    mask_updated, 
    sources_dict, 
    metrics_dict,
    time_point=100,  # show sources on frame 100
    figsize=(20, 10)
)

Variance method: found 4 sources
  Variance range: 8.658 - 18.887
Early response method: found 4 sources
  Response magnitude range: 0.827 - 1.193
Peak detection method: found 4 sources
  Peak height range: 13.800 - 15.702
Gradient method: found 4 sources
  Laplacian range: 5.691 - 7.860

SOURCE METRICS

VARIANCE:
  Source 1: ( 74, 170) - metric = 18.8867
  Source 2: ( 68, 157) - metric = 16.3984
  Source 3: ( 40, 142) - metric = 8.7281
  Source 4: ( 44, 113) - metric = 8.6583

EARLY_RESPONSE:
  Source 1: ( 80, 193) - metric = 1.1931
  Source 2: ( 43,  18) - metric = 1.1505
  Source 3: ( 47, 168) - metric = 1.0452
  Source 4: ( 44, 114) - metric = 0.8265

PEAK_DETECTION:
  Source 1: ( 74, 169) - metric = 15.7023
  Source 2: ( 68, 157) - metric = 15.4045
  Source 3: ( 44, 113) - metric = 14.4816
  Source 4: ( 40, 142) - metric = 13.8001

GRADIENT:
  Source 1: ( 74, 169) - metric = 7.8595
  Source 2: ( 43, 113) - metric = 6.9809
  Source 3: ( 47, 169) - metric = 6.6254
  Source 4: ( 30, 

## POINT SOURCE DIFFUSION MODEL (diffusion_v3)

Uses fixed point sources instead of NMF-learned spatial maps.
This addresses the identifiability problem where D → 0.

In [33]:
import diffusion_v3 as diff3
_ = importlib.reload(sys.modules['diffusion_v3'])

# Run smoke test to verify the module works
diff3.run_smoke_test(verbose=True)

Running diffusion_v3 smoke test...
Variance method: found 5 sources
  Variance range: 0.005 - 0.051
Early response method: found 5 sources
  Response magnitude range: 0.044 - 0.055
Peak detection method: found 5 sources
  Peak height range: 0.305 - 0.972
Gradient method: found 5 sources
  Laplacian range: 0.585 - 1.039
Detected 3 sources


INFO: Fitted: D=50.0000, k=0.0100, g=50.0000, MSE=0.003841, R²=-0.2059


Fitted: D=50.0000, k=0.0100, g=50.0000
MSE=0.003841, R²=-0.2059
Smoke test PASSED ✓


True

In [36]:
# Detect point sources using diffusion_v3
_ = importlib.reload(sys.modules['diffusion_v3'])

%matplotlib qt
plt.close('all')

# Configure source detection
source_config = diff3.SourceDetectionConfig(
    n_sources=3,
    methods=('variance', 'early_response', 'peak_detection', 'gradient'),
    early_window=(10, 60),
    baseline_window=(0, 30),
    spatial_exclusion_radius=25,
    peak_prominence=0.3,
    peak_distance=10
)

# Detect sources
sources_dict_v3, metrics_dict_v3 = diff3.identify_point_sources(
    g5, mask_updated, config=source_config, verbose=True
)

# Visualize all methods
diff3.visualize_point_sources(
    g5, mask_updated, sources_dict_v3, metrics_dict_v3,
    time_point=100, fps=fps, figsize=(20, 10)
)

Variance method: found 3 sources
  Variance range: 8.658 - 18.887
Early response method: found 3 sources
  Response magnitude range: 1.045 - 1.193
Peak detection method: found 3 sources
  Peak height range: 13.800 - 15.702
Gradient method: found 3 sources
  Laplacian range: 6.625 - 7.860


In [37]:
# Combine sources from different methods and visualize final selection
_ = importlib.reload(sys.modules['diffusion_v3'])

# Combine sources - use 'variance' method as primary (usually most reliable)
# Other options: 'union', 'intersection', 'weighted'
combined_sources = diff3.combine_source_candidates(
    sources_dict_v3, metrics_dict_v3,
    method='variance',  # Use variance-based sources
    max_sources=3,  # Limit to 3 sources
    exclusion_radius=25
)

print(f"Combined sources: {len(combined_sources)} locations")
print("Source coordinates (row, col):")
for i, (r, c) in enumerate(combined_sources):
    print(f"  Source {i+1}: ({r}, {c})")

# Visualize combined sources
diff3.visualize_combined_sources(
    g5, mask_updated, combined_sources,
    fps=fps, time_point=100, figsize=(14, 5)
)

Combined sources: 3 locations
Source coordinates (row, col):
  Source 1: (74, 170)
  Source 2: (40, 142)
  Source 3: (44, 113)


In [38]:
# Fit the point source diffusion model
_ = importlib.reload(sys.modules['diffusion_v3'])

import logging
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')

# Fit with joint optimization of a(t), D, k, g
result_v3 = diff3.fit_point_source_model(
    g5, 
    mask_updated.astype(bool),
    combined_sources,
    dt=0.533,
    D0=1.0,       # Initial guess for D (should be higher now!)
    k0=0.5,       # Initial guess for k
    g0=1.0,       # Initial guess for g
    fit_a_t=True, # Jointly optimize time courses
    loss_subsample=30000,
    max_iter=100,
    verbose=True
)

print("\n" + "="*60)
print("POINT SOURCE MODEL RESULTS")
print("="*60)
print(f"Diffusion coefficient D = {result_v3['D']:.4f} pixels²/s")
print(f"Decay rate k = {result_v3['k']:.4f} s⁻¹")
print(f"Source gain g = {result_v3['g']:.4f}")
print(f"\nFit quality:")
print(f"  MSE = {result_v3['mse']:.6f}")
print(f"  R² = {result_v3['r_squared']:.4f}")

INFO: 
=== Outer iteration 1 ===
INFO: Optimizing D, k, g...
INFO: Fitted: D=50.0000, k=0.0100, g=50.0000, MSE=1.128625, R²=-0.1912
INFO: Optimizing source time courses...
INFO:   D=50.0000, k=0.0100, g=50.0000, MSE=1.407616
  if abs(prev_mse - mse) / (prev_mse + 1e-10) < 1e-4:
INFO: 
=== Outer iteration 2 ===
INFO: Optimizing D, k, g...
INFO: Fitted: D=2.0136, k=0.0100, g=50.0000, MSE=1.340011, R²=-0.4143
INFO: Optimizing source time courses...
INFO:   D=2.0136, k=0.0100, g=50.0000, MSE=1.338758
INFO: 
=== Outer iteration 3 ===
INFO: Optimizing D, k, g...
INFO: Fitted: D=2.0495, k=0.0100, g=50.0000, MSE=1.338746, R²=-0.4129
INFO: Optimizing source time courses...
INFO:   D=2.0495, k=0.0100, g=50.0000, MSE=1.338744
INFO: Converged!
INFO: 
Final: D=2.0495, k=0.0100, g=50.0000, MSE=1.338744, R²=-0.4129



POINT SOURCE MODEL RESULTS
Diffusion coefficient D = 2.0495 pixels²/s
Decay rate k = 0.0100 s⁻¹
Source gain g = 50.0000

Fit quality:
  MSE = 1.338744
  R² = -0.4129


In [39]:
# Visualize source activity (locations and time courses)
_ = importlib.reload(sys.modules['diffusion_v3'])

%matplotlib qt
plt.close('all')

# Create a PointSourceResult-like object for visualization
# (since fit_point_source_model returns a dict, we need to construct the visualization manually)

# Source locations and time courses
fig, axes = plt.subplots(1, 2, figsize=(14, 5), constrained_layout=True)

# Left: Source locations on mean image
mean_img = np.mean(g5, axis=0)
vmax = np.percentile(mean_img[mask_updated], 99) if mask_updated.sum() > 0 else mean_img.max()

axes[0].imshow(mean_img, cmap='gray', vmin=0, vmax=vmax)

colors = plt.cm.tab10(np.linspace(0, 1, len(combined_sources)))
from matplotlib.patches import Circle
for i, (row, col) in enumerate(combined_sources):
    circle = Circle((col, row), radius=4, color=colors[i], fill=False, linewidth=2)
    axes[0].add_patch(circle)
    axes[0].text(col + 5, row, f'{i+1}', color=colors[i], fontsize=10, fontweight='bold')

axes[0].set_title('Source Locations')
axes[0].axis('off')

# Right: Time courses
t = np.arange(result_v3['a_t'].shape[0]) * 0.533
for i in range(result_v3['a_t'].shape[1]):
    axes[1].plot(t, result_v3['a_t'][:, i], label=f'Source {i+1}', 
                color=colors[i], linewidth=1.5)

axes[1].set_xlabel('Time (s)')
axes[1].set_ylabel('Source Activity')
axes[1].set_title('Fitted Source Time Courses')
axes[1].legend(ncol=2, frameon=False)
axes[1].grid(alpha=0.3)

plt.suptitle(f'Point Source Model\nD={result_v3["D"]:.3f} px²/s, k={result_v3["k"]:.3f} s⁻¹, R²={result_v3["r_squared"]:.3f}', 
             fontsize=12)
plt.show()

In [40]:
# Compare observed vs reconstructed at selected time points
_ = importlib.reload(sys.modules['diffusion_v3'])

%matplotlib qt
plt.close('all')

time_points = [0, 50, 100, 150, 199]  # Select time points to visualize
n_times = len(time_points)

fig, axes = plt.subplots(3, n_times, figsize=(15, 10), constrained_layout=True)

Ihat = result_v3['Ihat']  # (T+1, N) reconstruction

for col, t in enumerate(time_points):
    # Row 1: Observed
    obs = g5[t]
    vmax = np.percentile(obs[mask_updated], 99) if mask_updated.sum() > 0 else obs.max()
    axes[0, col].imshow(obs, cmap='viridis', vmin=0, vmax=vmax)
    axes[0, col].set_title(f't = {t * 0.533:.1f}s')
    axes[0, col].axis('off')
    
    # Row 2: Reconstructed
    recon = np.zeros_like(obs)
    recon[mask_updated] = Ihat[t+1]  # t+1 because Ihat[0] is initial condition
    axes[1, col].imshow(recon, cmap='viridis', vmin=0, vmax=vmax)
    axes[1, col].axis('off')
    
    # Row 3: Residual
    resid = np.zeros_like(obs)
    resid[mask_updated] = obs[mask_updated] - Ihat[t+1]
    vlim = np.percentile(np.abs(resid[mask_updated]), 95) if mask_updated.sum() > 0 else 1
    axes[2, col].imshow(resid, cmap='RdBu_r', vmin=-vlim, vmax=vlim)
    axes[2, col].axis('off')

axes[0, 0].set_ylabel('Observed', fontsize=12)
axes[1, 0].set_ylabel('Reconstructed', fontsize=12)
axes[2, 0].set_ylabel('Residual', fontsize=12)

plt.suptitle(f'Point Source Model Reconstruction\n'
             f'D={result_v3["D"]:.3f} px²/s, k={result_v3["k"]:.3f} s⁻¹, g={result_v3["g"]:.3f}, '
             f'R²={result_v3["r_squared"]:.3f}', fontsize=14)
plt.show()

In [54]:
# Use the full pipeline function for convenience
_ = importlib.reload(sys.modules['diffusion_v3'])

# This runs the complete pipeline: detection → fitting
result_full = diff3.fit_diffusion_point_sources(
    g5,
    mask_updated.astype(bool),
    dt=0.533,
    n_sources=20,
    bin_factors=(2, 2),  #  additional binning on top of existing (2,2)
    source_detection_config=diff3.SourceDetectionConfig(
        n_sources=8,
        methods=('variance', 'early_response', 'peak_detection'),
        spatial_exclusion_radius=8
    ),
    source_combine_method='variance',
    D0=0.01,
    k0=1.0,
    g0=1.0,
    fit_a_t=True,
    loss_subsample=30000,
    verbose=True
)

print("\n" + "="*60)
print("FULL PIPELINE RESULTS")
print("="*60)
print(f"Number of sources: {result_full.n_sources}")
print(f"Diffusion coefficient D = {result_full.D:.4f} pixels²/s")
print(f"Decay rate k = {result_full.k:.4f} s⁻¹")
print(f"Source gain g = {result_full.g:.4f}")
print(f"\nFit quality:")
print(f"  MSE = {result_full.mse:.6f}")
print(f"  R² = {result_full.r_squared:.4f}")

INFO: Binning movie 2x2...
INFO: Detecting point sources...
INFO: Using 8 point sources
INFO: Fitting point source diffusion model...
INFO: 
=== Outer iteration 1 ===
INFO: Optimizing D, k, g...


Variance method: found 8 sources
  Variance range: 2.351 - 15.455
Early response method: found 8 sources
  Response magnitude range: 0.491 - 1.056
Peak detection method: found 8 sources
  Peak height range: 6.725 - 12.832


INFO: Fitted: D=50.0000, k=0.0163, g=21.3328, MSE=0.754277, R²=-0.0048
INFO: Optimizing source time courses...
INFO:   D=50.0000, k=0.0163, g=21.3328, MSE=0.748770
  if abs(prev_mse - mse) / (prev_mse + 1e-10) < 1e-4:
INFO: 
=== Outer iteration 2 ===
INFO: Optimizing D, k, g...
INFO: Fitted: D=50.0000, k=0.0125, g=14.7717, MSE=0.727345, R²=0.0311
INFO: Optimizing source time courses...
INFO:   D=50.0000, k=0.0125, g=14.7717, MSE=0.860590
INFO: 
=== Outer iteration 3 ===
INFO: Optimizing D, k, g...
INFO: Fitted: D=50.0000, k=0.0115, g=9.9734, MSE=0.726696, R²=0.0320
INFO: Optimizing source time courses...


KeyboardInterrupt: 

In [50]:
# Visualize results using built-in functions
_ = importlib.reload(sys.modules['diffusion_v3'])

%matplotlib qt
plt.close('all')

# Visualize fit quality
diff3.visualize_fit_result(result_full, time_points=[0, 50, 100, 150, 199], figsize=(15, 10))

# Visualize source activity
diff3.visualize_source_activity(result_full, figsize=(14, 6))

In [47]:
# Visualize how diffusion spreads from a single source (impulse response)
_ = importlib.reload(sys.modules['diffusion_v3'])

%matplotlib qt
plt.close('all')

# Show diffusion spread from source 0
diff3.plot_diffusion_spread(
    result_full, 
    source_idx=0, 
    time_after_pulse=[0.5, 1.0, 2.0, 5.0, 20.0],
    figsize=(16, 4)
)

# Also try source 1 if it exists
if result_full.n_sources > 1:
    diff3.plot_diffusion_spread(
        result_full, 
        source_idx=1, 
        time_after_pulse=[0.5, 1.0, 2.0, 5.0],
        figsize=(16, 4)
    )

## NMF SOURCE DIFFUSION MODEL

### SMOKE TEST

In [None]:
import diffusion_v2 as diff
_ = importlib.reload(sys.modules['diffusion_v2'])

diff.run_smoke_test()

### FIT MODEL

In [15]:
# ============================================================
# IMPROVED DIFFUSION MODEL FITTING (diffusion_v2)
# ============================================================
# Uses the new diffusion_v2 module with:
# - Vectorized Laplacian construction with validation
# - Cached LU factorization for efficient simulation
# - Cleaner NMF source learning with preprocessing options
# - Staged fitting: learn sources → fit D,k,g → optional refinement
# - Comprehensive result dataclass with diagnostics

import diffusion_v2 as diff
_ = importlib.reload(sys.modules['diffusion_v2'])


import logging
# Enable logging to see progress
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')

# Fit the reaction-diffusion model
# Y: (T, H, W) fluorescence movie (already F/F0 or ΔR/R0)
# mask: (H, W) boolean tissue mask

reg_config = diff.RegularizationConfig(
    source_l1_weight=0.1,
    D_prior_weight=0.1,
    D_prior_mean=0.05  # Your expected D value
)

result = diff.fit_diffusion_model_2d(
    g5, 
    mask_updated.astype(bool),
    reg_config=None,
    dt=0.533,                    # seconds per frame
    n_sources=3,                # number of source components (start with 3-8)
    bin_factors=(2,2),          # (1,1) = no additional binning (data already 2x2 binned)
    hp_sigma_frames=3.0,         # high-pass filter for source learning
    loss_subsample=100000,        # subsample pixels for faster optimization
    theta=1.0,                   # 1.0=backward Euler (stable), 0.5=Crank-Nicolson
    refine_sources=False,        # set True to refine time courses after fitting
    verbose=True
)

# Print fitted parameters
print("\n" + "="*50)
print("FITTED PARAMETERS")
print("="*50)
print(f"Diffusion coefficient D = {result.D:.4f} pixels²/s")
print(f"Decay rate k = {result.k:.4f} s⁻¹")
print(f"Source gain g = {result.g:.4f}")
print(f"\nFit quality:")
print(f"  MSE = {result.mse:.6f}")
print(f"  R² = {result.r_squared:.4f}")
print(f"\nData dimensions:")
print(f"  T = {result.Y_used.shape[0]} frames")
print(f"  N = {result.mask_used.sum()} masked pixels")
print(f"  M = {result.n_sources} source components")

INFO: Input: T=499, H=100, W=250, N_masked=6746
INFO: Binning by (2, 2)
INFO: After binning: T=499, H=50, W=125, N_masked=1594
INFO: Building Laplacian...
INFO: Stability: λ_max≈7.94, D_max_explicit=0.47
INFO: Learning 3 sources via NMF...
INFO: NMF sources learned: M=3, recon_error=0.8281
INFO: NMF reconstruction error: 0.8281
INFO: Fitting D, k, g...
INFO: Eval 10: D=0.4971, k=0.1437, g=1.4012, MSE=0.701969
INFO: Eval 20: D=0.4778, k=0.1032, g=2.1255, MSE=1.136951
INFO: Eval 30: D=0.4843, k=0.1262, g=1.6800, MSE=0.562474
INFO: Eval 40: D=0.4711, k=0.1298, g=1.7115, MSE=0.560623
INFO: Eval 50: D=0.0890, k=0.5607, g=6.6263, MSE=0.499055
INFO: Eval 60: D=0.0510, k=0.5178, g=6.9496, MSE=0.471288
INFO: Eval 70: D=0.0452, k=0.4583, g=5.9911, MSE=0.468576
INFO: Eval 80: D=0.0167, k=0.3059, g=3.9554, MSE=0.462175
INFO: Eval 90: D=0.0026, k=0.2784, g=3.7025, MSE=0.457139
INFO: Eval 100: D=0.0008, k=0.2738, g=3.6406, MSE=0.456718
INFO: Eval 110: D=0.0001, k=0.2807, g=3.7173, MSE=0.456531
INFO:


FITTED PARAMETERS
Diffusion coefficient D = 0.0001 pixels²/s
Decay rate k = 0.2794 s⁻¹
Source gain g = 3.7030

Fit quality:
  MSE = 0.456525
  R² = 0.4379

Data dimensions:
  T = 499 frames
  N = 1594 masked pixels
  M = 3 source components


In [16]:
# Visualize learned source spatial maps and time courses
%matplotlib qt
_ = importlib.reload(sys.modules['diffusion_v2'])
diff.visualize_sources(result, max_sources=5, figsize=(15, 6))

In [17]:
# Compare observed vs reconstructed at selected time points
_ = importlib.reload(sys.modules['diffusion_v2'])
diff.plot_reconstruction_comparison(result, time_points=[0, 50, 100, 150, 199])

In [None]:
# Compute detailed residual metrics
_ = importlib.reload(sys.modules['diffusion_v2'])
metrics = diff.compute_residual_metrics(result)
print("Residual Metrics:")
for k, v in metrics.items():
    print(f"  {k}: {v:.4f}")

### Select number of sources using elbow method (optional)
Use this to determine the optimal number of source components M.

In [None]:
# Elbow plot to select number of sources (optional - takes a few minutes)
M_best, errors = diff.select_n_sources_elbow(g5, mask, M_range=range(2, 15), plot=True)
print(f"Suggested number of sources: M = {M_best}")

In [None]:
# Reshape a source spatial map to image for custom visualization
source_idx = 0  # which source to visualize
phi_img = diff.reshape_to_image(result.Phi[:, source_idx], result.mask_used)

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.imshow(phi_img, cmap='hot')
plt.colorbar(label='Spatial weight')
plt.title(f'Source {source_idx+1} spatial map')
plt.axis('off')

plt.subplot(1, 2, 2)
t = np.arange(result.a_t.shape[0]) * result.dt
plt.plot(t, result.a_t[:, source_idx], 'b-', lw=2)
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.title(f'Source {source_idx+1} time course')
plt.tight_layout()
plt.show()