In [1]:
import sys
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import joblib as jl
import cebra.datasets
from cebra import CEBRA
import torch
import scipy.io as sio
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression
import sklearn.metrics
import time



In [2]:
dur = 40
iterations = 5*1000
batch_size = 512
learning_rate = 0.0001
output_dimension = 3
Temp_para = [1] 
name_para = './data_NMR/Fig2_SU/ner/emb_PMd_lr0.0001_itr5k_temp1'
def split_data(neural, continuous_index):
            L = neural.shape[0]
            split_idx = round(L*0.8) 
            neural_train = neural[:split_idx]
            neural_test = neural[split_idx:]
            continuous_index_train = continuous_index[:split_idx]
            continuous_index_test = continuous_index[split_idx:]
            return neural_train,neural_test,continuous_index_train,continuous_index_test
        
angle_to_new_value = {-180: 4,-135: 5,-90: 6,-45: 7,0: 0,45: 1,90: 2,135: 3,180: 4}

directory = "./data/SU_12PMd/"
files = os.listdir(directory)
for file in files:
#     if f_type in file:
    mat_contents = sio.loadmat(os.path.join(directory, file))
    filename_parts = file.split("_neural_con_dis_index")
    new_filename = filename_parts[0] + "_embed_"+str(iterations)+"itr_PMd.npz"
    file_save = os.path.join(directory, new_filename)
    print(file_save)

    neural = mat_contents['neural_PMd']
    continuous_index_XY = mat_contents['continuous_index']
    discrete_index = mat_contents['discrete_index'] ## angles range from -180deg to +180deg
    vectorized_map = np.vectorize(lambda x: angle_to_new_value[x])
    discrete_index = 45*vectorized_map(discrete_index)

    L = neural.shape[0]
    N_values_hist = round(L/5)
    random_indices = np.random.choice(L, size=N_values_hist, replace=False)
    indices_X = continuous_index_XY[random_indices, 0]
    indices_Y = continuous_index_XY[random_indices, 1]
    index_diffs_X = np.abs(indices_X[:, None] - indices_X[None, :]) 
    index_diffs_Y = np.abs(indices_Y[:, None] - indices_Y[None, :])
    l_dist_XY = index_diffs_X + index_diffs_Y
    l_dist_XY_1d = l_dist_XY[~np.eye(N_values_hist, dtype=bool)].flatten()
    print('XY distance>>', np.median(l_dist_XY_1d))
    angles = np.squeeze(discrete_index[random_indices])
    angle_diffs = np.abs(angles[:, None] - angles[None, :]) 
    circular_angle_diffs = np.minimum(angle_diffs, 360 - angle_diffs)
    l_dist_Z_1d = circular_angle_diffs[~np.eye(N_values_hist, dtype=bool)].flatten()
    print('Z distance>>', np.median(l_dist_Z_1d))

    XY_scale = 10
    continuous_index_XY = continuous_index_XY*XY_scale
    #####>>>>>>>>>>>>>>>>>>>>>>>> Only for NMR
#     XYZ_threshold = np.median(l_dist_XY_1d)*XY_scale + np.median(l_dist_Z_1d)
    #####>>>>>>>>>>>>>>>>>>>>>>>> Only for NMR

    for temp in range(len(Temp_para)): 

        continuous_index = np.column_stack((continuous_index_XY, discrete_index))

        neural_train, neural_test, continuous_index_train, \
                    continuous_index_test = split_data(neural, continuous_index)

        #####>>>>>>>>>>>>>>>>>>>>>>>>  Only for NMR
#         L_train = neural_train.shape[0]
#         conr_2para = np.full((L_train,), 0.001)
#         conr_2para[:2] = [XYZ_threshold, Temp_para[temp]]
#         continuous_index_train = np.column_stack((continuous_index_train, conr_2para))
        #####>>>>>>>>>>>>>>>>>>>>>>>> Only for NMR
         
        cebra_veldir_model = CEBRA(model_architecture='offset1-model',
                   batch_size = batch_size,
                   learning_rate = learning_rate,
                   temperature = Temp_para[temp],
                   output_dimension = output_dimension,
                   max_iterations=iterations,
                   distance='cosine',
                   conditional='time_delta',
                   verbose=True,
                   time_offsets=1)
        start_time = time.time()
        cebra_veldir_model.fit(neural_train, continuous_index_train)
        end_time = time.time()
        execution_time = np.round((end_time - start_time), 2)

        cebra_veldir_train = cebra_veldir_model.transform(neural_train)
        cebra_veldir_test  = cebra_veldir_model.transform(neural_test)

        train_loss = cebra_veldir_model.state_dict_['loss']
        X = cebra_veldir_train
        y = continuous_index_train[:,0:2]
        reg_3d = LinearRegression().fit(X, y)       #### 1st fit ####
        pred_vel = reg_3d.predict(X)
        vel_train_r2 = sklearn.metrics.r2_score(y, pred_vel)

        pca = PCA(n_components=2)
        pca_2d = pca.fit(X)                         #### 2nd fit ####
        X_2d = pca_2d.transform(X)
        reg_2d = LinearRegression().fit(X_2d, y)    #### 3rd fit ####
        pred_vel = reg_2d.predict(X_2d)
        vel_train_r2_pca = sklearn.metrics.r2_score(y, pred_vel)
        vel_train_r2_pca = np.round(vel_train_r2_pca, 4)

        print('80% Train Data Temp=', str(Temp_para[temp]), \
              ' r2-3D=', str(np.round(vel_train_r2, 3)), ' r2-2D=', str(vel_train_r2_pca))
        ###************* use previous trained "reg_3d & pca_2d & reg_2d" ###***************
        ###************* use previous trained "reg_3d & pca_2d & reg_2d" ###***************
        X = cebra_veldir_test
        y = continuous_index_test[:,0:2]
        pred_vel = reg_3d.predict(X)
        vel_test_r2 = sklearn.metrics.r2_score(y, pred_vel)

        X_2d = pca_2d.transform(X)
        pred_vel = reg_2d.predict(X_2d)
        vel_test_r2_pca = sklearn.metrics.r2_score(y, pred_vel)
        vel_test_r2_pca = np.round(vel_test_r2_pca, 4)

        print('20% Test  Data Temp=', str(Temp_para[temp]), \
              ' r2-3D=', str(np.round(vel_test_r2, 3)), ' r2-2D=', str(vel_test_r2_pca))

        new_filename = file[:19] + "_Temp_"+str(Temp_para[temp])+ \
            "_iterations_"+str(iterations)+ \
            "_80%train_"+str(vel_train_r2_pca)+ \
            "_20%test_"+str(vel_test_r2_pca)+".npz"
        file_save = os.path.join(name_para,new_filename)
        np.savez(file_save,
                 execution_time = execution_time,
                 temperature = Temp_para[temp],
                 iterations = iterations, 
                 train_loss = train_loss,
                 cebra_veldir_train=cebra_veldir_train,
                 cebra_veldir_test=cebra_veldir_test,
                 continuous_index_train=continuous_index_train,
                 continuous_index_test=continuous_index_test,
                 vel_train_r2 = vel_train_r2,
                 vel_test_r2 = vel_test_r2,
                 vel_train_r2_pca = vel_train_r2_pca,
                 vel_test_r2_pca = vel_test_r2_pca)

./data/SU_12PMd/Mihili_20140303_embed_5000itr_PMd.npz
XY distance>> 12.968032337715222
Z distance>> 90.0


pos: -0.8458 neg:  6.7929 total:  5.7611 temperature:  1.0000: 100%|███████████████| 5000/5000 [18:32<00:00,  4.50it/s]


80% Train Data Temp= 1  r2-3D= 0.522  r2-2D= 0.3211
20% Test  Data Temp= 1  r2-3D= 0.391  r2-2D= 0.2505
./data/SU_12PMd/Chewie_20161007_embed_5000itr_PMd.npz
XY distance>> 12.318830575188901
Z distance>> 90.0


pos: -0.9711 neg:  6.8501 total:  5.6611 temperature:  1.0000: 100%|███████████████| 5000/5000 [19:06<00:00,  4.36it/s]


80% Train Data Temp= 1  r2-3D= 0.857  r2-2D= 0.4018
20% Test  Data Temp= 1  r2-3D= 0.774  r2-2D= 0.3343
./data/SU_12PMd/Mihili_20140306_embed_5000itr_PMd.npz
XY distance>> 13.321437144870522
Z distance>> 90.0


pos: -0.8223 neg:  6.7536 total:  5.7568 temperature:  1.0000: 100%|███████████████| 5000/5000 [19:00<00:00,  4.38it/s]


80% Train Data Temp= 1  r2-3D= 0.618  r2-2D= 0.3467
20% Test  Data Temp= 1  r2-3D= 0.5  r2-2D= 0.3039
./data/SU_12PMd/Mihili_20140218_embed_5000itr_PMd.npz
XY distance>> 12.62045672852825
Z distance>> 90.0


pos: -0.9204 neg:  6.7715 total:  5.6878 temperature:  1.0000: 100%|███████████████| 5000/5000 [18:58<00:00,  4.39it/s]


80% Train Data Temp= 1  r2-3D= 0.745  r2-2D= 0.3735
20% Test  Data Temp= 1  r2-3D= 0.541  r2-2D= 0.2549
./data/SU_12PMd/Chewie_20161014_embed_5000itr_PMd.npz
XY distance>> 12.0972259902899
Z distance>> 90.0


pos: -0.9581 neg:  6.7599 total:  5.6566 temperature:  1.0000: 100%|███████████████| 5000/5000 [19:00<00:00,  4.38it/s]


80% Train Data Temp= 1  r2-3D= 0.857  r2-2D= 0.422
20% Test  Data Temp= 1  r2-3D= 0.81  r2-2D= 0.4167
./data/SU_12PMd/Chewie_20161005_embed_5000itr_PMd.npz
XY distance>> 10.72540773148727
Z distance>> 90.0


pos: -0.9644 neg:  6.8063 total:  5.6513 temperature:  1.0000: 100%|███████████████| 5000/5000 [19:07<00:00,  4.36it/s]


80% Train Data Temp= 1  r2-3D= 0.861  r2-2D= 0.4048
20% Test  Data Temp= 1  r2-3D= 0.776  r2-2D= 0.389
./data/SU_12PMd/Chewie_20161021_embed_5000itr_PMd.npz
XY distance>> 10.371836226334132
Z distance>> 90.0


pos: -0.9703 neg:  6.7680 total:  5.6401 temperature:  1.0000: 100%|███████████████| 5000/5000 [19:02<00:00,  4.38it/s]


80% Train Data Temp= 1  r2-3D= 0.87  r2-2D= 0.4282
20% Test  Data Temp= 1  r2-3D= 0.834  r2-2D= 0.422
./data/SU_12PMd/Chewie_20161006_embed_5000itr_PMd.npz
XY distance>> 11.966517302624434
Z distance>> 90.0


pos: -0.9703 neg:  6.8104 total:  5.6486 temperature:  1.0000: 100%|███████████████| 5000/5000 [19:04<00:00,  4.37it/s]


80% Train Data Temp= 1  r2-3D= 0.871  r2-2D= 0.4202
20% Test  Data Temp= 1  r2-3D= 0.816  r2-2D= 0.4238
./data/SU_12PMd/Chewie_20160929_embed_5000itr_PMd.npz
XY distance>> 12.798675413913003
Z distance>> 90.0


pos: -0.9658 neg:  6.8225 total:  5.6534 temperature:  1.0000: 100%|███████████████| 5000/5000 [19:02<00:00,  4.37it/s]


80% Train Data Temp= 1  r2-3D= 0.868  r2-2D= 0.4438
20% Test  Data Temp= 1  r2-3D= 0.815  r2-2D= 0.4071
./data/SU_12PMd/Mihili_20140217_embed_5000itr_PMd.npz
XY distance>> 13.747443407328184
Z distance>> 90.0


pos: -0.8851 neg:  6.7834 total:  5.6964 temperature:  1.0000: 100%|███████████████| 5000/5000 [19:07<00:00,  4.36it/s]


80% Train Data Temp= 1  r2-3D= 0.727  r2-2D= 0.3907
20% Test  Data Temp= 1  r2-3D= 0.523  r2-2D= 0.2788
./data/SU_12PMd/Mihili_20140304_embed_5000itr_PMd.npz
XY distance>> 13.341387378758153
Z distance>> 90.0


pos: -0.8781 neg:  6.7186 total:  5.7131 temperature:  1.0000: 100%|███████████████| 5000/5000 [19:02<00:00,  4.38it/s]


80% Train Data Temp= 1  r2-3D= 0.703  r2-2D= 0.3562
20% Test  Data Temp= 1  r2-3D= 0.605  r2-2D= 0.3254
./data/SU_12PMd/Mihili_20140307_embed_5000itr_PMd.npz
XY distance>> 11.673527755510241
Z distance>> 90.0


pos: -0.8544 neg:  6.7798 total:  5.7341 temperature:  1.0000: 100%|███████████████| 5000/5000 [18:59<00:00,  4.39it/s]

80% Train Data Temp= 1  r2-3D= 0.639  r2-2D= 0.3513
20% Test  Data Temp= 1  r2-3D= 0.457  r2-2D= 0.3204



