In [6]:
import sys
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import joblib as jl
import cebra.datasets
from cebra import CEBRA
import torch
import scipy.io as sio
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression
import sklearn.metrics
import time

In [7]:
dur = 40
iterations = 5*1000
batch_size = 512
learning_rate = 0.0001
output_dimension = 3
Temp_para = [1] 
name_para = './data_NMR/Fig2_SU/ner/emb_M1_lr0.0001_itr5k_temp1'
def split_data(neural, continuous_index):
            L = neural.shape[0]
            split_idx = round(L*0.8) 
            neural_train = neural[:split_idx]
            neural_test = neural[split_idx:]
            continuous_index_train = continuous_index[:split_idx]
            continuous_index_test = continuous_index[split_idx:]
            return neural_train,neural_test,continuous_index_train,continuous_index_test
        
angle_to_new_value = {-180: 4,-135: 5,-90: 6,-45: 7,0: 0,45: 1,90: 2,135: 3,180: 4}

directory = "./data/SU_16M1/"
files = os.listdir(directory)
for file in files:
#     if f_type in file:
    mat_contents = sio.loadmat(os.path.join(directory, file))
    filename_parts = file.split("_neural_con_dis_index")
    new_filename = filename_parts[0] + "_embed_"+str(iterations)+"itr_M1.npz"
    file_save = os.path.join(directory, new_filename)
    print(file_save)

    neural = mat_contents['neural_M1']
    continuous_index_XY = mat_contents['continuous_index']
    discrete_index = mat_contents['discrete_index'] ## angles range from -180deg to +180deg
    vectorized_map = np.vectorize(lambda x: angle_to_new_value[x])
    discrete_index = 45*vectorized_map(discrete_index)

    L = neural.shape[0]
    N_values_hist = round(L/5)
    random_indices = np.random.choice(L, size=N_values_hist, replace=False)
    indices_X = continuous_index_XY[random_indices, 0]
    indices_Y = continuous_index_XY[random_indices, 1]
    index_diffs_X = np.abs(indices_X[:, None] - indices_X[None, :]) 
    index_diffs_Y = np.abs(indices_Y[:, None] - indices_Y[None, :])
    l_dist_XY = index_diffs_X + index_diffs_Y
    l_dist_XY_1d = l_dist_XY[~np.eye(N_values_hist, dtype=bool)].flatten()
    print('XY distance>>', np.median(l_dist_XY_1d))
    angles = np.squeeze(discrete_index[random_indices])
    angle_diffs = np.abs(angles[:, None] - angles[None, :]) 
    circular_angle_diffs = np.minimum(angle_diffs, 360 - angle_diffs)
    l_dist_Z_1d = circular_angle_diffs[~np.eye(N_values_hist, dtype=bool)].flatten()
    print('Z distance>>', np.median(l_dist_Z_1d))

    XY_scale = 10
    continuous_index_XY = continuous_index_XY*XY_scale
    #####>>>>>>>>>>>>>>>>>>>>>>>> Only for NMR
#     XYZ_threshold = np.median(l_dist_XY_1d)*XY_scale + np.median(l_dist_Z_1d)
    #####>>>>>>>>>>>>>>>>>>>>>>>> Only for NMR

    for temp in range(len(Temp_para)): 

        continuous_index = np.column_stack((continuous_index_XY, discrete_index))

        neural_train, neural_test, continuous_index_train, \
                    continuous_index_test = split_data(neural, continuous_index)

        #####>>>>>>>>>>>>>>>>>>>>>>>>  Only for NMR
#         L_train = neural_train.shape[0]
#         conr_2para = np.full((L_train,), 0.001)
#         conr_2para[:2] = [XYZ_threshold, Temp_para[temp]]
#         continuous_index_train = np.column_stack((continuous_index_train, conr_2para))
        #####>>>>>>>>>>>>>>>>>>>>>>>> Only for NMR
         
        cebra_veldir_model = CEBRA(model_architecture='offset1-model',
                   batch_size = batch_size,
                   learning_rate = learning_rate,
                   temperature = Temp_para[temp],
                   output_dimension = output_dimension,
                   max_iterations=iterations,
                   distance='cosine',
                   conditional='time_delta',
                   verbose=True,
                   time_offsets=1)
        start_time = time.time()
        cebra_veldir_model.fit(neural_train, continuous_index_train)
        end_time = time.time()
        execution_time = np.round((end_time - start_time), 2)

        cebra_veldir_train = cebra_veldir_model.transform(neural_train)
        cebra_veldir_test  = cebra_veldir_model.transform(neural_test)

        train_loss = cebra_veldir_model.state_dict_['loss']
        X = cebra_veldir_train
        y = continuous_index_train[:,0:2]
        reg_3d = LinearRegression().fit(X, y)       #### 1st fit ####
        pred_vel = reg_3d.predict(X)
        vel_train_r2 = sklearn.metrics.r2_score(y, pred_vel)

        pca = PCA(n_components=2)
        pca_2d = pca.fit(X)                         #### 2nd fit ####
        X_2d = pca_2d.transform(X)
        reg_2d = LinearRegression().fit(X_2d, y)    #### 3rd fit ####
        pred_vel = reg_2d.predict(X_2d)
        vel_train_r2_pca = sklearn.metrics.r2_score(y, pred_vel)
        vel_train_r2_pca = np.round(vel_train_r2_pca, 4)

        print('80% Train Data Temp=', str(Temp_para[temp]), \
              ' r2-3D=', str(np.round(vel_train_r2, 3)), ' r2-2D=', str(vel_train_r2_pca))
        ###************* use previous trained "reg_3d & pca_2d & reg_2d" ###***************
        ###************* use previous trained "reg_3d & pca_2d & reg_2d" ###***************
        X = cebra_veldir_test
        y = continuous_index_test[:,0:2]
        pred_vel = reg_3d.predict(X)
        vel_test_r2 = sklearn.metrics.r2_score(y, pred_vel)

        X_2d = pca_2d.transform(X)
        pred_vel = reg_2d.predict(X_2d)
        vel_test_r2_pca = sklearn.metrics.r2_score(y, pred_vel)
        vel_test_r2_pca = np.round(vel_test_r2_pca, 4)

        print('20% Test  Data Temp=', str(Temp_para[temp]), \
              ' r2-3D=', str(np.round(vel_test_r2, 3)), ' r2-2D=', str(vel_test_r2_pca))

        new_filename = file[:19] + "_Temp_"+str(Temp_para[temp])+ \
            "_iterations_"+str(iterations)+ \
            "_80%train_"+str(vel_train_r2_pca)+ \
            "_20%test_"+str(vel_test_r2_pca)+".npz"
        file_save = os.path.join(name_para,new_filename)
        np.savez(file_save,
                 execution_time = execution_time,
                 temperature = Temp_para[temp],
                 iterations = iterations, 
                 train_loss = train_loss,
                 cebra_veldir_train=cebra_veldir_train,
                 cebra_veldir_test=cebra_veldir_test,
                 continuous_index_train=continuous_index_train,
                 continuous_index_test=continuous_index_test,
                 vel_train_r2 = vel_train_r2,
                 vel_test_r2 = vel_test_r2,
                 vel_train_r2_pca = vel_train_r2_pca,
                 vel_test_r2_pca = vel_test_r2_pca)

./data/SU_16M1/Mihili_20140303_embed_5000itr_M1.npz
XY distance>> 13.456764189734278
Z distance>> 90.0


pos: -0.9286 neg:  6.7454 total:  5.6714 temperature:  1.0000: 100%|█| 5000/5000


80% Train Data Temp= 1  r2-3D= 0.8  r2-2D= 0.4447
20% Test  Data Temp= 1  r2-3D= 0.686  r2-2D= 0.3361
./data/SU_16M1/Chewie_20161007_embed_5000itr_M1.npz
XY distance>> 11.26095975832433
Z distance>> 90.0


pos: -0.9638 neg:  6.7939 total:  5.6555 temperature:  1.0000: 100%|█| 5000/5000


80% Train Data Temp= 1  r2-3D= 0.847  r2-2D= 0.4044
20% Test  Data Temp= 1  r2-3D= 0.799  r2-2D= 0.3581
./data/SU_16M1/Mihili_20140306_embed_5000itr_M1.npz
XY distance>> 13.378499620707554
Z distance>> 90.0


pos: -0.8605 neg:  6.7065 total:  5.7228 temperature:  1.0000: 100%|█| 5000/5000


80% Train Data Temp= 1  r2-3D= 0.692  r2-2D= 0.3567
20% Test  Data Temp= 1  r2-3D= 0.692  r2-2D= 0.3702
./data/SU_16M1/Mihili_20140218_embed_5000itr_M1.npz
XY distance>> 13.806536474763607
Z distance>> 90.0


pos: -0.8921 neg:  6.7769 total:  5.7170 temperature:  1.0000: 100%|█| 5000/5000


80% Train Data Temp= 1  r2-3D= 0.709  r2-2D= 0.3649
20% Test  Data Temp= 1  r2-3D= 0.614  r2-2D= 0.2964
./data/SU_16M1/Chewie_20150319_embed_5000itr_M1.npz
XY distance>> 9.963655842096475
Z distance>> 90.0


pos: -0.9348 neg:  6.8433 total:  5.6678 temperature:  1.0000: 100%|█| 5000/5000


80% Train Data Temp= 1  r2-3D= 0.801  r2-2D= 0.4079
20% Test  Data Temp= 1  r2-3D= 0.771  r2-2D= 0.4056
./data/SU_16M1/Chewie_20150629_embed_5000itr_M1.npz
XY distance>> 10.874213756735433
Z distance>> 90.0


pos: -0.9316 neg:  6.8468 total:  5.6854 temperature:  1.0000: 100%|█| 5000/5000


80% Train Data Temp= 1  r2-3D= 0.737  r2-2D= 0.451
20% Test  Data Temp= 1  r2-3D= 0.633  r2-2D= 0.4161
./data/SU_16M1/Chewie_20150313_embed_5000itr_M1.npz
XY distance>> 10.050386964025328
Z distance>> 90.0


pos: -0.9301 neg:  6.8196 total:  5.6638 temperature:  1.0000: 100%|█| 5000/5000


80% Train Data Temp= 1  r2-3D= 0.812  r2-2D= 0.4264
20% Test  Data Temp= 1  r2-3D= 0.771  r2-2D= 0.4156
./data/SU_16M1/Chewie_20161005_embed_5000itr_M1.npz
XY distance>> 12.36390051950243
Z distance>> 90.0


pos: -0.9618 neg:  6.8193 total:  5.6507 temperature:  1.0000: 100%|█| 5000/5000


80% Train Data Temp= 1  r2-3D= 0.867  r2-2D= 0.4118
20% Test  Data Temp= 1  r2-3D= 0.799  r2-2D= 0.4231
./data/SU_16M1/Chewie_20161021_embed_5000itr_M1.npz
XY distance>> 11.236793899457663
Z distance>> 90.0


pos: -0.9663 neg:  6.8036 total:  5.6511 temperature:  1.0000: 100%|█| 5000/5000


80% Train Data Temp= 1  r2-3D= 0.856  r2-2D= 0.4201
20% Test  Data Temp= 1  r2-3D= 0.84  r2-2D= 0.4395
./data/SU_16M1/Chewie_20161006_embed_5000itr_M1.npz
XY distance>> 12.655270065941487
Z distance>> 90.0


pos: -0.9470 neg:  6.8146 total:  5.6695 temperature:  1.0000: 100%|█| 5000/5000


80% Train Data Temp= 1  r2-3D= 0.849  r2-2D= 0.4226
20% Test  Data Temp= 1  r2-3D= 0.791  r2-2D= 0.4004
./data/SU_16M1/Chewie_20160929_embed_5000itr_M1.npz
XY distance>> 12.648895267062073
Z distance>> 90.0


pos: -0.9417 neg:  6.8053 total:  5.6688 temperature:  1.0000: 100%|█| 5000/5000


80% Train Data Temp= 1  r2-3D= 0.833  r2-2D= 0.4326
20% Test  Data Temp= 1  r2-3D= 0.769  r2-2D= 0.3929
./data/SU_16M1/Chewie_20150630_embed_5000itr_M1.npz
XY distance>> 11.309610571371802
Z distance>> 90.0


pos: -0.9359 neg:  6.8314 total:  5.6796 temperature:  1.0000: 100%|█| 5000/5000


80% Train Data Temp= 1  r2-3D= 0.783  r2-2D= 0.3977
20% Test  Data Temp= 1  r2-3D= 0.751  r2-2D= 0.3347
./data/SU_16M1/Mihili_20140217_embed_5000itr_M1.npz
XY distance>> 12.127720545325664
Z distance>> 90.0


pos: -0.8908 neg:  6.7842 total:  5.7067 temperature:  1.0000: 100%|█| 5000/5000


80% Train Data Temp= 1  r2-3D= 0.723  r2-2D= 0.3888
20% Test  Data Temp= 1  r2-3D= 0.62  r2-2D= 0.3285
./data/SU_16M1/Mihili_20140304_embed_5000itr_M1.npz
XY distance>> 13.772555712077972
Z distance>> 90.0


pos: -0.8878 neg:  6.7185 total:  5.6921 temperature:  1.0000: 100%|█| 5000/5000


80% Train Data Temp= 1  r2-3D= 0.746  r2-2D= 0.3785
20% Test  Data Temp= 1  r2-3D= 0.66  r2-2D= 0.3308
./data/SU_16M1/Mihili_20140307_embed_5000itr_M1.npz
XY distance>> 12.992772956920305
Z distance>> 90.0


pos: -0.8480 neg:  6.7918 total:  5.7480 temperature:  1.0000: 100%|█| 5000/5000

80% Train Data Temp= 1  r2-3D= 0.633  r2-2D= 0.4035
20% Test  Data Temp= 1  r2-3D= 0.595  r2-2D= 0.3632



