In [None]:
import h5py
import numpy as np
import matplotlib.pyplot as plt
from rastermap import Rastermap, utils
from scipy.stats import zscore
import scipy
import seaborn
import pickle
import cv2
from sklearn import decomposition, manifold
import pandas as pd


In [None]:
from skimage.measure import find_contours
from scipy.sparse import csc_matrix
def get_contours(A, dims, thr=0.9, thr_method='nrg', swap_dim=False, slice_dim = None):
    """Gets contour of spatial components and returns their coordinates

     Args:
         A:   np.ndarray or sparse matrix
                   Matrix of Spatial components (d x K)

             dims: tuple of ints
                   Spatial dimensions of movie

             thr: scalar between 0 and 1
                   Energy threshold for computing contours (default 0.9)

             thr_method: string
                  Method of thresholding:
                      'max' sets to zero pixels that have value less than a fraction of the max value
                      'nrg' keeps the pixels that contribute up to a specified fraction of the energy
            
             swap_dim: bool
                  If False (default), each column of A should be reshaped in F-order to recover the mask;
                  this is correct if the dimensions have not been reordered from (y, x[, z]).
                  If True, each column should be reshaped in C-order; this is correct for dims = ([z, ]x, y).

             slice_dim: int or None
                  Which dimension to slice along if we have 3D data. (i.e., get contours on each plane along this axis).
                  The default (None) is 0 if swap_dim is True, else -1.

     Returns:
         Coor: list of coordinates with center of mass and
                contour plot coordinates (per layer) for each component
    """


    A = csc_matrix((A['data'][:], A['indices'][:], A['indptr'][:]), shape=A['shape'][:])
    d, nr = np.shape(A)

    coordinates = []

    # for each patches
    for i in range(nr):
        pars:dict = dict()
        # we compute the cumulative sum of the energy of the Ath component that has been ordered from least to highest
        patch_data = A.data[A.indptr[i]:A.indptr[i + 1]]
        indx = np.argsort(patch_data)[::-1]
        if thr_method == 'nrg':
            cumEn = np.cumsum(patch_data[indx]**2)
            if len(cumEn) == 0:
                pars = dict(
                    coordinates=np.array([]),
                    CoM=np.array([np.NaN, np.NaN]),
                    neuron_id=i + 1,
                )
                coordinates.append(pars)
                continue
            else:
                # we work with normalized values
                cumEn /= cumEn[-1]
                Bvec = np.ones(d)
                # we put it in a similar matrix
                Bvec[A.indices[A.indptr[i]:A.indptr[i + 1]][indx]] = cumEn
        else:
            if thr_method != 'max':
                warn("Unknown threshold method. Choosing max")
            Bvec = np.zeros(d)
            Bvec[A.indices[A.indptr[i]:A.indptr[i + 1]]] = patch_data / patch_data.max()

        if swap_dim:
            Bmat = np.reshape(Bvec, dims, order='C')
        else:
            Bmat = np.reshape(Bvec, dims, order='F')

        def get_slice_coords(B: np.ndarray) -> np.ndarray:
            """Get contour coordinates for a 2D slice"""
            d1, d2 = B.shape
            vertices = find_contours(B.T, thr)
            # this fix is necessary for having disjoint figures and borders plotted correctly
            v = np.atleast_2d([np.nan, np.nan])
            for _, vtx in enumerate(vertices):
                num_close_coords = np.sum(np.isclose(vtx[0, :], vtx[-1, :]))
                if num_close_coords < 2:
                    if num_close_coords == 0:
                        # case angle
                        newpt = np.round(np.mean(vtx[[0, -1], :], axis=0) / [d2, d1]) * [d2, d1]
                        vtx = np.concatenate((newpt[np.newaxis, :], vtx, newpt[np.newaxis, :]), axis=0)
                    else:
                        # case one is border
                        vtx = np.concatenate((vtx, vtx[0, np.newaxis]), axis=0)
                v = np.concatenate(
                    (v, vtx, np.atleast_2d([np.nan, np.nan])), axis=0)
            return v
        
        if len(dims) == 2:
            pars['coordinates'] = get_slice_coords(Bmat)
        else:
            # make a list of the contour coordinates for each 2D slice
            pars['coordinates'] = []
            if slice_dim is None:
                slice_dim = 0 if swap_dim else -1
            for s in range(dims[slice_dim]):
                B = Bmat.take(s, axis=slice_dim)
                pars['coordinates'].append(get_slice_coords(B))

        pars['neuron_id'] = i + 1
        coordinates.append(pars)
    return coordinates


In [None]:
def load_data(path, filename, pnr):
    with h5py.File(filename, "r") as f:
        # print("Keys: %s" % f.keys())
        # print(f['dims'])
        # print(f['dview'])
        # print(f['estimates'])
        # print(f['mmap_file'])
        # print(f['params'])
        # print(f['remove_very_bad_comps'])
        # print(f['skip_refinement'])
    
        data = f['estimates']['S_dff'][:]
        # print(f['estimates'].keys())
        # print(f['estimates']['A'])
        A = f['estimates']['A']
        dims = f['dims']
        print(data.shape)
        coordinates = get_contours(A,dims)
    
    
    valid_index = []
    centers = []
    for i,c in enumerate(coordinates):
        center = c['coordinates'][~(np.isnan(c['coordinates'][:,0])|np.isnan(c['coordinates'][:,1]))].mean(axis=0)
        if np.isnan(center[0])==False and np.isnan(center[1])==False:
            centers.append(center)
            valid_index.append(i)
            
    centers = np.array(centers)
    valid_index = np.array(valid_index)
    
    with open(pnr, 'rb') as file:
        img = pickle.load(file)
        plt.figure(figsize=(12.8, 8), dpi=150)
        plt.imshow(img)
        plt.show()
    
    plt.figure(figsize=(12.8, 8), dpi=150)
    plt.xlim(0,img.shape[1])
    plt.ylim(0,img.shape[0])
    plt.scatter(centers[:,0],centers[:,1],s=10)
    plt.gca().invert_yaxis()
    for i,c in enumerate(centers):
        plt.annotate(i,c,size=5)
    plt.show()

    return centers, img,valid_index

In [None]:
path = 'G:/mPFC-2/2025-05-18/'
date = path[-11:-1]
filename = path+date+"_cnmf.hdf5"
pnr = path+date+"_pnr.pickle"

centers_1,img_1,valid_index_1 = load_data(path,filename,pnr)

In [None]:
path = 'G:/mPFC-3/2025-07-05/'
filename = path+"2025-07-05_cnmf.hdf5"
pnr = path+"2025-07-05_pnr.pickle"

centers_2,img_2,valid_index_2 = load_data(path,filename,pnr)

In [None]:
path = 'G:/mPFC-3/2025-07-11/'
filename = path+"2025-07-11_cnmf.hdf5"
pnr = path+"2025-07-11_pnr.pickle"

centers_1,img_1,valid_index_1 = load_data(path,filename,pnr)

In [None]:
def align(centers, src_points, dst_points, img):
    # Compute the perspective transform matrix
    matrix = cv2.getPerspectiveTransform(src_points, dst_points)
    
    # Apply the perspective transformation
    rectified_image = cv2.warpPerspective(img, matrix, (img.shape[1], img.shape[0]))
    
    plt.figure(figsize=(12.8, 8), dpi=150)
    plt.imshow(rectified_image)
    plt.show()
    
    # print(matrix)
    # print(centers)
    new_centers = []
    for c in centers:
        xy1 = np.array([c[0],c[1],1]).T
        x_y_s = np.matmul(matrix,xy1)
        # print(x_y_s)
        new_centers.append([x_y_s[0]/x_y_s[2],x_y_s[1]/x_y_s[2]])
    new_centers = np.array(new_centers)
    plt.figure(figsize=(12.8, 8), dpi=150)
    plt.xlim(0,img.shape[1])
    plt.ylim(0,img.shape[0])
    plt.scatter(new_centers[:,0],new_centers[:,1],s=10)
    plt.gca().invert_yaxis()
    for i,c in enumerate(new_centers):
        plt.annotate(i,c,size=5)
    plt.show()

    return new_centers

In [None]:
# Define the source points (corners of the distorted image)
a = 10
b = 203
c = 602
d = 740

src_points = np.float32([[centers_1[a,0],centers_1[a,1]], [centers_1[b,0],centers_1[b,1]], [centers_1[c,0],centers_1[c,1]],[centers_1[d,0],centers_1[d,1]]])

e = 6
f = 196
g = 569
h = 740

# Define the destination points (where you want the corners to be)
dst_points = np.float32([[centers_2[e,0],centers_2[e,1]], [centers_2[f,0],centers_2[f,1]],[centers_2[g,0],centers_2[g,1]],[centers_2[h,0],centers_2[h,1]]])

new_centers_1 = align(centers_1, src_points, dst_points, img_1)

In [None]:
def register(array_1, array_2,img,threshold=20):
    cost_values = []
    pairs = []
    cost_matrix = np.zeros([len(array_1),len(array_2)])
    for i in range(len(array_1)):
        for j in range(len(array_2)):
            distance = np.linalg.norm(array_2[j,:] - array_1[i,:])
            cost_matrix[i,j] = distance
            cost_values.append(cost_matrix[i,j])
            pairs.append([i,j])
    cost_values = np.array(cost_values)
    pairs = np.array(pairs)
    sorted_index = np.argsort(cost_values)
    cost_values = cost_values[sorted_index]
    pairs = pairs[sorted_index]
    registered_1 = []
    registered_2 = []

    corresponding_12 = {}
    corresponding_21 = {}
    for i,d in enumerate(cost_values):
        if d > threshold or len(registered_1) >= min(len(array_1),len(array_2)):
            break
        if pairs[i][0] not in registered_1 and pairs[i][1] not in registered_2:
            registered_1.append(pairs[i][0])
            registered_2.append(pairs[i][1])
            corresponding_12[pairs[i][0]] = pairs[i][1]
            corresponding_21[pairs[i][1]] = pairs[i][0]
            

    plt.figure(figsize=(12.8, 8), dpi=150)
    plt.scatter(array_1[registered_1,0],array_1[registered_1,1],s=10)
    plt.scatter(array_2[registered_2,0],array_2[registered_2,1],s=10)
    
    
    plt.xlim(0,img.shape[1])
    plt.ylim(0,img.shape[0])
    plt.gca().invert_yaxis()
    
    plt.show()
    return registered_1, registered_2, corresponding_12, corresponding_21

In [None]:
registered_1, registered_2, corresponding_12, corresponding_21 = register(new_centers_1, centers_2,100)

In [None]:
print(len(registered_1))

In [None]:
def generate_continuous_data(data, event, trial, padding,delay,interval,percentile,gaussian):
    data = data.T


    if interval > 0:
        for i in range(int(data.shape[0]/interval)+1):
            baseline = np.mean(data[i*interval:(i+1)*interval,:])
            data[i*interval:(i+1)*interval,:] = data[i*interval:(i+1)*interval,:]/baseline

    for n in range(data.shape[1]):
        threshold = np.percentile(data[:,n][data[:,n]>0],percentile)
        data[:,n][data[:,n]<threshold] = 0
        
    data = scipy.ndimage.gaussian_filter1d(data,2,axis=0)

    data = data.T
    
    data = scipy.stats.zscore(data, axis=1)
    data = data[:, round(event[trial[0]])-padding[0]:round(event[trial[1]])+padding[1]] 
    
    event_index = []
    for i in range(trial[0], trial[1]+1):
        event_index.append(round(event[i])-(round(event[trial[0]])-padding[0]))
    
    data = np.array(data)
    event_index = np.array(event_index)

    # Substract delay
    delay = delay
    event_index_new = []
    for i,e in enumerate(event_index):
        event_index_new.append(round(e+i*delay))
    event_index = event_index_new

    return data, event_index

In [None]:
path = 'G:/mPFC-2/2025-05-15/'
filename = path+"2025-05-15_cnmf.hdf5"
Trial_idx_file = path+'Trial_idx1.mat'
Trial_idx = scipy.io.loadmat(Trial_idx_file)['Trial_idx1']
door_open = Trial_idx[:,1]
with h5py.File(filename, "r") as f:
    data = f['estimates']['S_dff'][:]
print(data.shape)
data, door_open_index = generate_continuous_data(data[valid_index_1], door_open, [0,len(door_open)-1], [100,200],0.13,1000,30,0.1)

single_trial_data = []
for i,e in enumerate(door_open_index):
    single_trial_data.append(data[registered_1,e-50:e+50])
trial_averaged_data = np.mean(single_trial_data,axis=0)
sorted_index = np.argsort(np.argmax(trial_averaged_data,axis=1))
seaborn.heatmap(trial_averaged_data[sorted_index])
plt.show()

In [None]:
path = 'G:/mPFC-3/2025-07-12/'
filename = path+"2025-07-12_cnmf.hdf5"
Trial_idx_file = path+'Trial_idx1.mat'
Trial_idx = scipy.io.loadmat(Trial_idx_file)['Trial_idx1']
door_open = Trial_idx[:,1]
with h5py.File(filename, "r") as f:
    data = f['estimates']['S_dff'][:]
print(data.shape)
data, door_open_index = generate_continuous_data(data[valid_index_2], door_open, [0,len(door_open)-1], [100,200],0.13,1000,30,0.1)

single_trial_data = []
for i,e in enumerate(door_open_index):
    single_trial_data.append(data[registered_2,e-50:e+50])
trial_averaged_data = np.mean(single_trial_data,axis=0)
seaborn.heatmap(trial_averaged_data[sorted_index])
plt.show()

In [None]:
# Multi alignment
# sessions
Sessions = ['G:/mPFC-3/2025-07-04/','G:/mPFC-3/2025-07-05/','G:/mPFC-3/2025-07-06/','G:/mPFC-3/2025-07-08/',
            'G:/mPFC-3/2025-07-11/','G:/mPFC-3/2025-07-13/']

# template
template_path = 'G:/mPFC-3/2025-07-12/'

# Align points
Points = [[32,167,371,329,18,196,589,696],
          [19,74,362,399,18,489,589,648],
          [24,89,234,287,28,489,594,646],
          [35,219,554,665,65,213,588,720],
          [10,203,602,740,6,196,569,740],
          [48,238,517,636,18,197,751,805]]

# # sessions
# Sessions = ['G:/mPFC-1/2025-05-08/','G:/mPFC-1/2025-05-10/','G:/mPFC-1/2025-05-18/']

# # template
# template_path = 'G:/mPFC-1/2025-05-15/'

# # Align points
# Points = [[7,64,75,186,3,155,213,380],
#           [1,93,359,459,6,173,236,297],
#          [2,35,167,238,3,48,221,284]]

# # sessions
# Sessions = ['G:/mPFC-2/2025-05-08/','G:/mPFC-2/2025-05-18/']

# # template
# template_path = 'G:/mPFC-2/2025-05-15/'

# # Align points
# Points = [[132,232,365,540,145,271,403,575],
#          [21,170,253,403,2,278,385,559]]

# # sessions
# Sessions = ['G:/M2-1/2025-03-26/','G:/M2-1/2025-04-05/']

# # template
# template_path = 'G:/M2-1/2025-04-01/'

# # Align points
# Points = [[37,132,242,204,22,176,375,439],
#           [12,170,251,361,24,177,280,284]]

# Load template
date = template_path[-11:-1]
filename = template_path+date+"_cnmf.hdf5"
pnr = template_path+date+"_pnr.pickle"
centers_2,img_2,valid_index_2 = load_data(template_path,filename,pnr)

valid_indices = []
R2 = []
C21 = []
for path, points in zip(Sessions,Points):
    date = path[-11:-1]
    filename = path+date+"_cnmf.hdf5"
    pnr = path+date+"_pnr.pickle"
    centers_1,img_1,valid_index_1 = load_data(path,filename,pnr)
    valid_indices.append(valid_index_1)
    
    a = points[0]
    b = points[1]
    c = points[2]
    d = points[3]
    e = points[4]
    f = points[5]
    g = points[6]
    h = points[7]
    src_points = np.float32([[centers_1[a,0],centers_1[a,1]], 
                             [centers_1[b,0],centers_1[b,1]], 
                             [centers_1[c,0],centers_1[c,1]],
                             [centers_1[d,0],centers_1[d,1]]])
    dst_points = np.float32([[centers_2[e,0],centers_2[e,1]], 
                             [centers_2[f,0],centers_2[f,1]],
                             [centers_2[g,0],centers_2[g,1]],
                             [centers_2[h,0],centers_2[h,1]]])
    
    new_centers_1 = align(centers_1, src_points, dst_points, img_1)
    registered_1, registered_2, corresponding_12, corresponding_21 = register(new_centers_1, centers_2,img_1,50)
    R2.append(registered_2)
    C21.append(corresponding_21)

In [None]:
# Align all sessions
registered_neurons = np.arange(len(valid_index_2))
for r in R2:
    registered_neurons = np.intersect1d(registered_neurons,r)
print(len(registered_neurons))
Reg_Neu = []
for c in C21:
    reg_neu = []
    for n in registered_neurons:
        reg_neu.append(c[n])
    Reg_Neu.append(reg_neu)
Reg_Neu = np.array(Reg_Neu)

In [None]:
# Rastermap sorting
date = template_path[-11:-1]
filename = template_path+date+"_cnmf.hdf5"
Trial_idx_file = template_path+'Trial_idx1.mat'
Trial_idx = scipy.io.loadmat(Trial_idx_file)['Trial_idx1']
door_open = Trial_idx[:,1]
with h5py.File(filename, "r") as f:
    data = f['estimates']['S_dff'][:]
print(data.shape)
template_data, template_door_open_index = generate_continuous_data(data[valid_index_2], door_open, [0,len(door_open)-1], [100,200],0.13,1000,30,0.1)

single_trial_data = []
for i,e in enumerate(template_door_open_index):
    single_trial_data.append(template_data[registered_neurons,e-50:e+50])
trial_averaged_data = np.mean(single_trial_data,axis=0)
sorted_index = np.argsort(np.argmax(trial_averaged_data,axis=1))

model = Rastermap(locality=0.2, time_lag_window=20).fit(template_data[registered_neurons])
sorted_order = model.isort




Data = []
Event_Index = []
# Average
for idx, path in enumerate(Sessions):
    print(path)
    date = path[-11:-1]
    filename = path+date+"_cnmf.hdf5"
    Trial_idx_file = path+'Trial_idx1.mat'
    Trial_idx = scipy.io.loadmat(Trial_idx_file)['Trial_idx1']
    door_open = Trial_idx[:,1]
    with h5py.File(filename, "r") as f:
        data = f['estimates']['S_dff'][:]
    print(data.shape)
    data, door_open_index = generate_continuous_data(data[valid_indices[idx]], door_open, [0,len(door_open)-1], [100,200],0.13,1000,30,0.1)
    Data.append(data[Reg_Neu[idx]])
    Event_Index.append(door_open_index)
    single_trial_data = []
    for i,e in enumerate(door_open_index):
        single_trial_data.append(data[Reg_Neu[idx],e-50:e+50])
    trial_averaged_data = np.mean(single_trial_data,axis=0)
    # if idx == 0:
    #     sorted_index = np.argsort(np.argmax(trial_averaged_data,axis=1))
        
        # # Rastermap sorting
        # model = Rastermap(locality=0.2, time_lag_window=20).fit(data[Reg_Neu[idx]])
        # sorted_order = model.isort
    seaborn.heatmap(trial_averaged_data[sorted_index])
    plt.show()

    
    
    fig = plt.figure(figsize=(150,2))
    seaborn.heatmap(data[Reg_Neu[idx]][sorted_order],vmin=0,vmax=5,cmap="gray_r")
    for i,e in enumerate(door_open_index):
        plt.plot([e-20,e-20],[0,data.shape[0]], color='blue',linestyle='--', linewidth=0.4)
        plt.plot([e+50,e+50],[0,data.shape[0]], color='grey',linestyle='--', linewidth=0.4)
    plt.show()

# Insert template
# Data.insert(1,template_data[registered_neurons])
# Event_Index.insert(1,template_door_open_index)
Data.insert(5,template_data[registered_neurons])
Event_Index.insert(5,template_door_open_index)

In [None]:
multi_session_cat_data = np.concatenate(Data,axis=1)
print(len(Data))

In [None]:
def PCA(transform_data, transform_event_index, fit_data, fit_event_index, n_components=3,mask_period=[-20,50]):

    mask = np.ones([fit_data.shape[1]],dtype=bool)
    for i in fit_event_index:
        mask[i+mask_period[0]:i+mask_period[1]] = False
        
    reducer = decomposition.PCA(n_components)
    # Transpose
    fit_data = fit_data.T
    transform_data = transform_data.T
    
    # Fit model
    reducer.fit(fit_data[mask])

    # Transform
    embeddings = reducer.transform(transform_data)
    print(reducer.explained_variance_ratio_)

    for i in range(n_components):
        plt.figure(figsize=(150, 5), dpi=80)
        for j in transform_event_index:
            plt.plot([j/10-2,j/10-2], [-5, 5], color='black', linestyle='--')
            plt.plot([j/10+5,j/10+5], [-5, 5], color='grey', linestyle='--')
        plt.plot(np.array(range(len(embeddings[:,i])))/10,embeddings[:,i])
        plt.show()

    return embeddings

In [None]:
transform = 5
fit = 5
embeddings = PCA(Data[transform],Event_Index[transform], Data[fit],Event_Index[fit],10)


In [None]:
plt.figure(figsize=(150, 5), dpi=80)

# Load behavior file
path = 'G:/mPFC-3/2025-07-12/'
behav_file = path+'behav_score.xlsx'
bf = pd.read_excel(behav_file, header=None)
label = np.squeeze(bf.values)
print(label)
plt.plot(embeddings[:,5])

for i,e in enumerate(Event_Index[transform]):
    if label[i] == "ss" or label[i] =="ms":
        plt.plot([e-20,e-20],[-5,5], color='green',linestyle='--')
        plt.plot([e+50,e+50],[-5,5], color='grey',linestyle='--')
    elif label[i] == "sf" or label[i] =="mf":
        plt.plot([e-20,e-20],[-5,5], color='red',linestyle='--')
        plt.plot([e+50,e+50],[-5,5], color='grey',linestyle='--')
    elif label[i] == "n":
        plt.plot([e-20,e-20],[-5,5], color='black',linestyle='--')
        plt.plot([e+50,e+50],[-5,5], color='grey',linestyle='--')
    else:
        plt.plot([e-20,e-20],[-5,5], color='blue',linestyle='--')
        plt.plot([e+50,e+50],[-5,5], color='grey',linestyle='--')

plt.show()