/
7_cebra.py
75 lines (65 loc) · 2.92 KB
/
7_cebra.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import sys
sys.path.append(r'../')
import numpy as np
from functions import *
from cebra import CEBRA
algorithm = 'cebra_h'
### Load Data (and excluding behavioural neurons)
for worm_num in range(5):
b_neurons = [
'AVAR',
'AVAL',
'SMDVR',
'SMDVL',
'SMDDR',
'SMDDL',
'RIBR',
'RIBL',]
data = Database(data_set_no=worm_num)
data.exclude_neurons(b_neurons)
X = data.neuron_traces.T
B = data.states
state_names = ['Dorsal turn', 'Forward', 'No state', 'Reverse-1', 'Reverse-2', 'Sustained reversal', 'Slowing', 'Ventral turn']
### Preprocess and prepare data for BundLe Net
time, X = preprocess_data(X, data.fps)
X_, B_ = prep_data(X, B, win=1)
## Train test split
X_train, X_test, B_train_1, B_test_1 = timeseries_train_test_split(X_, B_)
### Deploy CEBRA hybrid
cebra_hybrid_model = CEBRA(model_architecture='offset10-model',
batch_size=512,
learning_rate=3e-4,
temperature=1,
output_dimension=3,
max_iterations=5000,
distance='cosine',
conditional='time_delta',
device='cuda_if_available',
verbose=True,
time_offsets=10,
hybrid = True)
cebra_hybrid_model.fit(X_train[:,0,0,:], B_train_1.astype(float))
print(worm_num)
### Projecting into latent space
Y0_tr = cebra_hybrid_model.transform(X_train[:,0,0,:])
Y1_tr = cebra_hybrid_model.transform(X_train[:,1,0,:])
Y0_tst = cebra_hybrid_model.transform(X_test[:,0,0,:])
Y1_tst = cebra_hybrid_model.transform(X_test[:,1,0,:])
# Save the weights
# model.save_weights('data/generated/BunDLeNet_model_worm_' + str(worm_num))
np.savetxt('data/generated/saved_Y/Y0_tr__'+algorithm+'_worm_'+ str(worm_num), Y0_tr)
np.savetxt('data/generated/saved_Y/Y1_tr__'+algorithm+'_worm_'+ str(worm_num), Y1_tr)
np.savetxt('data/generated/saved_Y/Y0_tst__'+algorithm+'_worm_'+ str(worm_num), Y0_tst)
np.savetxt('data/generated/saved_Y/Y1_tst__'+algorithm+'_worm_'+ str(worm_num), Y1_tst)
np.savetxt('data/generated/saved_Y/B_train_1__'+algorithm+'_worm_'+ str(worm_num), B_train_1)
np.savetxt('data/generated/saved_Y/B_test_1__'+algorithm+'_worm_'+ str(worm_num), B_test_1)
'''
Y0_tr = np.loadtxt('data/generated/saved_Y/Y0_tr__'+algorithm+'_worm_'+ str(worm_num))
Y1_tr = np.loadtxt('data/generated/saved_Y/Y1_tr__'+algorithm+'_worm_'+ str(worm_num))
Y0_tst = np.loadtxt('data/generated/saved_Y/Y0_tst__'+algorithm+'_worm_'+ str(worm_num))
Y1_tst = np.loadtxt('data/generated/saved_Y/Y1_tst__'+algorithm+'_worm_'+ str(worm_num))
B_train_1 = np.loadtxt('data/generated/saved_Y/B_train_1__'+algorithm+'_worm_'+ str(worm_num)).astype(int)
B_test_1 = np.loadtxt('data/generated/saved_Y/B_test_1__'+algorithm+'_worm_'+ str(worm_num)).astype(int)
plot_phase_space(Y0_tr, B_train_1, state_names)
plot_phase_space(Y0_tst, B_test_1, state_names)
'''