In [None]:
%matplotlib ipympl
from mne import read_epochs_fieldtrip
import matplotlib.pyplot as plt
import numpy as np

In [None]:
# suppress RuntimeWarning coming from unrecognized channel in structure
import warnings
warnings.simplefilter("ignore", RuntimeWarning)

# extract data 
exampledata = read_epochs_fieldtrip('testdata.mat', None, data_name='testdata')
exampledata

In [None]:
data = exampledata._get_data()
data = np.swapaxes(data, 0, 1) 
data = np.reshape(data, (data.shape[0], data.shape[1]*data.shape[2]))
plt.figure()
plt.plot(data[0, 0:4048])

In [None]:
# highpass the data
from local_funcs import butter_highpass_filter, butter_lowpass_filter

filt_dat = butter_highpass_filter(data, .1, 256)
filt_dat = butter_lowpass_filter(filt_dat, 100, 256)

plt.figure()
plt.plot(filt_dat[300, :])


In [None]:
# import CEBRA
import cebra
from cebra import CEBRA

In [None]:
# define CEBRA model 
cebra_model = CEBRA(
    model_architecture = "offset10-model",
    batch_size = 1024,
    temperature_mode="auto",
    learning_rate = 0.01,
    max_iterations = 1000,
    time_offsets = 10,
    output_dimension = 7,
    device = "cuda_if_available",
    verbose = True,
    conditional = 'time'
)

In [None]:
# bring the data into a more acceptable numeric range
from sklearn.preprocessing import RobustScaler
dat = RobustScaler().fit_transform(filt_dat.transpose())




In [None]:
# fit model
cebra_model.fit(dat)
# evaluate model
embedding = cebra_model.transform(dat)

In [None]:
# plots
#1 evaluate the three time series
plt.figure()
plt.subplot(311)
plt.plot(embedding[0:2023, 0])
plt.subplot(312)
plt.plot(embedding[0:2023, 1])
#plt.subplot(313)
#plt.plot(embedding[0:2023, 2])

In [None]:
# fancy plots
cebra.plot_embedding(embedding)


In [None]:
# more plotsmax_iterations
cebra.plot_loss(cebra_mdl_beh)

In [None]:
# evaluate correlation of single areas with the extracted embeddings
from local_funcs import get_parcel_corr

w = get_parcel_corr(dat, embedding)



In [None]:
# plot heatmap of weights 
import seaborn as sns

plt.figure()
sns.heatmap(np.abs(w))


In [None]:
from local_funcs import remap2mesh
from nilearn import plotting

w1 = remap2mesh(np.abs(w[:, 0]))
w2 = remap2mesh(np.abs(w[:, 1]))
# w3 = remap2mesh(np.abs(w[:, 2]))


In [None]:
# plot topography of first component
# left hemisphere
plotting.view_surf('../../Resources/S1200.L.very_inflated_MSMAll.32k_fs_LR.surf.gii', w1[:, 0])
# right hemisphere
# plotting.view_surf('../../Resources/S1200.R.very_inflated_MSMAll.32k_fs_LR.surf.gii', w1[:, 1])

In [None]:
# evaluate embeddings' spectra
from scipy.fft import fft, fftfreq

fsample = 1/256
N = embedding.shape[0]

demeaned_ = embedding - embedding.mean(axis=0)

comps_amp = np.abs(fft(embedding, axis=0)[0:N//2, :])**2
xf = fftfreq(N, fsample)[:N//2]

plt.figure()
plt.plot(xf, np.log10(comps_amp))
# plt.plot(xf, np.log10(comps_amp))


# plt.subplot(311)
# plt.plot(xf, np.log10(comps_amp[:, 0]))
# plt.subplot(312)
# plt.plot(xf, comps_amp[:, 1])
# plt.subplot(313)
# plt.plot(xf, comps_amp[:, 2])


# plt.show()


In [None]:
plt.figure()
plt.plot(embedding[:, 0])

In [None]:
# import tsne data
from scipy.io import loadmat

unlabeled_dat = loadmat('TSNE_dat.mat')['dat']
labels = [1]*int(unlabeled_dat.shape[0]/2) + [2]*int(unlabeled_dat.shape[0]/2)
labels = np.array(labels, dtype=int)


In [None]:
# define new CEBRA model to run with behavioral labels

cebra_mdl_beh = CEBRA(
    model_architecture = "offset10-model",
    batch_size = 2048,
    temperature_mode="auto",
    learning_rate = 0.001,
    max_iterations = 1000,
    time_offsets = 10,
    output_dimension = 2,
    device = "cuda_if_available",
    verbose = True,
    conditional="time"
)


In [None]:
# test models with discrete labels

embeddings = cebra_mdl_beh.fit_transform(unlabeled_dat) # , labels

In [None]:
# fancy plots
cebra.plot_embedding(embeddings, embedding_labels=labels)


In [None]:
# manually plot the embeddings
lbl1 = embeddings[0:int(unlabeled_dat.shape[0]/2), :]
lbl2 = embeddings[int(unlabeled_dat.shape[0]/2):, :]

plt.figure()
plt.scatter(lbl1[:, 0], lbl1[:, 1])
plt.scatter(lbl2[:, 0], lbl2[:, 1])



In [None]:
plt.figure()
plt.plot(embeddings[:, 0])
plt.plot(embeddings[:, 1])