In [6]:
from pathlib import Path
import re, os, sys
import numpy as np
import matplotlib.pyplot as plt
import importlib
import roicat
from functools import reduce

import roi_tracking_helpers
importlib.reload(roi_tracking_helpers)
from roi_tracking_helpers import roicat_align_rois, roicat_visualize_tracked_rois


In [7]:
# Choose sessions and paths
sessions_to_align = ['ses-010_date-20250314_protocol-t4', 'ses-011_date-20250315_protocol-t5'] #, 'ses-012_date-20250318_protocol-t6']
protocol_nums = [int(re.search(r'protocol-t(\d+)', s).group(1)) for s in sessions_to_align]

basepath = Path('/Volumes/mrsic_flogel/public/projects/AtApSuKuSaRe_20250129_HFScohort2')
animal = 'TAA0000066'
func_img_path = 'funcimg/Session'
alignment_path = basepath / animal / '_'.join(f't{n}' for n in protocol_nums)
savepath = alignment_path / 'analysis' 
roicat_data_name = animal 

roicat_dir = r'/Users/athinaapostolelli/Documents/SWC/VR_ABCD_project/ROICaT'

#### Align neural data and ROIs

In [8]:
if os.path.exists(os.path.join(alignment_path, f'roicat_aligned_ROIs_{'_'.join(f't{n}' for n in protocol_nums)}.npy')):
    print('Found ROIs alignment for these sessions. Loading...')
    idx_original_aligned = np.load(os.path.join(alignment_path, f'roicat_aligned_ROIs_{'_'.join(f't{n}' for n in protocol_nums)}.npy'))
else:
    # Load data to align
    data_to_align = [[] for s in range(len(sessions_to_align))]
    for s, sess in enumerate(sessions_to_align):
        DF_F_file = os.path.join(basepath, animal, sess, func_img_path, 'suite2p/plane0/DF_F0.npy')
        valid_frames_file = os.path.join(basepath, animal, sess, 'valid_frames.npz')

        if os.path.exists(DF_F_file) and os.path.exists(valid_frames_file):
            valid_frames = np.load(valid_frames_file)['valid_frames']

            DF_F_all = np.load(DF_F_file)
            DF_F = DF_F_all[:, valid_frames]

            data_to_align[s] = DF_F
            
        else:
            raise FileNotFoundError('Either the valid_frames.npz or the DF_F0.npy file does not exist in this directory.')

    # Align the data
    _, idx_original_aligned = roicat_align_rois(roicat_dir=roicat_dir,
                        roicat_data_name=roicat_data_name,
                        sessions_to_align=sessions_to_align,
                        basepath=basepath,
                        animal=animal,
                        alignment_method='DF_F0',  # this is ignored if data is not None
                        data=data_to_align,
                        plot_alignment=False,
                        savepath=alignment_path)
    


Found ROIs alignment for these sessions. Loading...


#### Visualize ROIs

In [62]:
# Visualize all ROIs from all sessions

# Keep neurons that have been clustered in all sessions
idx_array = np.vstack(idx_original_aligned)  # shape: (n_sessions, n_neurons)
valid_mask = np.all(~np.isnan(idx_array), axis=0)  # Find neurons that are not nan in all sessions
valid_indices = np.where(valid_mask)[0]  # Get the indices that are valid across all sessions
tracked_neuron_ids = [session[valid_indices].astype(int) for session in idx_original_aligned]

roicat_visualize_tracked_rois(roicat_dir, roicat_data_name, sessions_to_align=sessions_to_align, tracked_neuron_ids=tracked_neuron_ids)

In [14]:
# Choose specific subset of neurons
tracked_neuron_file = np.load(os.path.join(savepath, 'tracked_reward_tuned_neurons.npz'), allow_pickle=True)
tracked_neuron_idx = tracked_neuron_file['tracked_reward_id_tuned_idx'][()]
neurons_considered_sessions = tracked_neuron_file['neurons_considered_sessions']  # neurons with good signal for each session

ABCD_goals = [1,2,3,4]
num_sessions = len(sessions_to_align)
# overlapping_neurons = {s: {goal: [] for goal in (ABCD_goals if protocol_nums[s] >= 5 else [1])} for s in range(num_sessions)}

# Visualize neurons tuned to each goal in all sessions
for goal in ABCD_goals:
    if not all(goal in tracked_neuron_idx[s] for s in range(num_sessions)):
        continue 

    neuron_lists = [tracked_neuron_idx[s][goal] for s in range(num_sessions)]
    overlapping_neurons = reduce(np.intersect1d, neuron_lists)
    tracked_neuron_ids = [session[overlapping_neurons].astype(int) for session in neurons_considered_sessions]

    print(f'Visualizing neurons tuned to goal {goal} in {' '.join([f'T{protocol_nums[s]}' for s in range(num_sessions)])}:')
    roicat_visualize_tracked_rois(roicat_dir, roicat_data_name, sessions_to_align=sessions_to_align, tracked_neuron_ids=tracked_neuron_ids)

Visualizing neurons tuned to goal 1 in T4 T5:
