In [1]:
"""
This notebook runs the CEBRA analysis using a "Switch/Stay" label for the dynamic foraging task. Here "switch" refers to whether or not the mouse switched sides in a given trial.
"""

'\nThis notebook runs the CEBRA analysis using a "Switch/Stay" label for the dynamic foraging task. Here "switch" refers to whether or not the mouse switched sides in a given trial.\n'

In [1]:
import sys
import os # my addtion

import numpy as np
import itertools
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from mpl_toolkits.mplot3d import Axes3D
from scipy.integrate import solve_ivp
import cebra.data
import torch
import cebra.integrations
import cebra.datasets
from cebra import CEBRA
import torch
import pickle
import cebra_pack.cebra_utils as cp


from matplotlib.collections import LineCollection
import pandas as pd



## A. Load the Data

Here we load data from the Fibre Photometry pipeline of 4 Neuromodulators (DA, 5HT, ACh, NE) recorded in the Nucleus Acumbens region. The main neural data will be in the form of dF_F traces of these 4 Neuromodulators (NMs). These will be stored in a 2D array, 'all_nms'.

In [2]:
# load the dataframe that contains data from 1 session
df_trials_ses = pickle.load(open('../data/CO data/df.pkl', "rb"))

In [3]:
df_trials_ses.columns

Index(['bit_code', 'ses_idx', 'rpe', 'left_action_value', 'right_action_value',
       'licks L', 'licks R', 'Lick L (raw)', 'Lick R (raw)', 'trial', 'reward',
       'choice', 'go_cue_absolute_time', 'go_cue', 'choice_time',
       'reward_time', 'onset', 'NM', 'NM_name', 'region', 'last_value_NM',
       'overlap_index', 'NM_no_overlap', 'bins_mids', 'bins_mids_no_overlap'],
      dtype='object')

In [4]:
n_trials = 1765

In [5]:
# download the dictionary containing the traces
traces = pickle.load(open('../data/CO data/traces.pkl', "rb"))

In [6]:
# load the trace times
trace_times = np.load('../data/CO data/Trace times.npy', allow_pickle=True)

In [7]:
# get the choice time 
choice_times = df_trials_ses['choice_time'][0:n_trials].to_numpy()

In [8]:
# Combine the traces into one 2D array
all_nms = np.array([traces[trace] for trace in traces.keys()])
all_nms = np.transpose(all_nms)

# change it to an array of floats (previously it was an array of object datatype)
all_nms_new = all_nms.astype(np.float64)
all_nms_new.shape


(218572, 4)

In [9]:
all_nms.shape

(218572, 4)

In [10]:
# change it to an array of floats (previously it was an array of object datatype)
all_nms_new = all_nms.astype(np.float64)
all_nms_new.shape

(218572, 4)

In [11]:
# convert it to a tensor (this is probably not necessary but we want it to be as close to the inputs in the previous notebook)
all_nms_tensor = torch.from_numpy(all_nms_new)
all_nms_tensor.shape

torch.Size([218572, 4])

## B. Format data and create the behavioural/auxiliary variables

Now let's format the data. We want to view the data in a 1 second window around the choice time at each trial in the session. The hope is that this will make it easy to identify the trials where it chose to lick left and those where it chose to lick right.

Each trial will be labelled as rewarded/unrewarded and this will be the behavioural variable we use for this analysis.

In [12]:
ins = np.argwhere(df_trials_ses['licks R'] > df_trials_ses['licks L'])

In [13]:
df_trials_ses['licks L'].to_numpy()

array([26., 10.,  2., ...,  2.,  1.,  1.])

For the trials where the number of left and right trials are equal, the number of licks in either direction is either 0,1 or 2.

In [14]:
# Make a function to format the NM data into a 1s window around the choice
def format_data(neural_data, df, trace_times_, choice_times_ , window=None , window_size=10):

    # define the number of trials where the mouse made a choice
    n_choice_trials = np.unique(np.isnan(choice_times_),return_counts=True)[1][0]

    # define total number of trials
    n_total_trials = np.sum(np.unique(np.isnan(choice_times_),return_counts=True)[1])

    # list to hold all the 1s windows
    n_data_window = []

    # new labels
    reward_labels = []
    choice_labels = []
    rpe_labels = []
    n_licks = []


    # loop over all trials
    for i in range(0,n_total_trials):

        # skip trials where the animal didn't make a choice (null choice time)
        if np.isnan(choice_times_[i]):
            continue

        # find the index of the closest time to the choice time in the trace_times array 
        idx = np.abs(trace_times_ - choice_times_[i]).argmin()

        # take the previous 10 and/or the next 10 values of the NM data at these indices - 1s window
        if window =='before':
            n_data_window.append(neural_data[idx-10:idx])

        if window == 'after':
            n_data_window.append(neural_data[idx:idx+10])

        if window == None:
            n_data_window.append(neural_data[idx-10:idx+10])

        # label the timepoints as rewarded or unrewarded
        if df['reward'].iloc[i]:
            # new trial label
            reward_labels.append(1)

        elif df['reward'].iloc[i]==False:
            # new trial label
            reward_labels.append(0)
        
        # label the timepoints as left or right choice
        if df['licks L'].iloc[i] >= df['licks R'].iloc[i]:
            # new trial label
            choice_labels.append(1)
            n_licks.append(df['licks L'].iloc[i])

        elif df['licks R'].iloc[i] > df['licks L'].iloc[i]:
            # new trial label
            choice_labels.append(0)
            n_licks.append(df['licks R'].iloc[i])

        # get the rpe values at each trial
        rpe_labels.append(df['rpe'].iloc[i])

    # stack the nm data for each trial
    nms_HD = np.stack(n_data_window).reshape((n_choice_trials,-1))
    # format it into a tensor
    nms_HD = torch.from_numpy(nms_HD.astype(np.float64))
    print("neural tensor shape: ", nms_HD.shape)

    # convert trial labels into an array
    reward_labels = np.array(reward_labels)
    print("reward labels shape: ",reward_labels.shape)

    choice_labels = np.array(choice_labels)
    print("choice labels shape: ",choice_labels.shape)

    # convert rpe labels to arrays
    rpe_labels = np.array(rpe_labels)
    print("rpe labels shape:", rpe_labels.shape)


    return nms_HD, reward_labels, choice_labels, rpe_labels, n_licks

In [15]:
formatted_nms, reward_labels, choice_labels, rpe_labels, n_licks = format_data(all_nms,df=df_trials_ses,trace_times_=trace_times, choice_times_=choice_times)

neural tensor shape:  torch.Size([1717, 80])
reward labels shape:  (1717,)
choice labels shape:  (1717,)
rpe labels shape: (1717,)


In [16]:
# define function to take the choice labels and make a 'Switch' label

def make_switch_label(choice_label):

    # make sure input is in array form
    assert type(choice_label)==np.ndarray

    switch_labels = []

    for i in range(0,choice_label.shape[0]):

        # should I just skip this first one?
        if i==0:
            switch_labels.append(0)
            continue

        # make switch label based on previous trial
        if choice_label[i]!=choice_label[i-1]:
            switch_labels.append(1)        
        
        elif choice_label[i]==choice_label[i-1]:
            switch_labels.append(0)

    switch_labels = np.array(switch_labels)
    print('Switch labels shape:', switch_labels.shape)

    return switch_labels


In [17]:
switch_labels = make_switch_label(choice_label=choice_labels)

Switch labels shape: (1717,)


## C. Build and train the CEBRA models

In [18]:
# set the maximum number of iterations for training the model
max_iterations = 2000

In [19]:
# build a CEBRA-Time and CEBRA-Behaviour model
cebra_time_model = CEBRA(model_architecture='offset10-model-mse',
                        batch_size=512,
                        learning_rate=3e-4,
                        temperature=1,
                        output_dimension=3,
                        max_iterations=max_iterations,
                        distance='euclidean',
                        conditional='time',
                        device='cuda_if_available',
                        verbose=True,
                        time_offsets=10)

In [20]:
cebra_behaviour_model = CEBRA(model_architecture='offset10-model-mse',
                        batch_size=512,
                        learning_rate=3e-4,
                        temperature=1,
                        output_dimension=3,
                        max_iterations=max_iterations,
                        distance='euclidean',
                        conditional='time_delta',
                        device='cuda_if_available',
                        verbose=True,
                        time_offsets=10)

Train the two models

In [21]:
# train the time model (no labels here)
cebra_time_model.fit(formatted_nms)

  return F.conv1d(input, weight, bias, self.stride,
pos:  0.8830 neg:  2.9669 total:  3.8499 temperature:  1.0000: 100%|██████████| 2000/2000 [00:23<00:00, 85.06it/s] 


In [22]:
# train the behaviour model (use the labels here)
cebra_behaviour_model.fit(formatted_nms, switch_labels)

pos:  0.1595 neg:  6.0460 total:  6.2055 temperature:  1.0000: 100%|██████████| 2000/2000 [00:16<00:00, 124.53it/s]


## D. Compute and view embeddings

Here, we compute the embeddings from the two trained models and then plot them.

In [23]:
time_embedding = cebra_time_model.transform(formatted_nms)

In [24]:
behaviour_embedding = cebra_behaviour_model.transform(formatted_nms)

In [25]:
# divide the labels into right and left
switch = switch_labels==1
stay = switch_labels==0

switch = switch.flatten()
stay = stay.flatten()

In [26]:
# get auc scores
mean_scores, errors = cp.get_auc([behaviour_embedding, time_embedding], trial_labels=switch_labels)
mean_scores

[0.6963423649035996, 0.5]

In [27]:
np.round(mean_scores,2)

array([0.7, 0.5])

In [28]:
# create a figure and make the plots
fig1 = plt.figure(figsize=(16,4))
gs = gridspec.GridSpec(1, 2, figure=fig1)

ax1 = fig1.add_subplot(gs[0,0], projection='3d')
ax2 = fig1.add_subplot(gs[0,1], projection='3d')
axes =[ax1,ax2]

for ax in axes:


    ax.set_xlabel("latent 1", labelpad=0.01)
    ax.set_ylabel("latent 2", labelpad=0.01)
    ax.set_zlabel("latent 3", labelpad=0.01)

    # Hide X and Y axes label marks
    ax.xaxis.set_tick_params(labelbottom=False)
    ax.yaxis.set_tick_params(labelleft=False)
    ax.zaxis.set_tick_params(labelright=False)

    # Hide X and Y axes tick marks
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zticks([])


# colour maps
colours = ['cool', 'plasma']

# Aucs rounded
mean_scores = np.round(mean_scores,2)

# plot the time embedding 
cebra.plot_embedding(embedding=time_embedding[switch,:], embedding_labels=switch_labels[switch],ax=ax1, markersize=2, title='Time embedding', cmap=colours[0])
cebra.plot_embedding(embedding=time_embedding[stay,:], embedding_labels=switch_labels[stay],ax=ax1, markersize=2, title='Time embedding, AUC:{}'.format(mean_scores[1]), cmap=colours[1])

# plot the behaviour embedding 
cebra.plot_embedding(embedding=behaviour_embedding[switch,:], embedding_labels=switch_labels[switch],ax=ax2, markersize=2, title='Behaviour embedding', cmap=colours[0])
cebra.plot_embedding(embedding=behaviour_embedding[stay,:], embedding_labels=switch_labels[stay],ax=ax2,markersize=2, title='Behaviour embedding, AUC: {}'.format(mean_scores[0]),  cmap=colours[1])


# Adjust the subplot layout manually
#plt.subplots_adjust(left=0.095, right=0.1, top=0.95, bottom=0.05, wspace=0.001)

# Adjust the subplot layout manually
plt.subplots_adjust(left=0.00001, right=0.55, top=0.95, bottom=0.05, wspace=0.0001)

# Use tight_layout with padding to ensure labels are not cut off


# Adjust label positions using bbox
ax2.set_zlabel("latent 3", labelpad=1, fontsize=10, bbox=dict(facecolor='white', edgecolor='none', pad=0.5))


Text(0.5, 0, 'latent 3')

In [29]:
np.unique(switch_labels, return_counts=True)

(array([0, 1]), array([1441,  276]))

In [30]:
# build, train and compute with the time and behaviour models with this new labels
t_embed,b_embed =  cp.build_train_compute(formatted_nms,switch_labels)

pos:  0.0592 neg:  5.4446 total:  5.5038 temperature:  1.0000: 100%|██████████| 2000/2000 [00:14<00:00, 136.60it/s]
pos:  0.2666 neg:  5.8733 total:  6.1399 temperature:  1.0000: 100%|██████████| 2000/2000 [00:17<00:00, 114.56it/s]


In [31]:
# define a function to view the embeddings
def view(time_embedding, behaviour_embedding, labels, label_classes, size=5):
 
    # create a figure and make the plots
    fig = plt.figure(figsize=(14,8))
    gs = gridspec.GridSpec(1, 2, figure=fig)


    ax81 = fig.add_subplot(gs[0,0], projection='3d')
    ax82 = fig.add_subplot(gs[0,1], projection='3d')
 

    # colour maps
    colours = ['cool', 'plasma', 'spring']

    # plot the time embedding 
    cebra.plot_embedding(embedding=time_embedding[label_classes[1],:], embedding_labels=labels[label_classes[1]],ax=ax81, markersize=size, title='Time embedding', cmap=colours[1])
    cebra.plot_embedding(embedding=time_embedding[label_classes[0],:], embedding_labels=labels[label_classes[0]],ax=ax81, markersize=size, title='Time embedding', cmap=colours[0])


    # plot the behaviour embedding 
    cebra.plot_embedding(embedding=behaviour_embedding[label_classes[1],:], embedding_labels=labels[label_classes[1]],ax=ax82, markersize=size, title='Behaviour embedding', cmap=colours[1])
    cebra.plot_embedding(embedding=behaviour_embedding[label_classes[0],:], embedding_labels=labels[label_classes[0]],ax=ax82,markersize=size, title='Behaviour embedding',  cmap=colours[0])

    gs.tight_layout(figure=fig)

In [32]:
#%matplotlib inline
view(t_embed, b_embed,switch_labels, [switch,stay])

try this with the infoNCE setting -- maybe the behaviour embedding will make sense then

In [33]:
individual_nms = cp.individual_datasets(traces_=traces)

shape of formatted array: (218572, 1)
shape of formatted array: (218572, 1)
shape of formatted array: (218572, 1)
shape of formatted array: (218572, 1)


In [34]:
b_embeds, t_embeds, sw_labels, [switch, stay] = cp.nm_analysis_2(individual_nms, df_trials_ses, trace_times, choice_times, title=" ",other_label=switch_labels)

neural tensor shape:  torch.Size([1717, 20])
reward labels shape:  (1717,)
choice labels shape:  (1717,)
rpe labels shape: (1717,)


pos:  0.1838 neg:  5.6258 total:  5.8096 temperature:  1.0000: 100%|██████████| 2000/2000 [00:14<00:00, 142.23it/s]
pos:  0.0987 neg:  6.1227 total:  6.2214 temperature:  1.0000: 100%|██████████| 2000/2000 [00:17<00:00, 117.58it/s]


COMPLETED ANALYSIS OF NM 0: 
neural tensor shape:  torch.Size([1717, 20])
reward labels shape:  (1717,)
choice labels shape:  (1717,)
rpe labels shape: (1717,)


pos:  0.1095 neg:  5.4901 total:  5.5995 temperature:  1.0000: 100%|██████████| 2000/2000 [00:13<00:00, 145.55it/s]
pos:  0.1308 neg:  6.1011 total:  6.2319 temperature:  1.0000: 100%|██████████| 2000/2000 [00:17<00:00, 114.91it/s]


COMPLETED ANALYSIS OF NM 1: 
neural tensor shape:  torch.Size([1717, 20])
reward labels shape:  (1717,)
choice labels shape:  (1717,)
rpe labels shape: (1717,)


pos:  0.3448 neg:  5.4783 total:  5.8231 temperature:  1.0000: 100%|██████████| 2000/2000 [00:13<00:00, 147.89it/s]
pos:  0.0841 neg:  6.1592 total:  6.2433 temperature:  1.0000: 100%|██████████| 2000/2000 [00:16<00:00, 118.85it/s]


COMPLETED ANALYSIS OF NM 2: 
neural tensor shape:  torch.Size([1717, 20])
reward labels shape:  (1717,)
choice labels shape:  (1717,)
rpe labels shape: (1717,)


pos:  0.1471 neg:  5.4739 total:  5.6210 temperature:  1.0000: 100%|██████████| 2000/2000 [00:15<00:00, 129.77it/s]
pos:  0.0001 neg:  6.2382 total:  6.2383 temperature:  1.0000: 100%|██████████| 2000/2000 [00:15<00:00, 131.48it/s]

COMPLETED ANALYSIS OF NM 3: 





In [35]:
len(b_embeds)

4

In [36]:
means, sds = cp.get_auc(b_embeds, sw_labels)

In [37]:
means

[0.5794763097285499, 0.6290430860211809, 0.546714238300697, 0.5]

In [38]:
# first make function to make the plots given a list of embeddings
def plot4_embeddings(embeddings, labels , l_class, means=None, titles=['DA only', 'NE only', '5HT only', 'ACh only'], t=""):

    # number of plots
    n_plots = len(embeddings)

    n_columns = 2
    n_rows = n_plots//n_columns

    # create axis
    fig = plt.figure(figsize=(8,4*n_plots))
    gs = gridspec.GridSpec(n_rows, n_columns, figure=fig)

    # colour 
    c = ['cool','plasma','pink','winter']

    for i, embed in enumerate(embeddings):

        # create the axes
        ax = fig.add_subplot(gs[i // n_columns, i%n_columns], projection='3d')

        ax.set_xlabel("latent 1", labelpad=0.001, fontsize=13)
        ax.set_ylabel("latent 2", labelpad=0.001, fontsize=13)
        ax.set_zlabel("latent 3", labelpad=0.001, fontsize=13)

        # Hide X and Y axes label marks
        ax.xaxis.set_tick_params(labelbottom=False)
        ax.yaxis.set_tick_params(labelleft=False)
        ax.zaxis.set_tick_params(labelright=False)

        # Hide X and Y axes tick marks
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_zticks([])


        if means.any():
            titles=['DA only, AUC:{}'.format(means[0]),'NE only, AUC:{}'.format(means[1]), '5HT only, AUC:{}'.format(means[2]), 'ACh only, AUC:{}'.format(means[3])]


        # plot the embedding
        cebra.plot_embedding(embedding=embed[l_class[0],:], embedding_labels=labels[l_class[0]], ax=ax, markersize=2,title=titles[i], cmap=c[0])
        cebra.plot_embedding(embedding=embed[l_class[1],:], embedding_labels=labels[l_class[1]], ax=ax, markersize=2,title=titles[i], cmap=c[1])

    plt.suptitle(t, fontsize=15)
    plt.tight_layout()

In [39]:
np.round(means,2)

array([0.58, 0.63, 0.55, 0.5 ])

In [40]:
#%matplotlib inline

plot4_embeddings(b_embeds, sw_labels, [switch, stay], means=np.round(means,2))