In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
import numpy as np
import pandas as pd
import xarray as xr
import pingouin as pg
import plotly.express as px
from plotly.express.colors import n_colors
import plotly.graph_objects as go
from os.path import join as pjoin
from natsort import natsort_keygen, natsorted
from scipy.stats import zscore

sys.path.append("../")
import circletrack_behavior as ctb
import plotting_functions as pf
import pca_ica as ica

In [None]:
## Settings
project_folder = ['CircleTrack_Updating']
experiment_folders = ['Reversal2']
dpath = f'../../{project_folder[0]}'
fig_path = f'../../../Manuscripts/MemoryUpdating/intermediate_plots'
chance_color = 'darkgrey'
avg_color = 'midnightblue'
subject_color = 'darkgrey'
fading_color = 'red'
two_colors = ['darkorchid', 'midnightblue']
error_color = ['rgba(153,50,204,0.4)', 'rgba(25,25,112,0.4)'] #'rgba(169,169,169,0.4)'
session_list = ['A1', 'A2', 'A3', 'A4', 'A5', 'AU1', 'AU2', 'AU3', 'AU4', 'AR1', 'AR2']
old_new_list = ['AU1', 'AU2', 'AU3', 'AU4', 'AR1', 'AR2']
data_of_interest = 'ica_ensembles'

if not os.path.exists(fig_path):
    os.makedirs(fig_path)

xr.set_options(keep_attrs=True)

### Example ensemble membership plots.

In [None]:
mouse = 'rv04'
sess = 6
session = f'{mouse}_ensembles_{sess}.nc'
s_session = f'{mouse}_S_{sess}.nc'
experiment = experiment_folders[0]
ens_id = 4

exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
mpath = pjoin(exp_path, f'{mouse}/{session}')
spath = pjoin(dpath, f'{experiment}/output/aligned_minian/{mouse}/S/{s_session}')

assemblies = xr.open_dataset(mpath)
patterns = assemblies['patterns']
act = assemblies['activations']

S = xr.open_dataset(spath)['S']

In [None]:
lin_pos = 'lin_position'
fig = pf.custom_graph_template(x_title='', y_title='', rows=3, height=600, width=800, shared_x=True, font_size=18)
## Linear position in first row with rewards
fig.add_trace(go.Scatter(x=S['behav_t'], y=S[lin_pos], line_color='darkgrey', showlegend=False), row=1, col=1)
fig.add_trace(go.Scatter(x=S['behav_t'][S['lick_port'] != -1], y=S[lin_pos][S['lick_port'] != -1], mode='markers', marker_color='red', marker_size=2, showlegend=False))

## Average population activity in second row
avg_act = zscore(np.mean(zscore(S.values, axis=1), axis=0))
fig.add_trace(go.Scatter(x=S['behav_t'], y=avg_act, line_color='black', showlegend=False), row=2, col=1)

## Raster of population activity in third row
sbin = (S.values > 0).astype(int)
fig.add_trace(go.Heatmap(x=S['behav_t'], y=S['unit_id'], z=sbin, colorscale='gray_r', showscale=False), row=3, col=1)

fig.update_yaxes(title='Position (rad)', row=1)
fig.update_yaxes(title='Population Activity', row=2)
fig.update_yaxes(title='Neuron', row=3)
fig.update_xaxes(title='Time (s)', row=3)
fig.show()
fig.write_image(pjoin(fig_path, f'{mouse}_triplot.png'))

In [None]:
## Membership weight into an example ensemble
fig = pf.stem_plot(patterns[ens_id], plot_members=True, x_title='Neuron', y_title='Weight', member_color='darkturquoise')
fig.show()
fig.write_image(pjoin(fig_path, f'{mouse}_{sess}_emerging_ensemble_membership_ensid{ens_id}.png'))

In [None]:
## Activation strength of the same ensemble
z_scored = True
fig = pf.custom_graph_template(x_title='Trial', y_title='Z-Scored Activation Strength')
trial_act = np.zeros(int(np.nanmax(act['trials'].values))+1)
x_axis = np.arange(1, int(np.nanmax(act['trials'].values))+1)
for trial in np.unique(act['trials']):
    if pd.isna(trial):
        pass
    else:
        if z_scored:
            zdata = zscore(act.values, axis=1)
        else:
            zdata = act.copy()
        tdata = zdata[ens_id, act['trials'] == trial]
        trial_act[int(trial)] = np.mean(tdata)
fig.add_trace(go.Scatter(x=x_axis, y=trial_act, mode='markers', marker_color='darkturquoise', showlegend=False))
fig.show()
fig.write_image(pjoin(fig_path, f'{mouse}_{sess}_activation_strength_bytrial_ensid{ens_id}.png'))

In [None]:
ens_dict = {'mouse': [], 'day': [], 'session': [], 'ens_id': [], 'trend': []}

for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
        for mouse in os.listdir(exp_path):
            print(mouse)
            mpath = pjoin(exp_path, f'{mouse}')
            for idx, session in enumerate(os.listdir(mpath)):
                if (mouse == 'rv07') & (idx > 2):
                    idx += 1

                if (mouse == 'rv07') & (idx > 7):
                    idx += 1
                assemblies = xr.open_dataset(pjoin(mpath, session))
                act = assemblies['activations']

                trends, binned_act, slopes, tau = ica.define_ensemble_trends_across_trials(act.values, act)

                for key in trends:
                    for ens in trends[key]:
                        ens_dict['mouse'].append(mouse)
                        ens_dict['day'].append(idx+1)
                        ens_dict['session'].append(assemblies.attrs['session_two'])
                        ens_dict['ens_id'].append(ens)
                        ens_dict['trend'].append(key)
ens_df = pd.DataFrame(ens_dict)

In [None]:
prop_dict = {'mouse': [], 'day': [], 'session': [], 'trend': [], 'proportion': []}
for mouse in ens_df['mouse'].unique():
    mdata = ens_df[ens_df['mouse'] == mouse]
    for idx, session in enumerate(mdata['session'].unique()):
        sdata = mdata[mdata['session'] == session].reset_index(drop=True)
        g = sdata.groupby(['session', 'trend'], as_index=False).size()
        g['size'] = g['size'] / sdata.shape[0]
        for _, row in g.iterrows():
            prop_dict['mouse'].append(mouse)
            prop_dict['day'].append(idx+1)
            prop_dict['session'].append(session)
            prop_dict['trend'].append(row['trend'])
            prop_dict['proportion'].append(row['size'])
prop_df = pd.DataFrame(prop_dict)
avg_prop = prop_df.groupby(['session', 'trend'], as_index=False).agg({'proportion': ['mean', 'sem']})

In [None]:
## Plot proportion of fading ensembles on each day.
fading = avg_prop[avg_prop['trend'] == 'decreasing']
fig = pf.custom_graph_template(x_title='', y_title='Proportion Fading Ensembles')
fig.add_trace(go.Bar(x=fading['session'], y=fading['proportion']['mean'], marker_color=fading_color, showlegend=False,
                     marker_line_color='black', marker_line_width=2, error_y=dict(type='data', array=fading['proportion']['sem'])))

for mouse in prop_df['mouse'].unique():
    mdata = prop_df[(prop_df['mouse'] == mouse) & (prop_df['trend'] == 'decreasing')]
    fig.add_trace(go.Scatter(x=mdata['session'], y=mdata['proportion'], mode='lines', line_color=subject_color,
                             showlegend=False, name=mouse, line_width=1, opacity=0.7))
fig.update_xaxes(categoryorder='array', categoryarray=session_list)
fig.show()
fig.write_image(pjoin(fig_path, 'proportion_fading_ensembles.png'))

In [None]:
## Plot proportion of emerging ensembles on each day.
fading = avg_prop[avg_prop['trend'] == 'increasing']
fig = pf.custom_graph_template(x_title='', y_title='Proportion Emerging Ensembles')
fig.add_trace(go.Bar(x=fading['session'], y=fading['proportion']['mean'], marker_color='darkturquoise', showlegend=False,
                     marker_line_color='black', marker_line_width=2, error_y=dict(type='data', array=fading['proportion']['sem'])))

for mouse in prop_df['mouse'].unique():
    mdata = prop_df[(prop_df['mouse'] == mouse) & (prop_df['trend'] == 'increasing')]
    fig.add_trace(go.Scatter(x=mdata['session'], y=mdata['proportion'], mode='lines', line_color=subject_color,
                             showlegend=False, name=mouse, line_width=1, opacity=0.7))
fig.update_xaxes(categoryorder='array', categoryarray=session_list)
fig.show()
fig.write_image(pjoin(fig_path, 'proportion_emerging_ensembles.png'))

In [None]:
## Plot proportion of fading ensembles on each day.
fig = pf.custom_graph_template(x_title='', y_title='', rows=1, columns=2, shared_y=True, width=800, titles=['Fading', 'Emerging'])
fading = avg_prop[avg_prop['trend'] == 'decreasing']
fig.add_trace(go.Bar(x=fading['session'], y=fading['proportion']['mean'], marker_color=fading_color, showlegend=False,
                     marker_line_color='black', marker_line_width=2, error_y=dict(type='data', array=fading['proportion']['sem'])), row=1, col=1)

for mouse in prop_df['mouse'].unique():
    mdata = prop_df[(prop_df['mouse'] == mouse) & (prop_df['trend'] == 'decreasing')]
    fig.add_trace(go.Scatter(x=mdata['session'], y=mdata['proportion'], mode='lines', line_color=subject_color,
                             showlegend=False, name=mouse, line_width=1, opacity=0.7), row=1, col=1)

fading = avg_prop[avg_prop['trend'] == 'increasing']
fig.add_trace(go.Bar(x=fading['session'], y=fading['proportion']['mean'], marker_color='darkturquoise', showlegend=False,
                     marker_line_color='black', marker_line_width=2, error_y=dict(type='data', array=fading['proportion']['sem'])), row=1, col=2)

for mouse in prop_df['mouse'].unique():
    mdata = prop_df[(prop_df['mouse'] == mouse) & (prop_df['trend'] == 'increasing')]
    fig.add_trace(go.Scatter(x=mdata['session'], y=mdata['proportion'], mode='lines', line_color=subject_color,
                             showlegend=False, name=mouse, line_width=1, opacity=0.7), row=1, col=2)
fig.update_xaxes(categoryorder='array', categoryarray=session_list)
fig.update_yaxes(title='Proportion of Ensembles')
fig.show()
fig.write_image(pjoin(fig_path, 'proportion_of_ensembles.png'))