In [11]:
import os, re
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression,LogisticRegression
import sklearn.metrics
from scipy.linalg import orthogonal_procrustes
from itertools import permutations, combinations
from sklearn.decomposition import PCA
dur = 40
emb_dim = 3
N_angles = 8

# directory = './data/Fig2_SU/nmr/emb_M1_lr001_itr5k_temp07/'
# file_save = './data/Fig3_SU_Decode/nmr_M1.npz'
# name_range = slice(33, 41)

directory = './data_NMR/Fig2_SU/nmr/emb_M1_lr001_itr5k_temp07/'
file_save = './data_NMR/Fig3_SU_Decode/nmr_Chewie.npz'
name_range = slice(56, 64)

# directory = './data_NMR/Fig2_SU/ner/emb_M1_lr0.0001_itr5k_temp1/'
# file_save = './data_NMR/Fig3_SU_Decode/NER_M1.npz'
# name_range = slice(33, 41)

# directory = './data/Fig2_SU/ceb/emb_M1_10k/'
# file_save = './data/Fig3_SU_Decode/ceb_M1.npz'
# name_range = slice(33, 41)

# directory = './data/Fig2_SU/piv/emb_M1_run3/'
# file_save = './data/Fig3_SU_Decode/piv_M1.npz'
# name_range = slice(38, 46)

# directory = './data/Fig2_SU/nmr/emb_PMd/'
# file_save = './data/Fig3_SU_Decode/nmr_PMd.npz'
# name_range = slice(34, 42)

# directory = './data_NMR/Fig2_SU/ner/emb_PMd_lr0.0001_itr5k_temp1/'
# file_save = './data_NMR/Fig3_SU_Decode/NER_PMd.npz'
# name_range = slice(34, 42)

# directory = './data/Fig2_SU/ceb/emb_PMd/'
# file_save = './data/Fig3_SU_Decode/ceb_PMd.npz'
# name_range = slice(34, 42)

# directory = './data/Fig2_SU/piv/emb_PMd_itr60/'
# file_save = './data/Fig3_SU_Decode/piv_PMd.npz'
# name_range = slice(40, 48)

def get_best_R(R_all, emb_A, emb_A_8angle_align):
    determinants = [np.linalg.det(R_all[:, :, i]) for i in range(R_all.shape[2])]
    positive_dets = [det for det in determinants if det >= 0]
    negative_dets = [det for det in determinants if det < 0]

    if len(positive_dets)>0:
        target_dets = positive_dets
        differences = [abs(abs(det) - 1) for det in target_dets]
        min_index = np.argmin(differences)
        best_R_index_p = determinants.index(positive_dets[min_index])
        best_R_p = R_all[:, :, best_R_index_p]
        emb_A_whole_align_p = np.matmul(emb_A, best_R_p)
        align_diff_p = np.sum(abs(emb_A_whole_align_p-emb_A_8angle_align))
        ## print('diff positive detR=', align_diff_p)
    elif len(positive_dets) == 0:
        align_diff_p = 5000000 ### arbitory value
        
    if len(negative_dets)>0:
        target_dets = negative_dets
        differences = [abs(abs(det) - 1) for det in target_dets]
        min_index = np.argmin(differences)
        best_R_index_n = determinants.index(negative_dets[min_index])
        best_R_n = R_all[:, :, best_R_index_n]
        emb_A_whole_align_n = np.matmul(emb_A, best_R_n)
        align_diff_n = np.sum(abs(emb_A_whole_align_n-emb_A_8angle_align))
        ## print('diff negative detR=', align_diff_n)
    elif len(negative_dets) == 0:
        align_diff_n = 5000000
        
    if align_diff_p<align_diff_n:
        best_R = best_R_p
        ## print('Using positive R')
    elif align_diff_p>align_diff_n:
        best_R = best_R_n
        ## print('Using negative R')
    return best_R


def cross_decode(file_path1, file_path2):
    Monkey_A = np.load(file_path1)
    XYTarget_A = np.concatenate((Monkey_A['continuous_index_train'][:, :3], \
                                 Monkey_A['continuous_index_test']), axis=0)
    # print('XYTarget_A>>', XYTarget_A.shape)
    emb_A = np.concatenate((Monkey_A['cebra_veldir_train'], Monkey_A['cebra_veldir_test']), axis=0)
    if np.max(XYTarget_A[:, 2])>10: ### angles in 0-45-90-...315degrees
        XYTarget_A[:, 2] = XYTarget_A[:, 2]/45
        
    Monkey_B = np.load(file_path2)
    XYTarget_B = np.concatenate((Monkey_B['continuous_index_train'][:, :3], \
                                 Monkey_B['continuous_index_test']), axis=0)
    # print('XYTarget_B>>', XYTarget_B.shape)
    emb_B = np.concatenate((Monkey_B['cebra_veldir_train'], Monkey_B['cebra_veldir_test']), axis=0)
    if np.max(XYTarget_B[:, 2])>10:
        XYTarget_B[:, 2] = XYTarget_B[:, 2]/45
    
    train_trial_A = int(Monkey_A['continuous_index_train'][:, :3].shape[0]/dur)
    test_trial_A = int(Monkey_A['continuous_index_test'].shape[0]/dur)
    train_trial_B = int(Monkey_B['continuous_index_train'][:, :3].shape[0]/dur)
    test_trial_B = int(Monkey_B['continuous_index_test'].shape[0]/dur)
    
    R_all = np.zeros((emb_dim, emb_dim, N_angles))
    for a in range(N_angles):
        direction_trial = (XYTarget_A[:, 2] == a)
        trial_avg_A = emb_A[direction_trial, :].reshape(-1,dur,emb_dim).mean(axis=0)
        direction_trial = (XYTarget_B[:, 2] == a)
        trial_avg_B = emb_B[direction_trial, :].reshape(-1,dur,emb_dim).mean(axis=0)
        R, sca = orthogonal_procrustes(trial_avg_A, trial_avg_B) ### both are (dur, 3emb-dim)
        R_all[:,:, a] = R
        det_R = np.linalg.det(R)
    trial_arrays = []
    for i in range(N_angles):
        direction_trial = (XYTarget_A[:, 2] == i)
        trial_A = emb_A[direction_trial, :].reshape(-1,dur,emb_dim)
        trial_A = np.matmul(trial_A, R_all[:,:,i])
        trial_arrays.append((direction_trial, trial_A))
    emb_A_8angle_align = np.empty_like(emb_A)
    for mask, trial_data in trial_arrays: ### loop-through 8 times=angles
        flat_data = trial_data.reshape(-1, emb_dim) ### (n-trials*dur, 3emb-dim)
        emb_A_8angle_align[mask, :] = flat_data 
     
    emb_A_whole_align = np.matmul(emb_A, get_best_R(R_all, emb_A, emb_A_8angle_align))
    
    continuous_index_train = XYTarget_A[:train_trial_A*dur, :]
    cebra_veldir_train = emb_A_whole_align[:train_trial_A*dur, :] ####***** three choices here *****####
    continuous_index_test_B = XYTarget_B[-test_trial_B*dur:, :]
    cebra_veldir_test_B = emb_B[-test_trial_B*dur:, :]
    
    X = cebra_veldir_train
    y = continuous_index_train[:, 0:2]
    reg_3d = LinearRegression().fit(X, y)       #### 1st fit ####
    pred_vel = reg_3d.predict(X)
    vel_train_r2 = sklearn.metrics.r2_score(y, pred_vel)

    pca = PCA(n_components=2)
    pca_2d = pca.fit(X)                         #### 2nd fit ####
    X_2d = pca_2d.transform(X)
    reg_2d = LinearRegression().fit(X_2d, y)    #### 3rd fit ####
    
    ###******** this part will use previous trained "reg & LogisticReg" ###********
    ###******** this part will use previous trained "reg & LogisticReg" ###********
    
    X = cebra_veldir_test_B
    y = continuous_index_test_B[:, 0:2]
    pred_vel = reg_3d.predict(X)
    vel_test_r2_3d = sklearn.metrics.r2_score(y, pred_vel)

    X_2d = pca_2d.transform(X)
    pred_vel = reg_2d.predict(X_2d)
    vel_test_r2_2d = sklearn.metrics.r2_score(y, pred_vel)
    # print("Cross vel 2d >>", np.round(vel_test_r2_2d, 4))
    return vel_test_r2_3d, vel_test_r2_2d

def self_decode(file_path1):
    Monkey_A = np.load(file_path1)
    X = Monkey_A['cebra_veldir_train']
    y = Monkey_A['continuous_index_train'][:,0:2]
    # print('Self y>>', y.shape)
    reg_3d = LinearRegression().fit(X, y)       #### 1st fit ####
    pred_vel = reg_3d.predict(X)
    vel_train_r2 = sklearn.metrics.r2_score(y, pred_vel)

    pca = PCA(n_components=2)
    pca_2d = pca.fit(X)                         #### 2nd fit ####
    X_2d = pca_2d.transform(X)
    reg_2d = LinearRegression().fit(X_2d, y)    #### 3rd fit ####
   
    ###************* use previous trained "reg_3d & pca_2d & reg_2d" ###***************
    ###************* use previous trained "reg_3d & pca_2d & reg_2d" ###***************
    X = Monkey_A['cebra_veldir_test']
    y = Monkey_A['continuous_index_test'][:,0:2]
    pred_vel = reg_3d.predict(X)
    vel_test_r2_3d = sklearn.metrics.r2_score(y, pred_vel)

    X_2d = pca_2d.transform(X)
    pred_vel = reg_2d.predict(X_2d)
    vel_test_r2_2d = sklearn.metrics.r2_score(y, pred_vel)
    # print("Self vel 2d >>", np.round(vel_test_r2_2d, 4))
    return vel_test_r2_3d, vel_test_r2_2d
            
files = [os.path.join(directory, f) for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
n = len(files)
vel_R_3D = np.zeros((n, n))
vel_R_2D = np.zeros((n, n))
date_subjects = []
n_compare = 0

def list_and_sort_files(directory):
    files = [os.path.join(directory, f) for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
    def extract_date(filename):
        match = re.search(r'(\d{8})', os.path.basename(filename))
        date = match.group(0) if match else '000000'  # Default to '000000' if no date is found
        return int(date) if len(date) == 6 else int(date)
    sorted_files = sorted(files, key=extract_date)
    return sorted_files
sorted_files=list_and_sort_files(directory)

for i, file1 in enumerate(sorted_files):
    # print("Reading file:", file1[30:46])
    for j, file2 in enumerate(sorted_files):
        if i != j:    ### with-others
            vel_test_3d, vel_test_2d = cross_decode(file1, file2)
            # print('#'+str(n_compare+1)+' cross compare')
        elif i == j:  ### with-itself
            vel_test_3d, vel_test_2d = self_decode(file1)
            # print('#'+str(n_compare+1)+' self compare')
        vel_R_3D[i, j] = vel_test_3d
        vel_R_2D[i, j] = vel_test_2d

        if "M1PMd" in directory:
            date = file1[-29:-23]
            suffix = file1[-7:-5]
            date_subjects.append(f"{date}{suffix}")  
        elif "M1PMd" not in directory:
            # print(file1[name_range])
            date_subjects.append(file1[name_range])
        n_compare = n_compare+1
print('label of date:', np.unique(date_subjects))

np.savez(file_save, date_subjects = date_subjects, vel_R_3D=vel_R_3D, vel_R_2D=vel_R_2D)

label of date: ['20150313' '20150319' '20150629' '20150630' '20160929' '20161005'
 '20161006' '20161007' '20161014' '20161021']
