In [None]:
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
from pathlib import Path
import seaborn as sns
from sklearn.metrics import confusion_matrix
import pynapple as nap
from spatial_manifolds.toroidal import *
from spatial_manifolds.behaviour_plots import *
from matplotlib.colors import TwoSlopeNorm
from scipy.spatial import distance
from sklearn.preprocessing import StandardScaler
import umap
from cebra import CEBRA
import cebra.integrations.plotly
from sklearn.neighbors import KNeighborsRegressor, KNeighborsClassifier

from spatial_manifolds.circular_decoder import circular_decoder, cross_validate_decoder, cross_validate_decoder_time, circular_nanmean
from spatial_manifolds.data.curation import curate_clusters
from scipy.stats import zscore
from spatial_manifolds.util import gaussian_filter_nan
from spatial_manifolds.predictive_grid import compute_travel_projected, wrap_list
from spatial_manifolds.behaviour_plots import *
from spatial_manifolds.behaviour_plots import trial_cat_priority
from spatial_manifolds.detect_grids import *
from spatial_manifolds.brainrender_helper import *

import numpy as np
import matplotlib.pyplot as plt
import hdbscan
from sklearn.preprocessing import StandardScaler
import matplotlib.cm as cm
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from IPython.display import HTML

import warnings
warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2
%matplotlib ipympl

In [None]:
fig_path = '/Users/harryclark/Documents/figs/FIGURE1/'
mouse = 29
day = 23

# good examples include 
#mice = [25, 25, 26, 27, 29, 28]
#days = [25, 24, 18, 26, 23, 25]

In [None]:
gcs, ngs, ns, sc, ngs_ns, all = cell_classification_of1(mouse, day, percentile_threshold=99) # subset
rc, rsc, vr_ns = cell_classification_vr(mouse, day)

g_m_ids, g_m_cluster_ids = HDBSCAN_grid_modules(gcs, all, mouse, day, min_cluster_size=3, cluster_selection_epsilon=3, 
                                                figpath=fig_path, curate_with_vr=True, curate_with_brain_region=True) # create grid modules using HDBSCAN    

plot_grid_modules_rate_maps(gcs, g_m_ids, g_m_cluster_ids, mouse, day, figpath=fig_path)

# we now have cluster ids classified into modules, non grid spatial cells and non spatial cells 
# as defined by activity in the open field
g_m_cluster_ids = sorted(g_m_cluster_ids, key=len, reverse=True) 
cluster_ids_by_group = []
cluster_ids_by_group.extend(g_m_cluster_ids) # grid cells by module [0,1,2...]
cluster_ids_by_group.append(ngs.cluster_id.values.tolist()) # non grid spatial [-4]
cluster_ids_by_group.append(ns.cluster_id.values.tolist()) # non spatial cells [-3]
cluster_ids_by_group.append(gcs.cluster_id.values.tolist()) # all grid cells [-2]
cluster_ids_by_group.append(sc.cluster_id.values.tolist()) # speed cells [-1]

for m, cluster_ids in enumerate(cluster_ids_by_group):
    plot_vr_rate_maps(mouse, day, cluster_ids, label=f'{m}', figpath=fig_path)

#plot_vr_rate_maps(mouse, day, rc.cluster_id.values, label=f'ramp_cells', figpath=fig_path)
#plot_vr_rate_maps(mouse, day, rsc.cluster_id.values, label=f'speed_ramp_cells', figpath=fig_path)

In [None]:
plot_stops_mouse_day(mouse, day, figpath=fig_path)

In [None]:
def plot_outbound_homebound_similarity(mouse, day, cluster_id_groups=[cluster_ids_by_group[-2], 
                                                                      cluster_ids_by_group[-4], 
                                                                      vr_ns.cluster_id.values],
                                                                      cluster_id_labels=['GC', 'NGS', 'NS'],
                                                                      figpath=''):
    tcs, tcs_time, autocorrs, last_ephys_bin, beh, clusters = compute_vr_tcs(mouse, day, apply_zscore=False)           
    trial_groups, trial_colors = get_trial_groups_and_colors(beh, last_ephys_bin, tl, bs)
    sorted_trial_indices, sorted_trial_colors = get_sorted_trials_and_colors(beh, last_ephys_bin, tl, bs)
   
    fig, ax = plt.subplots(1,1, figsize=(2.3,2.3))
    colors = ['black', '#707B8F', '#A7A388']
    for i, (cluster_ids, label) in enumerate(zip(cluster_id_groups, cluster_id_labels)):
        outbound_similarities = []
        homebounds_similarities = []
        for id in cluster_ids:
            tc = tcs[id]
            tc = gaussian_filter(np.nan_to_num(tc).astype(np.float64), sigma=2.5)
            tc = tc[:last_ephys_bin] # only want bins with ephys data in it
            tc = zscore(tc)

            x, b_y = get_avg_profile(tc, bs, tl, mask=trial_groups==('rz1bhit'))
            x, nb_y = get_avg_profile(tc, bs, tl, mask=trial_groups==('rz1nbhit'))

            ob =  stats.pearsonr(b_y[:int(0.5*tl/bs)], nb_y[:int(0.5*tl/bs)])[0]
            hb = stats.pearsonr(b_y[int(0.5*tl/bs):], nb_y[int(0.5*tl/bs):])[0]

            outbound_similarities.append(ob)
            homebounds_similarities.append(hb)
        ax.scatter(outbound_similarities, homebounds_similarities, color=colors[i], label=label, alpha=0.7)

    ax.set_xlabel('Outbound similarity')
    ax.set_ylabel('Homebound similarity')
    ax.plot(np.arange(-1,1,0.1), np.arange(-1,1,0.1), linestyle='dashed', color='black')
    ax.set_ylim(-1,1)
    ax.set_xlim(-1,1)
    ax.set_xticks([-1, 0, 1])
    ax.set_yticks([-1, 0, 1])


    fig.legend()
    fig.savefig(f'{figpath}/M{mouse}D{day}_similarity_outbound_homebound.pdf', dpi=300, bbox_inches='tight')
    plt.show()
    
plot_outbound_homebound_similarity(mouse, day, cluster_id_groups=[cluster_ids_by_group[-2], 
                                                                  cluster_ids_by_group[-4], 
                                                                  vr_ns.cluster_id.values],
                                                                  cluster_id_labels=['GC', 'NGS', 'NS'],
                                                                  figpath=fig_path)

In [None]:
'''from sklearn.preprocessing import StandardScaler
import umap
from cebra import CEBRA
import cebra.integrations.plotly
from spatial_manifolds.cebra_helper import encode_1d_to_2d, plot_embeddings
from sklearn.neighbors import KNeighborsRegressor, KNeighborsClassifier
cluster_id_groups=[cluster_ids_by_group[-2], cluster_ids_by_group[-4], vr_ns.cluster_id.values]
cluster_id_labels=['GC', 'NGS', 'NS']

tcs, tcs_time, autocorrs, last_ephys_bin, beh, clusters = compute_vr_tcs(mouse, day, apply_zscore=False)           
trial_groups, trial_colors = get_trial_groups_and_colors(beh, last_ephys_bin, tl, bs)
sorted_trial_indices, sorted_trial_colors = get_sorted_trials_and_colors(beh, last_ephys_bin, tl, bs)
last_ephys_time_bin = clusters[clusters.index[0]].count(bin_size=time_bs, time_units = 'ms').index[-1]
ep = nap.IntervalSet(start=0, end=last_ephys_time_bin, time_units = 's')
speed_in_time = beh['S'].bin_average(bin_size=time_bs, time_units = 'ms', ep=ep)
dt_in_time = beh['travel'].bin_average(bin_size=time_bs, time_units = 'ms', ep=ep)-((beh['trial_number'][0]-1)*tl)
pos_in_time = dt_in_time%tl
time_in_time = np.arange(((time_bs/1000)/2), len(pos_in_time)*(time_bs/1000)+((time_bs/1000)/2), time_bs/1000)
trial_number_in_time = (dt_in_time//tl)+beh['trial_number'][0]
cycl_pos_in_time = encode_1d_to_2d(positions=pos_in_time)
cos_pos_in_time = cycl_pos_in_time[:,0]
sin_pos_in_time = cycl_pos_in_time[:,1]
trial_type_in_time = []
trial_group_in_time = []
for tn in trial_number_in_time:
    trial = beh['trials'][beh['trials']['number'] == tn]
    type = trial['type'].iloc[0]
    performance = trial['performance'].iloc[0]
    context = trial['context'].iloc[0]
    group = f'{context}{type}{performance}'
    trial_type_in_time.append(type)
    trial_group_in_time.append(group)
trial_type_in_time = np.array(trial_type_in_time)
trial_group_in_time = np.array(trial_group_in_time)

# cast b and nb to 0 and 1
trial_type_in_time[trial_type_in_time == 'b'] = '0'
trial_type_in_time[trial_type_in_time == 'nb'] = '1'

all_behaviour = np.stack([pos_in_time, 
                          cos_pos_in_time, 
                          sin_pos_in_time, 
                          trial_type_in_time, 
                          trial_group_in_time, 
                          trial_number_in_time, 
                          time_in_time], axis=0)
all_behaviour = np.transpose(all_behaviour)

train_mask = all_behaviour[:, 4] == 'rz1bhit'
train_mask = all_behaviour[:, 0].astype(np.float64) > 0

test_mask = all_behaviour[:, 4] == 'rz1nbhit'
label_train = all_behaviour[train_mask]
label_test = all_behaviour[test_mask]

for i, (cluster_ids, label) in enumerate(zip(cluster_id_groups, cluster_id_labels)):
    tcs_array = []
    for id in cluster_ids:
        tc = tcs_time[id]
        tc = gaussian_filter(np.nan_to_num(tc).astype(np.float64), sigma=2.5)
        tc = zscore(tc)
        tcs_array.append(tc)
    tcs_array = np.array(tcs_array)

    # Transpose to shape (N, M) for UMAP
    data_transposed = tcs_array.T
    
    max_i = 5000
    dims = 32  # here, we set as a variable for hypothesis testing below.

    # build behaviour models
    position_model = CEBRA(model_architecture='offset10-model',
                           batch_size=512, 
                           learning_rate=3e-4,
                           temperature=10, 
                           output_dimension=dims, 
                           max_iterations=max_i,
                           distance='cosine', 
                           conditional='time_delta',  
                           device='cuda_if_available',
                           verbose=True, time_offsets=1)
    
    position_trial_type_model = CEBRA(model_architecture='offset10-model',
                                batch_size=512, 
                                learning_rate=3e-4,
                                temperature=10, 
                                output_dimension=dims, 
                                max_iterations=max_i,
                                distance='cosine', 
                                conditional='time_delta',  
                                device='cuda_if_available',
                                verbose=True, time_offsets=1)
    
    neural_train = data_transposed[train_mask]
    neural_test = data_transposed[test_mask]

    # train models
    position_model.fit(neural_train, label_train[:, 1:3].astype(float))
    position_trial_type_model.fit(neural_train, label_train[:, 1:4].astype(float))

    fig = plt.figure(figsize=(4, 4))
    ax = plt.subplot(111)
    ax.plot(position_model.state_dict_['loss'], c='red', alpha=0.3, label='position')
    ax.plot(position_trial_type_model.state_dict_['loss'], c='blue', alpha=0.3, label='position+type')
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xlabel('Iterations')
    ax.set_ylabel('InfoNCE Loss')
    plt.legend(bbox_to_anchor=(0.5, 0.3), frameon=False)
    plt.show()

    #plot embeddings
    position_embedding = position_model.transform(neural_train)
    position_type_embedding = position_trial_type_model.transform(neural_train)
    position_embedding_test = position_model.transform(neural_test)
    position_type_embedding_test = position_trial_type_model.transform(neural_test)

    fig = plt.figure(figsize=(8, 8))
    ax1 = fig.add_subplot(2, 2, 1, projection="3d")
    ax2 = fig.add_subplot(2, 2, 2, projection="3d")
    ax3 = fig.add_subplot(2, 2, 3, projection="3d")
    ax4 = fig.add_subplot(2, 2, 4, projection="3d")
    ax1 = plot_embeddings(ax1, position_embedding, pos_in_time[train_mask], cmap="track", viewing_angle=2)
    ax2 = plot_embeddings(ax2, position_type_embedding, pos_in_time[train_mask], cmap="track", viewing_angle=2)
    ax3 = plot_embeddings(ax3, position_embedding_test, pos_in_time[test_mask], cmap="track", viewing_angle=2)
    ax4 = plot_embeddings(ax4, position_type_embedding_test, pos_in_time[test_mask], cmap="track", viewing_angle=2)
    ax1.set_title('Position')
    ax2.set_title('Position+Type')
    plt.show()
'''

In [None]:
'''cluster_id_groups=[cluster_ids_by_group[-2], cluster_ids_by_group[-4], vr_ns.cluster_id.values]
cluster_id_labels=['GC', 'NGS', 'NS']

tcs, tcs_time, autocorrs, last_ephys_bin, beh, clusters = compute_vr_tcs(mouse, day, apply_zscore=False)           
trial_groups, trial_colors = get_trial_groups_and_colors(beh, last_ephys_bin, tl, bs)
sorted_trial_indices, sorted_trial_colors = get_sorted_trials_and_colors(beh, last_ephys_bin, tl, bs)
last_ephys_time_bin = clusters[clusters.index[0]].count(bin_size=time_bs, time_units = 'ms').index[-1]
ep = nap.IntervalSet(start=0, end=last_ephys_time_bin, time_units = 's')
speed_in_time = beh['S'].bin_average(bin_size=time_bs, time_units = 'ms', ep=ep)
dt_in_time = beh['travel'].bin_average(bin_size=time_bs, time_units = 'ms', ep=ep)-((beh['trial_number'][0]-1)*tl)
pos_in_time = dt_in_time%tl
trial_number_in_time = (dt_in_time//tl)+beh['trial_number'][0]
trial_type_in_time = []
trial_group_in_time = []
for tn in trial_number_in_time:
    trial = beh['trials'][beh['trials']['number'] == tn]
    type = trial['type'].iloc[0]
    performance = trial['performance'].iloc[0]
    context = trial['context'].iloc[0]
    group = f'{context}{type}{performance}'
    trial_type_in_time.append(type)
    trial_group_in_time.append(group)
trial_type_in_time = np.array(trial_type_in_time)
trial_group_in_time = np.array(trial_group_in_time)

trial_type_in_time = []
trial_group_in_time = []
for tn in trial_number_in_time:
    trial = beh['trials'][beh['trials']['number'] == tn]
    type = trial['type'].iloc[0]
    performance = trial['performance'].iloc[0]
    context = trial['context'].iloc[0]
    group = f'{context}{type}{performance}'
    trial_type_in_time.append(type)
    trial_group_in_time.append(group)
trial_type_in_time = np.array(trial_type_in_time)
trial_group_in_time = np.array(trial_group_in_time)

moving_mask = speed_in_time > 0
trial_group_in_time = trial_group_in_time[moving_mask]
trial_type_in_time = trial_type_in_time[moving_mask]
pos_in_time = pos_in_time[moving_mask]

b_mask = trial_group_in_time == 'rz1bhit'
nb_mask = trial_group_in_time == 'rz1nbhit'

for i, (cluster_ids, label) in enumerate(zip(cluster_id_groups, cluster_id_labels)):
    tcs_array = []
    for id in cluster_ids:
        tc = tcs_time[id]
        tc = gaussian_filter(np.nan_to_num(tc).astype(np.float64), sigma=2.5)
        tc = zscore(tc)
        tcs_array.append(tc)
    tcs_array = np.array(tcs_array)

    # Transpose to shape (N, M) for UMAP
    data_transposed = tcs_array.T
    data_transposed = data_transposed[moving_mask]

    scaler = StandardScaler()
    data_normalized = scaler.fit_transform(data_transposed)

    reducer = umap.UMAP(n_components=2, random_state=7)
    embedding = reducer.fit_transform(data_normalized)

    # Plot the 3D embedding
    fig = plt.figure(figsize=(6, 6))
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(embedding[:, 0][b_mask], embedding[:, 1][b_mask], pos_in_time[b_mask], alpha=0.05, color='tab:blue', s=5)
    ax.scatter(embedding[:, 0][nb_mask], embedding[:, 1][nb_mask], pos_in_time[nb_mask], alpha=0.05, color='tab:orange', s=5)
    ax.set_title('3D UMAP Projection of Neural Data')
    ax.set_xlabel('UMAP 1')
    ax.set_ylabel('UMAP 2')
    ax.set_zlabel('UMAP 3')
    ax.view_init(elev=30, azim=45)
    plt.show()'''

In [None]:
'''fig, ax = plt.subplots(1,1, figsize=(3,3))

for i in range(len(angles1[trial_groups == 'rz1bhit'])):
    angles_1_i = angles1[trial_groups == 'rz1bhit'][i]
    angles_2_i = angles2[trial_groups == 'rz1bhit'][i]
    diff_angles1 = np.append(np.array([0]), np.diff(angles_1_i))
    diff_angles2 = np.append(np.array([0]), np.diff(angles_2_i))

    # wrapping around correction
    diff_angles1[diff_angles1>(0.66*2*np.pi)] = diff_angles1[diff_angles1>(0.66*2*np.pi)]-(2*np.pi)
    diff_angles1[diff_angles1<(-0.66*2*np.pi)] = diff_angles1[diff_angles1<(-0.66*2*np.pi)]+(2*np.pi)

    diff_angles2[diff_angles2>(0.66*2*np.pi)] = diff_angles2[diff_angles2>(0.66*2*np.pi)]-(2*np.pi)
    diff_angles2[diff_angles2<(-0.66*2*np.pi)] = diff_angles2[diff_angles2<(-0.66*2*np.pi)]+(2*np.pi)
    
    cum_angles1 = np.cumsum(diff_angles1)
    cum_angles2 = np.cumsum(diff_angles2)
    plt.plot(cum_angles1, cum_angles2, color="tab:blue", alpha=0.3)
plt.show()

fig, ax = plt.subplots(1,1, figsize=(3,3))
for i in range(len(angles1[trial_groups == 'rz1nbhit'])):
    angles_1_i = angles1[trial_groups == 'rz1nbhit'][i]
    angles_2_i = angles2[trial_groups == 'rz1nbhit'][i]
    diff_angles1 = np.append(np.array([0]), np.diff(angles_1_i))
    diff_angles2 = np.append(np.array([0]), np.diff(angles_2_i))

    # wrapping around correction
    diff_angles1[diff_angles1>(0.66*2*np.pi)] = diff_angles1[diff_angles1>(0.66*2*np.pi)]-(2*np.pi)
    diff_angles1[diff_angles1<(-0.66*2*np.pi)] = diff_angles1[diff_angles1<(-0.66*2*np.pi)]+(2*np.pi)

    diff_angles2[diff_angles2>(0.66*2*np.pi)] = diff_angles2[diff_angles2>(0.66*2*np.pi)]-(2*np.pi)
    diff_angles2[diff_angles2<(-0.66*2*np.pi)] = diff_angles2[diff_angles2<(-0.66*2*np.pi)]+(2*np.pi)
    
    cum_angles1 = np.cumsum(diff_angles1)
    cum_angles2 = np.cumsum(diff_angles2)
    plt.plot(cum_angles1, cum_angles2, color="tab:orange",alpha=0.3)
plt.show()'''

In [None]:
'''fig, ax = plt.subplots(1,1, figsize=(3,3))

for i in range(len(angles1[trial_groups == 'rz1bhit'])):
    angles_1_i = angles1[trial_groups == 'rz1bhit'][i]+np.pi
    angles_2_i = angles2[trial_groups == 'rz1bhit'][i]+np.pi
    plt.plot(angles_1_i, angles_2_i, color="tab:blue", alpha=0.3)

for i in range(len(angles1[trial_groups == 'rz1nbhit'])):
    angles_1_i = angles1[trial_groups == 'rz1nbhit'][i]+np.pi
    angles_2_i = angles2[trial_groups == 'rz1nbhit'][i]+np.pi
    plt.plot(angles_1_i, angles_2_i, color="tab:orange", alpha=0.3)
plt.show()
'''

In [None]:
'''cum_angles1 = np.append(np.array([0]), np.diff(angles1))
cum_angles2 = np.append(np.array([0]), np.diff(angles1))

plt.hist(cum_angles1,bins=100)
plt.show()

plt.hist(cum_angles2,bins=100)
plt.show()

# wrapping around correction
cum_angles1[cum_angles1>(0.66*2*np.pi)] = cum_angles1[cum_angles1>(0.66*2*np.pi)]-(2*np.pi)
cum_angles1[cum_angles1<(-0.66*2*np.pi)] = cum_angles1[cum_angles1<(-0.66*2*np.pi)]+(2*np.pi)

cum_angles2[cum_angles2>(0.66*2*np.pi)] = cum_angles2[cum_angles2>(0.66*2*np.pi)]-(2*np.pi)
cum_angles2[cum_angles2<(-0.66*2*np.pi)] = cum_angles2[cum_angles2<(-0.66*2*np.pi)]+(2*np.pi)
 
plt.plot(np.cumsum(cum_angles1))
plt.show()'''

In [None]:
plot_spectrogram(mouse, day, cluster_ids=cluster_ids_by_group[0], figpath=fig_path, label="GC")

In [None]:
plot_toroidal_projection(mouse, day, cluster_ids=cluster_ids_by_group[0], figpath=fig_path)

In [None]:
plot_individual_rate_maps_with_avg(mouse, day, cluster_ids=cluster_ids_by_group[0], label='GC', figpath=fig_path)

In [None]:
plot_individual_rate_maps_with_avg(mouse, day, cluster_ids=cluster_ids_by_group[-4], label='NGS', figpath=fig_path)

In [None]:
plot_projected_stops(mouse, day, cluster_ids=cluster_ids_by_group[-4], label="NGS", figpath=fig_path)
plot_decoding(mouse, day, cluster_ids=cluster_ids_by_group[-4], label="NGS", figpath=fig_path)

In [None]:
plot_projected_stops(mouse, day, cluster_ids=cluster_ids_by_group[-2], label="GC", figpath=fig_path)
plot_decoding(mouse, day, cluster_ids=cluster_ids_by_group[-2], label="GC", figpath=fig_path)

In [None]:
#plot_decoding(mouse, day, cluster_ids=np.intersect1d(cluster_ids_by_group[-4], rc.cluster_id.values), label="NGS_ramp_cells", figpath=fig_path)

In [None]:
#plot_decoding(mouse, day, cluster_ids=rc.cluster_id.values, label="ramp_cells", figpath=fig_path)

In [None]:
#compare_decodings(mouse, day, cluster_ids_1=cluster_ids_by_group[-2], 
#                  cluster_ids_2=np.intersect1d(cluster_ids_by_group[-4], rc.cluster_id.values), label1='GC', label2='RC_NGS', figpath=fig_path)

In [None]:
#compare_decodings(mouse, day, cluster_ids_1=cluster_ids_by_group[-2], 
#                  cluster_ids_2=cluster_ids_by_group[-4], label1='GC', label2='NGS', figpath=fig_path)