In [1]:
from mylib.statistic_test import *
from sklearn.cluster import DBSCAN
from sklearn.decomposition import PCA
from mazepy.datastruc.neuact import SpikeTrain
from mazepy.datastruc.variables import VariableBin

code_id = "0850 - Lisa Paper Revisits"
loc = join(figpath, "Dsp", code_id)
mkdir(loc)


def get_lapwise_ratemap(trace: dict):    
    beg_time, end_time = trace['lap beg time'], trace['lap end time']
    beg_idx = np.array([np.where(trace['correct_time'] >= beg_time[i])[0][0] for i in range(beg_time.shape[0])])
    routes = classify_lap(spike_nodes_transform(trace['correct_nodes'], 12), beg_idx)
    smoothed_map = np.zeros((trace['n_neuron'], 144, beg_idx.shape[0]), dtype = np.float64)
    
    print(np.unique(routes))
    
    for i in tqdm(range(beg_idx.shape[0])):
        
        spike_idx = np.where(
            (trace['ms_time'] >= beg_time[i]) & (trace['ms_time'] <= end_time[i]) &
            (np.isnan(trace['spike_nodes_original']) == False)
        )[0]
        
        spike_nodes = spike_nodes_transform(trace['spike_nodes_original'][spike_idx].astype(np.int64), 12)-1
        Spikes = trace['Spikes_original'][:, spike_idx]
        
        spike_train = SpikeTrain(
            activity=Spikes,
            time=trace['ms_time'][spike_idx],
            variable=VariableBin(spike_nodes),
        )
        
        rate_map = spike_train.calc_tuning_curve(144, t_interv_limits=100)
        smoothed_map[:, :, i] = rate_map.to_array() #@ trace['Ms'].T
    
    bins = CP_DSP[3]-1
    smoothed_map = smoothed_map[:, bins, :]

    X = np.transpose(smoothed_map, (2, 1, 0))
    return X

def get_all_mice_data(mouse):
    file_indices = np.where(f2['MiceID'] == mouse)[0]
    cellreg_index = np.where(f_CellReg_dsp['MiceID'] == mouse)[0][0]
    
    traces = []
    for i in file_indices:
        with open(f2['Trace File'][i], 'rb') as handle:
            trace = pickle.load(handle)
            traces.append(trace)

    with open(f_CellReg_dsp['cellreg_folder'][cellreg_index], 'rb') as handle:
        index_map = pickle.load(handle).astype(np.int64)
            
    if mouse != 10232:
        index_map = index_map[1:, :]
    
    # Identify neurons that appear in all sessions
    ncell = np.where(np.sum(np.where(index_map == 0, 0, 1), axis=0) == index_map.shape[0])[0]
    
    session_label = []
    route_label = []
    Xs = []
    centroid_init0 = []
    centroid_init1 = []
    for i, trace in enumerate(traces):
        X = get_lapwise_ratemap(trace)
        
        cent1 = trace[f'node 3']['old_map_clear'][index_map[i, ncell]-1, :][:, CP_DSP[3]-1].T.flatten()
        cent0 = np.vstack([trace[f'node {n}']['old_map_clear'][index_map[i, ncell]-1, :][:, CP_DSP[3]-1].T.flatten() for n in [0, 4, 5, 9]])
        cent0 = np.mean(cent0, axis=0)
        
        centroid_init0.append(cent0)
        centroid_init1.append(cent1)
        
        beg_time, end_time = trace['lap beg time'], trace['lap end time']
        beg_idx = np.array([np.where(trace['correct_time'] >= beg_time[i])[0][0] for i in range(beg_time.shape[0])])
        routes = classify_lap(spike_nodes_transform(trace['correct_nodes'], 12), beg_idx)
        Xs.append(X[:, :, index_map[i, ncell]-1])
        
        session_label.append(np.ones(X.shape[0])*i)
        route_label.append(routes)
        
    centroid_init0 = np.mean(np.vstack(centroid_init0), axis=0)
    centroid_init1 = np.mean(np.vstack(centroid_init1), axis=0)
    
    return np.concatenate(Xs, axis=0), np.concatenate(session_label), np.concatenate(route_label), np.vstack([centroid_init0, centroid_init1])

        E:\Data\FinalResults\Dsp\0850 - Lisa Paper Revisits is already existed!


In [2]:
from sklearn.cluster import KMeans

def fit_kmeans(X, R: int, kmeans_init=None, is_return_model: bool = False):
    """
    Cluster Maps with KMeans model.
    
    Parameters
    ----------
    X : np.ndarray
        The entire map of this cell within this session.
        shape: (I x J x K) tensor of normalized firing rates
        
        I: Trials
        J: Spatial bins
        K: Neurons
    
    Returns
    -------
    U, V
    """
    X_wrap = np.reshape(X, (X.shape[0], X.shape[1]*X.shape[2]))
    
    if kmeans_init is not None:
        kmean = KMeans(n_clusters=R, init=kmeans_init)
    else:
        kmean = KMeans(n_clusters=R)
        
    kmean.fit(X_wrap)
    
    U = np.zeros((X.shape[0], R))
    for i in range(X.shape[0]):
        U[i, kmean.labels_[i]] = 1
        
    V = kmean.cluster_centers_
    
    if R >= 2:
        nclusters = np.sum(U, axis=0)
        if nclusters[0] < nclusters[1]:
            U = U[:, [1, 0]]
            V = V[[1, 0], :]
    
    if is_return_model:
        return U, V, kmean
    else:
        return U, V

def fit_pca(X, n_components: int, kmeans_init=None):
    """
    Cluster Maps with KMeans model.
    
    Parameters
    ----------
    X : np.ndarray
        The entire map of this cell within this session.
        shape: (I x J x K) tensor of normalized firing rates
        
        I: Trials
        J: Spatial bins
        K: Neurons
        
    n_components: int
        Number of components to keep
    
    Returns
    -------
    reduced_X: np.ndarray
        The reduced map, shape: (I x n_components)
    """
    X_wrap = np.reshape(X, (X.shape[0], X.shape[1]*X.shape[2]))
    
    pca = PCA(n_components=n_components)
    
    if kmeans_init is None:
        return pca.fit_transform(X_wrap)
    else:
        pca.fit(X_wrap)
        return pca.transform(X_wrap), pca.transform(kmeans_init)

def get_lap_distances(X, V):
    """
    Refer to 'Distance to cluster calculations' Section of Lisa paper.
    
    Parameters
    ----------
    X: ndarray (I x J x K) tensor of normalized firing rates
        I: Trials/Laps
        J: Position Bins
        K: Neurons
    V: ndarray, low-dimensional factors (R x JK)
        R: Rank, which is 2 here
    
    Returns
    -------
    Distances: ndarray, distances to cluster centers (I x R), for each position
    in each lap.
    """
    I, J, K = X.shape
    V = V.reshape(2, J, K)
    
    Distances = np.zeros(I)
    
    for i in range(I):
        Distances[i] = (
            np.sum(
                (2 * X[i, :, :] - (V[0, :, :] + V[1, :, :])) * 
                (V[0, :, :] - V[1, :, :])
            ) / 
            np.sum((V[0, :, :] - V[1, :, :])**2)
        )
            
    return Distances

def get_lap_neuron_distances(X, V):
    """
    Refer to 'Distance to cluster calculations' Section of Lisa paper.
    
    Parameters
    ----------
    X: ndarray (I x J x K) tensor of normalized firing rates
        I: Trials/Laps
        J: Position Bins
        K: Neurons
    V: ndarray, low-dimensional factors (R x JK)
        R: Rank, which is 2 here
    
    Returns
    -------
    Distances: ndarray, distances to cluster centers (I x R), for each neuron
    in each lap.
    and 
    """
    I, J, K = X.shape
    V = V.reshape(2, J, K)
    
    Distances = np.zeros((I, K))
    
    for i in range(I):
        for k in range(K):
            Distances[i, k] = (
                np.sum(
                    (2 * X[i, :, k] - (V[0, :, k] + V[1, :, k])) * 
                    (V[0, :, k] - V[1, :, k])
                ) / 
                np.sum((V[0, :, k] - V[1, :, k])**2)
            )
            
    return Distances

def get_lap_pos_distances(X, P, V, bins_template: np.ndarray = CP_DSP[3]):
    """
    Refer to 'Distance to cluster calculations' Section of Lisa paper.
    
    Parameters
    ----------
    X: ndarray (N x T) matrix of firing rates
        N: Neurons
        T: Time bins
    P: ndarray (T, ) array of positions, corresponding to each time bin.
    
    V: ndarray, low-dimensional factors (R x JK)
        R: Rank, which is 2 here
    
    Returns
    -------
    Distances: ndarray, distances to cluster centers (I x R), for each position
    in each lap.
    """
    I, J, K = X.shape
    V = V.reshape(2, J, K)
    
    Distances = np.zeros((I, J))
    
    for i in range(I):
        for j in range(J):
            Distances[i, j] = (
                np.sum(
                    (2 * X[i, j, :] - (V[0, j, :] + V[1, j, :])) * 
                    (V[0, j, :] - V[1, j, :])
                ) / 
                np.sum((V[0, j, :] - V[1, j, :])**2)
            )
            
    return Distances


def plot_lapwise_distances(X, session_label, route_label, mouse, kmeans_init = None):
    U, V = fit_kmeans(X, 2, kmeans_init=kmeans_init)
    print(np.where(V<0)[0].shape[0])
    distances = get_lap_distances(X, V)

    y_max = 2
    fig, axes = plt.subplots(nrows=1, ncols=7, figsize=(7, 4))
    for i in range(7):
        ax = Clear_Axes(axes[i], close_spines=['left', 'top', 'right'], ifxticks=True)
        idx = np.where(session_label == i)[0]
        ax.plot(
            distances[idx],
            np.arange(idx.shape[0]), 
            linewidth=0.5,
            color='#B5B5B6' #gray
        )
        idx0 = np.where(
            distances[idx] > 0
        )[0]
        idx1 = np.where(
            distances[idx] < 0
        )[0]
        ax.plot( 
            distances[idx][idx0], 
            np.arange(idx.shape[0])[idx0],
            's',
            markeredgewidth=0,
            markersize=2.5,
            linewidth=0.5,
            color='#28306E'
        )
        ax.plot(
            distances[idx][idx1], 
            np.arange(idx.shape[0])[idx1],
            's',
            markeredgewidth=0,
            markersize=2.5,
            linewidth=0.5,
            color='#A2C78D'
        )
    
        for j in range(7):
            idxr = np.where(route_label[idx] == j)[0]
            ax.plot(
                np.repeat(2.4, idxr.shape[0]),
                np.arange(idx.shape[0])[idxr], 
                's',
                markeredgewidth=0,
                markersize=3,
                color=DSPPalette[j]
            )
    
        dl = np.where(np.ediff1d(route_label[idx]) != 0)[0] + 0.5
        for d in dl:
            ax.axhline(d, color='k', linewidth=0.1, ls='--')
    
        ax.axvline(0, color='k', linewidth=0.1, ls='--')
        ax.set_xlim([y_max+0.5, -y_max])
        ax.set_xticks(np.linspace(-2, 2, 5)) 
    plt.savefig(join(loc, f"DistanceToCluster {mouse}.png"), dpi = 600)
    plt.savefig(join(loc, f"DistanceToCluster {mouse}.svg"), dpi = 600)
    plt.show()

def hex_to_rgba(hex_color):
    """
    Convert a hex color (#RRGGBB or #RRGGBBAA) to RGBA format (0-255).
    """
    hex_color = hex_color.lstrip('#')  # Remove '#' if present
    if len(hex_color) == 6:
        r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
        a = 255  # Default alpha
    elif len(hex_color) == 8:
        r, g, b, a = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16), int(hex_color[6:8], 16)
    else:
        raise ValueError("Invalid hex color format. Use #RRGGBB or #RRGGBBAA.")
    return r, g, b, a

def hex_to_rgba_normalized(hex_color):
    """
    Convert a hex color (#RRGGBB or #RRGGBBAA) to RGBA format (0-1).
    """
    r, g, b, a = hex_to_rgba(hex_color)
    return np.array([r / 255, g / 255, b / 255, a / 255])

DSPPaletteRGBA = np.vstack([hex_to_rgba_normalized(c) for c in DSPPalette])
DayPaletteRGBA = np.asarray(sns.color_palette("rainbow", 7))
MAPPaletteRGBA = np.vstack([hex_to_rgba_normalized(c) for c in ['#333766', '#A4C096']])

def plot_pca_clusters(X, session_label, route_label, mouse, map_identity, kmeans_init = None):
    if kmeans_init is None:
        reduced_X = fit_pca(X, 2)
    else:
        reduced_X, reduced_centroids = fit_pca(X, 2, kmeans_init=kmeans_init)
        
    fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(3, 9))
    ax0, ax1, ax2 = (
        Clear_Axes(axes[0], close_spines=['top', 'right']), 
        Clear_Axes(axes[1], close_spines=['top', 'right']), 
        Clear_Axes(axes[2], close_spines=['top', 'right'])
    )
    ax0.scatter(
        reduced_X[:, 0],
        reduced_X[:, 1],
        s=5,
        color=DSPPaletteRGBA[route_label, :],
        alpha=0.9
    )
    ax1.scatter(
        reduced_X[:, 0],
        reduced_X[:, 1],
        s=5,
        alpha=0.9,
        color=DayPaletteRGBA[session_label.astype(np.int64), :]
    )
    
    if kmeans_init is not None:
        ax0.scatter(
            reduced_centroids[:, 0],
            reduced_centroids[:, 1],
            s=10,
            color=MAPPaletteRGBA
        )
        ax1.scatter(
            reduced_centroids[:, 0],
            reduced_centroids[:, 1],
            s=10,
            color=MAPPaletteRGBA
        )
        
    ax0.set_aspect("equal")
    ax1.set_aspect("equal")
    # Blues correspond to Session 0 while reds correspond to Session 6
    ax2.scatter(
        reduced_X[:, 0],
        reduced_X[:, 1],
        s=5,
        color=MAPPaletteRGBA[map_identity.astype(np.int64), :],
        alpha=0.9
    )
    ax2.set_aspect("equal")
    plt.savefig(join(loc, f"PCA Clusters {mouse}.png"), dpi = 600)
    plt.savefig(join(loc, f"PCA Clusters {mouse}.svg"), dpi = 600)
    plt.show()

# Identify if Two Maps demonstrate data best

In [None]:
def select_best_rank(X, max_rank: int, kmeans_init='k-means++'):
    """
    Uses BIC to select the best rank R for low-rank matrix factorization.

    Parameters:
        X: ndarray (I x J x K) tensor of normalized firing rates
        k_clusters: int, number of clusters
        max_rank: int, maximum rank R to test

    Returns:
        best_R: int, the rank R with the lowest BIC score
        BIC_scores: list, BIC scores for each tested rank
    """
    I, J, K = X.shape
    n = I  # Number of samples (rows in U_continuous)
    BIC = np.zeros(max_rank)
    
    XS = cp.deepcopy(X)
    
    for i in range(X.shape[2]):
        sups = np.percentile(X[:, :, i].flatten(), 95)
        infs = np.min(X[:, :, i])
        if infs == sups:
            continue
        
        XS[:, :, i] = np.clip(
            (X[:, :, i] - infs) / (sups - infs), 
            a_min=0, 
            a_max=1
        )
        
    temp_kmeans = np.reshape(kmeans_init, (2, XS.shape[1], XS.shape[2]))
    for i in range(XS.shape[2]):
        valmax = np.max(temp_kmeans[0, :, i])
        valmin = np.min(temp_kmeans[0, :, i])
        if valmax > 0:
            temp_kmeans[0, :, i] = (temp_kmeans[0, :, i] - valmin) / (valmax - valmin)

        valmax = np.max(temp_kmeans[1, :, i])
        valmin = np.min(temp_kmeans[1, :, i])
        if valmax > 0:
            temp_kmeans[1, :, i] = (temp_kmeans[1, :, i] - valmin) / (valmax - valmin)

    temp_kmeans = np.reshape(temp_kmeans, (temp_kmeans.shape[0], temp_kmeans.shape[1] * temp_kmeans.shape[2]))

    for rank_R in range(1, max_rank + 1):
        if rank_R == 2:
            # Perform low-rank factorization
            U, V, kmeans = fit_kmeans(XS, rank_R, kmeans_init=temp_kmeans, is_return_model=True)
        else:
            # Perform low-rank factorization
            U, V, kmeans = fit_kmeans(XS, rank_R, is_return_model=True)
        
        X_wrap = np.reshape(XS, (XS.shape[0], XS.shape[1]*XS.shape[2]))
        # Number of data points and features
        n, d = X_wrap.shape
    
        # Number of clusters
        k = kmeans.n_clusters
    
        # Within-cluster sum of squares (inertia)
        WCSS = kmeans.inertia_
        
        # Estimate variance
        variance = WCSS / (n * d)
        
        
        # Log-likelihood
        log_likelihood = -0.5 * n * d * np.log(2 * np.pi * variance) - 0.5 * WCSS / variance
    
        # Number of free parameters
        num_params = k * d  # k clusters with d dimensions
    
        # Compute BIC
        BIC[rank_R - 1] = -2 * log_likelihood + num_params * np.log(n)
        print(f"R: {rank_R} BIC: {BIC[rank_R - 1]:.2f}\n log-likelihood: {log_likelihood:.2f}, num_params: {num_params}, Variance: {variance:.2f}, WCSS: {WCSS:.2f}, n: {n}, d: {d}")
    print()
    return BIC

if exists(join(figdata, f"{code_id}  [BIC].pkl")) == False:
    BICData = {
        "MiceID": [],
        "X": [],
        "BIC": []
    }
    
    for m, mouse in enumerate([10212, 10224, 10227, 10232]):
        X, session_label, route_label, kmeans_init = get_all_mice_data(mouse)
        BIC = select_best_rank(X, 5, kmeans_init=kmeans_init)
        BICData["MiceID"].append(np.repeat(mouse, 5))
        BICData["X"].append(np.arange(1, 6))
        BICData["BIC"].append(BIC)
        
        U, V = fit_kmeans(X, 2, kmeans_init=kmeans_init)
        plot_lapwise_distances(X, session_label, route_label, mouse, kmeans_init=kmeans_init)
        plot_pca_clusters(X, session_label, route_label, mouse, U[:, 1])
        
    for k in BICData.keys():
        BICData[k] = np.concatenate(BICData[k])
        
    with open(join(figdata, f"{code_id} [BIC].pkl"), "wb") as f:
        pickle.dump(BICData, f)
        
    BICD = pd.DataFrame(BICData)
    BICD.to_excel(join(figdata, f"{code_id} [BIC].xlsx"), index = False)
else:
    with open(join(figdata, f"{code_id} [BIC].pkl"), "rb") as f:
        BICData = pickle.load(f)
        
fig = plt.figure(figsize=(2, 3))
ax = Clear_Axes(plt.axes(), close_spines=['top', 'right'], ifxticks=True, ifyticks=True)
sns.stripplot(
    x = 'X',
    y='BIC',
    data=BICData,
    hue='MiceID',
    palette = ['#F2E8D4', '#8E9F85', '#C3AED6', '#A7D8DE'],
    size = 4,
    linewidth=0.2,
    jitter=0.2,
    ax = ax,
    zorder=1
)
sns.barplot(
    x = 'X',
    y='BIC',
    data=BICData,
    ax = ax,
    zorder=2,
    capsize=0.5,
    linewidth=0.5,
    err_kws={"color": 'k', 'linewidth': 0.5}
)
#plt.savefig(join(loc, "BIC.png"), dpi=600)
#plt.savefig(join(loc, "BIC.svg"), dpi=600)
plt.show()

idx1 = np.where(BICData['X'] == 1)[0]
idx2 = np.where(BICData['X'] == 2)[0]
idx3 = np.where(BICData['X'] == 3)[0]
print_estimator(BICData['BIC'][idx1])
print_estimator(BICData['BIC'][idx2])
print_estimator(BICData['BIC'][idx3])

print(ttest_rel(BICData['BIC'][idx1], BICData['BIC'][idx2]))
print(ttest_rel(BICData['BIC'][idx3], BICData['BIC'][idx2]))

[0 1 2 3 4 5 6]


  firing_rate = spike_counts/(occu_time/1000)
100%|██████████| 52/52 [00:00<00:00, 262.08it/s]


[0 1 2 3 4 5 6]


  0%|          | 0/60 [00:00<?, ?it/s]

In [None]:
""