In [None]:
import os
import pickle
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import colors
from scipy.stats import mannwhitneyu, wilcoxon
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.metrics.pairwise import cosine_similarity
from scipy.stats import spearmanr
from matplotlib.backends.backend_pdf import PdfPages
import xarray as xr

# sys.path.append(r'H:/anthony/repos/NWB_analysis')
sys.path.append(r'/home/aprenard/repos/NWB_analysis')
sys.path.append(r'/home/aprenard/repos/fast-learning')
# from nwb_wrappers import nwb_reader_functions as nwb_read
import src.utils.utils_imaging as imaging_utils
import src.utils.utils_io as io
from src.behavior import compute_performance, plot_single_session
import warnings

# Set plot parameters.
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['svg.fonttype'] = 'none'
sns.set_theme(context='paper', style='ticks', palette='deep', font='sans-serif', font_scale=1)
%matplotlib inline

In [None]:
# Path to the directory containing the processed data.
processed_dir = io.solve_common_paths('processed_data')
nwb_dir = io.solve_common_paths('nwb')
db_path = io.solve_common_paths('db')
# Session metadata file.

# # Rewarded and non-rewarded NWB files.
# group_yaml_rew = r"//sv-nas1.rcp.epfl.ch/Petersen-Lab/analysis/Anthony_Renard/mice_info/groups/imaging_rewarded.yaml"
# group_yaml_non_rew = r"//sv-nas1.rcp.epfl.ch/Petersen-Lab/analysis/Anthony_Renard/mice_info/groups/imaging_non_rewarded.yaml"
# nwb_list_rew = io.read_group_yaml(group_yaml_rew)
# nwb_list_non_rew = io.read_group_yaml(group_yaml_non_rew)
# nwb_list = nwb_list_rew + nwb_list_non_rew

In [2]:
plt.imshow(np.zeros((10, 10)))
plt.show()

# 0. Illustrations

In [28]:
sampling_rate = 30
win = (1, 1.3)  # from stimulus onset to 300 ms after.
win = (int(win[0] * sampling_rate), int(win[1] * sampling_rate))
baseline_win = (0, 1)
baseline_win = (int(baseline_win[0] * sampling_rate), int(baseline_win[1] * sampling_rate))
days = ['-2', '-1', '0', '+1', '+2']

mouse_id = 'AR127'

session_list, nwb_files, _, db_filtered = io.select_sessions_from_db(db_path,
                                                                    nwb_dir,
                                                                    two_p_imaging='yes',
                                                                    subject_id=mouse_id,
                                                                    day=days,)
print(session_list)

data = []
mdata_list = []
for session_id in session_list:
    arr, mdata = imaging_utils.load_session_2p_imaging(mouse_id,
                                                        session_id,
                                                        processed_dir)
    arr = imaging_utils.substract_baseline(arr, 3, baseline_win)
    data.append(arr)
    mdata_list.append(mdata)

# Extract UM trials.
for i, arr in enumerate(data):
    arr = imaging_utils.extract_trials(arr, mdata_list[i], 'UM', n_trials=None)
    data[i] = arr



['AR127_20240221_133407', 'AR127_20240222_152629', 'AR127_20240223_131820', 'AR127_20240224_140853', 'AR127_20240225_142858']


In [5]:
for icell in range(90, 120):
    fig, axes = plt.subplots(1, 2)

    for itrial in range(50):
        axes[0].plot(data[2][icell, itrial, :] + itrial * 2)
    axes[1].plot(data[2][icell, :, :].mean(axis=0))
    plt.suptitle(f'Cell {icell}')

  fig, axes = plt.subplots(1, 2)


In [6]:
# Plot a single population vector.
itrial = 0
pop_vector = data[2][:87,itrial,win[0]:win[1]].mean(axis=1)
pop_vector = np.repeat(pop_vector[:, np.newaxis], 10, axis=1)
vmin = np.percentile(pop_vector, 2)
vmax = np.percentile(pop_vector, 98)
plt.figure()
plt.imshow(pop_vector,cmap='viridis', vmin=vmin, vmax=vmax, aspect=2)


<matplotlib.image.AxesImage at 0x257dac34f80>

In [None]:
# Plot LMI vector.
lmi = np.load(r'\\sv-nas1.rcp.epfl.ch\Petersen-Lab\analysis\Anthony_Renard\data_processed\lmi.npy', allow_pickle=True).item()



In [14]:

lmi['AR127']['allcells']

array([ 3.125000e-04,  1.506250e-01, -7.000000e-02,  1.215625e-01,
       -2.468750e-02, -4.518750e-01,  8.562500e-02, -1.734375e-01,
        2.156250e-02, -3.968750e-02, -3.203125e-01, -7.500000e-01,
       -9.750000e-02, -2.125000e-02, -4.209375e-01,  4.687500e-02,
       -6.343750e-02,  1.684375e-01,  9.468750e-02,  1.468750e-02,
       -1.665625e-01, -2.384375e-01, -1.812500e-02,  4.843750e-02,
        2.437500e-02,  1.375000e-01, -1.225000e-01, -2.662500e-01,
        2.300000e-01, -7.687500e-02,  2.750000e-02,  5.343750e-02,
       -1.987500e-01, -7.384375e-01, -1.081250e-01, -1.893750e-01,
       -2.362500e-01, -2.593750e-02, -6.256250e-01, -1.781250e-02,
       -2.812500e-02, -2.362500e-01,  1.531250e-02,  9.375000e-02,
       -1.162500e-01, -5.656250e-02,  2.165625e-01, -2.684375e-01,
       -9.268750e-01, -4.375000e-02, -4.812500e-02, -2.743750e-01,
        8.000000e-02, -1.459375e-01, -1.206250e-01,  1.753125e-01,
       -2.900000e-01,  2.403125e-01, -3.218750e-02,  1.687500e

In [16]:

f, axes = plt.subplots(1, 2, sharey=True)
im = axes[0].imshow(np.repeat(lmi['AR127']['allcells'][:, np.newaxis], 10, axis=1), cmap='coolwarm', vmin=-.5, vmax=.5)
plt.colorbar(im)
# vmax = np.percentile(pop_vectors_dict[mouse_id]['allcells'], 99)
# vmin = np.percentile(pop_vectors_dict[mouse_id]['allcells'], 1)
# im = axes[1].imshow(pop_vectors_dict[mouse_id]['allcells'], cmap='viridis', vmin=vmin, vmax=vmax)
# plt.colorbar(im)
# print(vmin, vmax)


<matplotlib.colorbar.Colorbar at 0x257f026fbc0>

# Compute LMI and responsniveness

# 1. Responses to unmotivated mapping trials across learning days.

## 1.1. PSTH of unmotivated mapping trials across learning days

In [29]:
# Load data needed to compute before and after learning.

sampling_rate = 30
win_sec = (0.8, 3)  
win = (int(win_sec[0] * sampling_rate), int(win_sec[1] * sampling_rate))
baseline_win = (0, 1)
baseline_win = (int(baseline_win[0] * sampling_rate), int(baseline_win[1] * sampling_rate))
days = ['-2', '-1', '0', '+1', '+2']
# Correlation matrix for a specific cell type
cell_type = None
variance_explained_thr = 0.7

_, _, mice, _ = io.select_sessions_from_db(db_path,
                                            nwb_dir,
                                            two_p_imaging='yes')
mice = [m for m in mice if m not in ['AR163']]
print(mice)
len(mice)


['GF305', 'GF306', 'GF307', 'GF308', 'GF310', 'GF311', 'GF313', 'GF314', 'GF317', 'GF318', 'GF319', 'GF323', 'GF333', 'GF334', 'GF348', 'GF350', 'MI062', 'MI069', 'MI072', 'MI075', 'MI076', 'AR132', 'AR133', 'AR137', 'AR139', 'AR127', 'AR143', 'AR177', 'AR178', 'AR179', 'AR180']


31

In [30]:
psth = {}
metadata = {}
responsive_p_values = {}

for mouse_id in mice:
    # Disregard these mice as the number of trials is too low.
    # if mouse_id in ['GF307', 'GF310', 'GF333', 'AR144', 'AR135']:
    #     continue
    session_list, nwb_files, _, db_filtered = io.select_sessions_from_db(db_path,
                                                                        nwb_dir,
                                                                        two_p_imaging='yes',
                                                                        subject_id=mouse_id,
                                                                        day=days,)
    print(session_list)

    data = []
    mdata_list = []
    for session_id in session_list:
        arr, mdata = imaging_utils.load_session_2p_imaging(mouse_id,
                                                            session_id,
                                                            processed_dir)
        arr = imaging_utils.substract_baseline(arr, 3, baseline_win)
        data.append(arr)
        mdata_list.append(mdata)

    # Extract UM trials.
    for i, arr in enumerate(data):
        arr = imaging_utils.extract_trials(arr, mdata_list[i], 'A', n_trials=40)
        data[i] = arr

    # Get some metadata.
    reward_group = io.get_reward_group_from_db(db_path, session_list[0])
    metadata[mouse_id] = {}
    metadata[mouse_id]['reward_group'] = reward_group
    metadata[mouse_id]['cell_types'] = mdata['cell_types']
    metadata[mouse_id]["rois"] = mdata['rois']
    
    psth[mouse_id] = {}
    responsive_p_values[mouse_id] = {}

    for cell_type in ['allcells', 'wS2', 'wM1']:
        # Select cell type.
        if cell_type == 'allcells':
            data_subtype = data
        else:
            data_subtype = []
            cell_type_mask = mdata_list[0]['cell_types']==cell_type
            data_subtype = [arr[cell_type_mask] for arr in data]

        # If no cells of the specified type, skip.
        if data_subtype[0].shape[0] == 0:
            continue

        # Compute average response for each trial, each day.
        # --------------------------------------------------

        psth[mouse_id][cell_type] = []
        for day in data_subtype:
            psth[mouse_id][cell_type].append(np.nanmean(day[:, :, win[0]:win[1]], axis=1))

        # # Test responsiveness.
        # # --------------------

        # baseline_avg = []
        # response_avg = []
        # for day in data_subtype:
        #     baseline_avg.append(np.nanmean(day[:, :, baseline_win[0]:baseline_win[1]], axis=2))
        #     response_avg.append(np.nanmean(day[:, :, win[0]:win[1]], axis=2))

        # # Compare response amplitude to baseline.
        # n_cells = data_subtype[0].shape[0]
        # p_values = [np.zeros(n_cells) for _ in range(len(data_subtype))]
        # for iday, day in enumerate(data_subtype):
        #     for icell in range(n_cells):
        #         # Temporary fix for cells that are always 0 due to baseline substraction.
        #         if np.all(baseline_avg[iday][icell] == 0.):
        #             p_values[iday][icell] = 1
        #         else:
        #             _, p_values[iday][icell] = wilcoxon(baseline_avg[iday][icell], response_avg[iday][icell])
        # p_values = np.stack(p_values, axis=0)
        # responsive_p_values[mouse_id][cell_type] = p_values

['GF305_27112020_083119', 'GF305_28112020_103938', 'GF305_29112020_103331', 'GF305_30112020_110255', 'GF305_02122020_132229']
here (133, 40, 181)
here (133, 40, 181)
here (133, 40, 181)
here (133, 40, 181)
here (133, 40, 181)
['GF306_27112020_104436', 'GF306_28112020_125555', 'GF306_29112020_131929', 'GF306_30112020_133249', 'GF306_02122020_161611']
here (215, 40, 181)
here (215, 40, 181)
here (215, 40, 181)
here (215, 40, 181)
here (215, 40, 181)
['GF307_17112020_080325', 'GF307_18112020_075939', 'GF307_19112020_083908', 'GF307_20112020_082942', 'GF307_21112020_102608']
here (150, 40, 181)
here (150, 40, 181)
here (150, 40, 181)
here (150, 40, 181)
here (150, 40, 181)
['GF308_17112020_105052', 'GF308_18112020_093627', 'GF308_19112020_103527', 'GF308_20112020_122826', 'GF308_21112020_135515']
here (147, 40, 181)
here (147, 40, 181)
here (147, 40, 181)
here (147, 40, 181)
here (147, 40, 181)
['GF310_17112020_132720', 'GF310_18112020_122252', 'GF310_19112020_131953', 'GF310_20112020_1509

Plot PSTH's for all cells and projectors.

In [31]:
# Convert to pandas.
mouse_ids = list(psth.keys())
df = []

for mouse_id in mouse_ids:
    # if mouse_id == 'GF308':
    #     continue
    
    for icell, roi in enumerate(metadata[mouse_id]['rois']):
        time = np.linspace(win_sec[0], win_sec[1], psth[mouse_id]['allcells'][0].shape[1]) - 1
        cell_type = metadata[mouse_id]['cell_types'][icell]
        for iday in range(len(days)):
            trace = psth[mouse_id]['allcells'][iday][icell]
            temp = pd.DataFrame(np.stack([time, trace], axis=1), columns=['time', 'activity'])
            temp['day'] = days[iday]
            temp['mouse_id'] = mouse_id
            temp['roi'] = roi
            temp['cell_type'] = cell_type
            temp['reward_group'] = metadata[mouse_id]['reward_group']
            df.append(temp)
df = pd.concat(df)
df = df.reset_index(drop=True)


In [None]:
# Plot.
data = df.loc[df.time<1.5]

# GF305 has baseline artefact on day -1 at auditory trials.
data = data.loc[~data.mouse_id.isin(['GF305'])]

data = data.loc[data['day'].isin(['-2', '-1', '0', '+1', '+2'])]
# data = data.loc[data['cell_type']=='wM1']
# df = df.loc[df.mouse_id.isin(['GF305', 'GF306', 'GF307'])]
fig = sns.relplot(data=data, x='time', y='activity', errorbar='ci', col='day',
            kind='line', hue='reward_group',
            hue_order=['R-','R+'], palette=sns.color_palette(['#d51a1c', '#238443']),
            height=3, aspect=0.8)
for ax in fig.axes.flatten():
    ax.axvline(0, color='#FF9600', linestyle='--')
    ax.set_title('')    

fig = sns.relplot(data=data, x='time', y='activity', errorbar='se', col='day', row='cell_type',
            kind='line', hue='reward_group',
            hue_order=['R-','R+'], palette=sns.color_palette(['#d51a1c', '#238443']), row_order=['wS2', 'wM1',],
            height=3, aspect=0.8)
for ax in fig.axes.flatten():
    ax.axvline(0, color='#FF9600', linestyle='--')
    ax.set_title('')


Single PSTH's per mouse.

In [95]:
pdf_file = f'psth_individual_mice_auditory.pdf'
output_dir = fr'//sv-nas1.rcp.epfl.ch/Petersen-Lab/analysis/Anthony_Renard/analysis_output/sensory_plasticity/psth'

with PdfPages(os.path.join(output_dir, pdf_file)) as pdf:
    for mouse_id in mice:
        # Plot.
        data = df.loc[df['day'].isin(['-2', '-1', '0', '+1', '+2'])
                      & (df['mouse_id'] == mouse_id)]

        sns.relplot(data=data, x='time', y='activity', errorbar='se', col='day', row='cell_type',
                    kind='line', hue='reward_group',
                    hue_order=['R-','R+'], palette=sns.color_palette(['#d51a1c', '#238443']))
        plt.suptitle(mouse_id)
        # pdf.savefig(dpi=300)
        # plt.close()

## PSTH's during learning

In [22]:
# Load data needed to compute before and after learning.

sampling_rate = 30
win_sec = (0.8, 3)  
win = (int(win_sec[0] * sampling_rate), int(win_sec[1] * sampling_rate))
baseline_win = (0, 1)
baseline_win = (int(baseline_win[0] * sampling_rate), int(baseline_win[1] * sampling_rate))
days = ['-2', '-1', '0', '+1', '+2']

_, _, mice, _ = io.select_sessions_from_db(db_path,
                                            nwb_dir,
                                            two_p_imaging='yes')
mice = [m for m in mice if m not in ['AR163']]
print(mice)
len(mice)

['GF305', 'GF306', 'GF307', 'GF308', 'GF310', 'GF311', 'GF313', 'GF314', 'GF317', 'GF318', 'GF319', 'GF323', 'GF333', 'GF334', 'GF348', 'GF350', 'MI062', 'MI069', 'MI072', 'MI075', 'MI076', 'AR132', 'AR133', 'AR137', 'AR139', 'AR127', 'AR143', 'AR177', 'AR178', 'AR179', 'AR180']


31

In [25]:
response_amp = {}
psth = {}
lmi = {}
lmi_p = {}
metadata = {}
responsive_p_values = {}

for mouse_id in mice:
    # Disregard these mice as the number of trials is too low.
    # if mouse_id in ['GF307', 'GF310', 'GF333', 'AR144', 'AR135']:
    #     continue
    session_list, nwb_files, _, db_filtered = io.select_sessions_from_db(db_path,
                                                                        nwb_dir,
                                                                        two_p_imaging='yes',
                                                                        subject_id=mouse_id,
                                                                        day=days,)
    print(session_list)

    data = []
    mdata_list = []
    for session_id in session_list:
        arr, mdata = imaging_utils.load_session_2p_imaging(mouse_id,
                                                            session_id,
                                                            processed_dir)
        arr = imaging_utils.substract_baseline(arr, 3, baseline_win)
        data.append(arr)
        mdata_list.append(mdata)

    # Extract UM trials.
    for i, arr in enumerate(data):
        arr = imaging_utils.extract_trials(arr, mdata_list[i], 'UM', n_trials=40)
        data[i] = arr

    # Get some metadata.
    reward_group = io.get_reward_group_from_db(db_path, session_list[0])
    metadata[mouse_id] = {}
    metadata[mouse_id]['reward_group'] = reward_group
    metadata[mouse_id]['cell_types'] = mdata['cell_types']
    metadata[mouse_id]["rois"] = mdata['rois']
    
    psth[mouse_id] = {}
    responsive_p_values[mouse_id] = {}
    response_amp[mouse_id] = {}
    lmi[mouse_id] = {}
    lmi_p[mouse_id] = {}

    for cell_type in ['allcells', 'wS2', 'wM1']:
        # Select cell type.
        if cell_type == 'allcells':
            data_subtype = data
        else:
            data_subtype = []
            cell_type_mask = mdata_list[0]['cell_types']==cell_type
            data_subtype = [arr[cell_type_mask] for arr in data]

        # If no cells of the specified type, skip.
        if data_subtype[0].shape[0] == 0:
            continue

        # Compute average response for each trial, each day.
        # --------------------------------------------------

        psth[mouse_id][cell_type] = []
        for day in data_subtype:
            psth[mouse_id][cell_type].append(np.nanmean(day[:, :, win[0]:win[1]], axis=1))

        # Compute average response for each trial, each day.
        response_amp[mouse_id][cell_type] = []
        for day in data_subtype:
            response_amp[mouse_id][cell_type].append(np.nanmean(day[:, :, win[0]:win[1]], axis=2))

        # # Compute LMI.
        # if cell_type == 'allcells':
        #     # # pre = np.mean(np.concatenate(response_amp[0:2], axis=1), axis=1)
        #     # # print(pre.shape)
        #     # # post = np.mean(np.concatenate((response_amp[5], response_amp[7]), axis=1), axis=1)
        #     # # lmi[mouse_id] = (post - pre) / (np.abs(post) + np.abs(pre))

        #     # lmis = []
        #     # ncells = len(metadata[mouse_id]['rois'])
        #     # for icell in range(ncells):
        #     #     # mapping trials of D-2, D-1, D+1, D+2.
        #     #     X = np.r_[response_amp[mouse_id][cell_type][0][icell], response_amp[mouse_id][cell_type][1][icell],
        #     #               response_amp[mouse_id][cell_type][3][icell], response_amp[mouse_id][cell_type][4][icell]]
        #     #     y = np.r_[np.zeros(response_amp[mouse_id][cell_type][0][icell].shape[0]),
        #     #               np.zeros(response_amp[mouse_id][cell_type][1][icell].shape[0]),
        #     #               np.ones(response_amp[mouse_id][cell_type][3][icell].shape[0]),
        #     #               np.ones(response_amp[mouse_id][cell_type][4][icell].shape[0])]
        #     #     fpr, tpr, _ = roc_curve(y, X)
        #     #     roc_auc = auc(fpr, tpr)
        #     #     lmis.append((roc_auc - 0.5) * 2)
        #     # lmi[mouse_id]['allcells'] = np.array(lmis)

        #     pre = [response_amp[mouse_id][cell_type][days.index('-2')],
        #            response_amp[mouse_id][cell_type][days.index('-1')]]
        #     pre = np.concatenate(pre, axis=1)
        #     post = [response_amp[mouse_id][cell_type][days.index('+1')],
        #             response_amp[mouse_id][cell_type][days.index('+2')]]
        #     post = np.concatenate(post, axis=1)
        #     lmi[mouse_id]['allcells'], lmi_p[mouse_id]['allcells'] = imaging_utils.compute_lmi(pre, post, nshuffles=100)
        # else:
        #     lmi[mouse_id][cell_type] = lmi[mouse_id]['allcells'][metadata[mouse_id]['cell_types'] == cell_type]
        #     lmi_p[mouse_id][cell_type] = lmi_p[mouse_id]['allcells'][metadata[mouse_id]['cell_types'] == cell_type]

# # Save lmi dicts.
# save_path = r'\\sv-nas1.rcp.epfl.ch\Petersen-Lab\analysis\Anthony_Renard\data_processed\lmi.npy'
# np.save(save_path, lmi, allow_pickle=True)
# save_path = r'\\sv-nas1.rcp.epfl.ch\Petersen-Lab\analysis\Anthony_Renard\data_processed\lmi_p.npy'
# np.save(save_path, lmi_p, allow_pickle=True)

['GF305_27112020_083119', 'GF305_28112020_103938', 'GF305_29112020_103331', 'GF305_30112020_110255', 'GF305_02122020_132229']
here (133, 40, 181)
here (133, 40, 181)
here (133, 40, 181)
here (133, 40, 181)
here (133, 40, 181)
['GF306_27112020_104436', 'GF306_28112020_125555', 'GF306_29112020_131929', 'GF306_30112020_133249', 'GF306_02122020_161611']
here (215, 40, 181)
here (215, 40, 181)
here (215, 40, 181)
here (215, 40, 181)
here (215, 40, 181)
['GF307_17112020_080325', 'GF307_18112020_075939', 'GF307_19112020_083908', 'GF307_20112020_082942', 'GF307_21112020_102608']
here (150, 40, 181)
here (150, 40, 181)
here (150, 40, 181)
here (150, 40, 181)
here (150, 40, 181)
['GF308_17112020_105052', 'GF308_18112020_093627', 'GF308_19112020_103527', 'GF308_20112020_122826', 'GF308_21112020_135515']
here (147, 40, 181)
here (147, 40, 181)
here (147, 40, 181)
here (147, 40, 181)
here (147, 40, 181)
['GF310_17112020_132720', 'GF310_18112020_122252', 'GF310_19112020_131953', 'GF310_20112020_1509

In [37]:
# Coutn cells.

for reward_group in ['R-', 'R+']:
    count_all = []
    count_s2 = []
    count_m1 = []
    for mouse in lmi.keys():
        if metadata[mouse]['reward_group'] != reward_group:
            continue
        for cell_type in lmi[mouse].keys():
            if cell_type == 'allcells':
                count_all.append(len(lmi[mouse][cell_type]))
            elif cell_type == 'wS2':
                count_s2.append(len(lmi[mouse][cell_type]))
            elif cell_type == 'wM1':
                count_m1.append(len(lmi[mouse][cell_type]))
    c_all = np.sum(count_all)
    c_s2 = np.sum(count_s2)
    c_m1 = np.sum(count_m1)
    print(c_all, c_s2, c_m1)


2482 233 115
2846 280 277


In [33]:
lmi['AR127']['allcells']

array([ 3.125000e-04,  1.506250e-01, -7.000000e-02,  1.215625e-01,
       -2.468750e-02, -4.518750e-01,  8.562500e-02, -1.734375e-01,
        2.156250e-02, -3.968750e-02, -3.203125e-01, -7.500000e-01,
       -9.750000e-02, -2.125000e-02, -4.209375e-01,  4.687500e-02,
       -6.343750e-02,  1.684375e-01,  9.468750e-02,  1.468750e-02,
       -1.665625e-01, -2.384375e-01, -1.812500e-02,  4.843750e-02,
        2.437500e-02,  1.375000e-01, -1.225000e-01, -2.662500e-01,
        2.300000e-01, -7.687500e-02,  2.750000e-02,  5.343750e-02,
       -1.987500e-01, -7.384375e-01, -1.081250e-01, -1.893750e-01,
       -2.362500e-01, -2.593750e-02, -6.256250e-01, -1.781250e-02,
       -2.812500e-02, -2.362500e-01,  1.531250e-02,  9.375000e-02,
       -1.162500e-01, -5.656250e-02,  2.165625e-01, -2.684375e-01,
       -9.268750e-01, -4.375000e-02, -4.812500e-02, -2.743750e-01,
        8.000000e-02, -1.459375e-01, -1.206250e-01,  1.753125e-01,
       -2.900000e-01,  2.403125e-01, -3.218750e-02,  1.687500e

Proportion of LMI cells.


In [186]:
lmi_prop = []
for mouse in mice:
    reward_group = metadata[mouse]['reward_group']
    for cell_type in ['allcells', 'wS2', 'wM1']:
        if cell_type in lmi_p[mouse].keys():
            n_lmi_up = np.sum(lmi_p[mouse][cell_type] >= 0.975)
            n_lmi_down = np.sum(lmi_p[mouse][cell_type] <= 0.025)
            prop_lmi_up = n_lmi_up / len(lmi[mouse][cell_type])
            prop_lmi_down = n_lmi_down / len(lmi[mouse][cell_type])
            lmi_prop.append({'mouse_id': mouse,'reward_group': reward_group, 'cell_type': cell_type, 'n_lmi': n_lmi_up, 'prop_lmi': prop_lmi_up, 'modulation': 'up'})
            lmi_prop.append({'mouse_id': mouse,'reward_group': reward_group, 'cell_type': cell_type, 'n_lmi': n_lmi_down, 'prop_lmi': prop_lmi_down, 'modulation': 'down'})
        else:
            lmi_prop.append({'mouse_id': mouse, 'reward_group': reward_group, 'cell_type': cell_type, 'n_lmi_up': np.nan, 'prop_lmi_up': np.nan})
            lmi_prop.append({'mouse_id': mouse, 'reward_group': reward_group, 'cell_type': cell_type, 'n_lmi_down': np.nan, 'prop_lmi_down': np.nan})
lmi_prop = pd.DataFrame(lmi_prop)

In [197]:
g = sns.catplot(data=lmi_prop, x='reward_group', y='prop_lmi', kind='bar', col='modulation', row='cell_type', hue='reward_group',
            palette=sns.color_palette([ '#238443', '#d51a1c',]), height=3, aspect=0.8)

# Perform Mann-Whitney U test to check if the difference between the two reward groups is significant for each modulation and cell type.
results = []
for cell_type in lmi_prop['cell_type'].unique():
    for modulation in ['up', 'down']:
        group_rew = lmi_prop[(lmi_prop['reward_group'] == 'R+') & (lmi_prop['cell_type'] == cell_type) & (lmi_prop['modulation'] == modulation)]['prop_lmi']
        group_unrew = lmi_prop[(lmi_prop['reward_group'] == 'R-') & (lmi_prop['cell_type'] == cell_type) & (lmi_prop['modulation'] == modulation)]['prop_lmi']
        group_rew = group_rew.dropna()
        group_unrew = group_unrew.dropna()
        stat, p = mannwhitneyu(group_rew, group_unrew)
        results.append({'cell_type': cell_type, 'modulation': modulation, 'p_value': p})
        print(f'Cell type {cell_type}, Modulation {modulation}: p-value = {p}')

# Convert results to a DataFrame
results_df = pd.DataFrame(results)
print(results_df)

# Add stars to the plot for each subplot
for ax in g.axes.flat:
    cell_type = ax.get_title().split(' = ')[-1]
    modulation = ax.get_title().split(' = ')[-2]
    for result in results:
        if result['cell_type'] == cell_type and result['modulation'] == modulation:
            if result['p_value'] < 0.05:
                ax.text(0.5, 0.5, '*', ha='center', va='bottom', color='black', transform=ax.transAxes)
            if result['p_value'] < 0.01:
                ax.text(0.5, 0.5, '**', ha='center', va='bottom', color='black', transform=ax.transAxes)
            if result['p_value'] < 0.001:
                ax.text(0.5, 0.5, '***', ha='center', va='bottom', color='black', transform=ax.transAxes)


Cell type allcells, Modulation up: p-value = 0.00038139799917024205
Cell type allcells, Modulation down: p-value = 0.2579351703686603
Cell type wS2, Modulation up: p-value = 0.015403528420213176
Cell type wS2, Modulation down: p-value = 0.38455254690437146
Cell type wM1, Modulation up: p-value = 0.20185429860117265
Cell type wM1, Modulation down: p-value = 0.879153887906272
  cell_type modulation   p_value
0  allcells         up  0.000381
1  allcells       down  0.257935
2       wS2         up  0.015404
3       wS2       down  0.384553
4       wM1         up  0.201854
5       wM1       down  0.879154


PSTH's with LMI cells




In [141]:
# Convert to pandas.
mouse_ids = list(psth.keys())
df = []

for mouse_id in mouse_ids:
    # if mouse_id == 'GF308':
    #     continue
    
    for icell, roi in enumerate(metadata[mouse_id]['rois']):
        # if (lmi_p[mouse_id]['allcells'][icell] < 0.975) & (lmi_p[mouse_id]['allcells'][icell] > 0.025):
        # if (lmi_p[mouse_id]['allcells'][icell] < 0.975):
        if (lmi_p[mouse_id]['allcells'][icell] > 0.025):
            continue

        time = np.linspace(win_sec[0], win_sec[1], psth[mouse_id]['allcells'][0].shape[1]) - 1
        cell_type = metadata[mouse_id]['cell_types'][icell]
        for iday in range(len(days)):
            trace = psth[mouse_id]['allcells'][iday][icell]
            temp = pd.DataFrame(np.stack([time, trace], axis=1), columns=['time', 'activity'])
            temp['day'] = days[iday]
            temp['mouse_id'] = mouse_id
            temp['roi'] = roi
            temp['cell_type'] = cell_type
            temp['reward_group'] = metadata[mouse_id]['reward_group']
            df.append(temp)
df = pd.concat(df)
df = df.reset_index(drop=True)

In [142]:
# Plot.
data = df.loc[df.time<1.5]

# GF305 has baseline artefact on day -1 at auditory trials.
data = data.loc[~data.mouse_id.isin(['GF305'])]

data = data.loc[data['day'].isin(['-2', '-1', '0', '+1', '+2'])]
# data = data.loc[data['cell_type']=='wM1']
# df = df.loc[df.mouse_id.isin(['GF305', 'GF306', 'GF307'])]
fig = sns.relplot(data=data, x='time', y='activity', errorbar='ci', col='day',
            kind='line', hue='reward_group',
            hue_order=['R-','R+'], palette=sns.color_palette(['#d51a1c', '#238443']),
            height=3, aspect=0.8)
for ax in fig.axes.flatten():
    ax.axvline(0, color='#FF9600', linestyle='--')
    ax.set_title('')        

fig = sns.relplot(data=data, x='time', y='activity', errorbar='ci', col='day', row='cell_type',
            kind='line', hue='reward_group',
            hue_order=['R-','R+'], palette=sns.color_palette(['#d51a1c', '#238443']), row_order=['wS2', 'wM1',],
            height=3, aspect=0.8)
for ax in fig.axes.flatten():
    ax.axvline(0, color='#FF9600', linestyle='--')
    ax.set_title('')    


PSTH's with LMI cells -- 5 first whisker/whisker hit trials VS the rest.

In [27]:
# Load data needed to compute before and after learning.

sampling_rate = 30
win_sec = (0.8, 3)  
win = (int(win_sec[0] * sampling_rate), int(win_sec[1] * sampling_rate))
baseline_win = (0, 1)
baseline_win = (int(baseline_win[0] * sampling_rate), int(baseline_win[1] * sampling_rate))
days = ['0']

_, _, mice, _ = io.select_sessions_from_db(db_path,
                                            nwb_dir,
                                            two_p_imaging='yes')
mice = [m for m in mice if m not in ['AR163']]
print(mice)
len(mice)

['GF305', 'GF306', 'GF307', 'GF308', 'GF310', 'GF311', 'GF313', 'GF314', 'GF317', 'GF318', 'GF319', 'GF323', 'GF333', 'GF334', 'GF348', 'GF350', 'MI062', 'MI069', 'MI072', 'MI075', 'MI076', 'AR132', 'AR133', 'AR137', 'AR139', 'AR127', 'AR143', 'AR177', 'AR178', 'AR179', 'AR180']


31

In [28]:
response_amp = {}
psth = {}
metadata = {}

for mouse_id in mice:
    # Disregard these mice as the number of trials is too low.
    # if mouse_id in ['GF307', 'GF310', 'GF333', 'AR144', 'AR135']:
    #     continue
    session_list, nwb_files, _, db_filtered = io.select_sessions_from_db(db_path,
                                                                        nwb_dir,
                                                                        two_p_imaging='yes',
                                                                        subject_id=mouse_id,
                                                                        day=days,)
    print(session_list)

    data = []
    mdata_list = []
    for session_id in session_list:
        arr, mdata = imaging_utils.load_session_2p_imaging(mouse_id,
                                                            session_id,
                                                            processed_dir)
        arr = imaging_utils.substract_baseline(arr, 3, baseline_win)
        data.append(arr)
        mdata_list.append(mdata)

    # Extract UM trials.
    for i, arr in enumerate(data):
        arr = imaging_utils.extract_trials(arr, mdata_list[i], 'W', n_trials=40)
        data[i] = arr

    # Get some metadata.
    reward_group = io.get_reward_group_from_db(db_path, session_list[0])
    metadata[mouse_id] = {}
    metadata[mouse_id]['reward_group'] = reward_group
    metadata[mouse_id]['cell_types'] = mdata['cell_types']
    metadata[mouse_id]["rois"] = mdata['rois']
    
    psth[mouse_id] = {}
    responsive_p_values[mouse_id] = {}
    response_amp[mouse_id] = {}


    for cell_type in ['allcells', 'wS2', 'wM1']:
        # Select cell type.
        if cell_type == 'allcells':
            data_subtype = data
        else:
            data_subtype = []
            cell_type_mask = mdata_list[0]['cell_types']==cell_type
            data_subtype = [arr[cell_type_mask] for arr in data]

        # If no cells of the specified type, skip.
        if data_subtype[0].shape[0] == 0:
            continue

        # Compute average response for each trial, each day.
        # --------------------------------------------------

        psth[mouse_id][cell_type] = []
        for day in data_subtype:
            psth[mouse_id][cell_type].append(day[:, :, win[0]:win[1]])

        # Compute average response for each trial, each day.
        response_amp[mouse_id][cell_type] = []
        for day in data_subtype:
            response_amp[mouse_id][cell_type].append(day[:, :, win[0]:win[1]])


['GF305_29112020_103331']
here (133, 40, 181)
['GF306_29112020_131929']
here (215, 40, 181)
['GF307_19112020_083908']
here (150, 40, 181)
['GF308_19112020_103527']
here (147, 40, 181)
['GF310_19112020_131953']
here (243, 40, 181)
['GF311_19112020_160412']
here (105, 40, 181)
['GF313_29112020_154625']
here (164, 40, 181)
['GF314_29112020_174831']
here (197, 40, 181)
['GF317_17122020_080715']
here (146, 40, 181)
['GF318_17122020_144100']
here (140, 40, 181)
['GF319_26122020_144746']
here (154, 40, 181)
['GF323_09012021_111716']
here (305, 40, 181)
['GF333_24012021_145617']
here (124, 40, 181)
['GF334_24012021_173019']
here (130, 40, 181)
['GF348_31052021_102411']
here (206, 40, 181)
['GF350_31052021_135001']
here (211, 40, 181)
['MI062_02102021_105027']
here (108, 40, 181)
['MI069_21122021_090648']
here (218, 40, 181)
['MI072_21122021_132704']
here (244, 40, 181)
['MI075_21122021_151949']
here (189, 40, 181)
['MI076_21122021_112146']
here (268, 40, 181)
['AR132_20240426_093953']
here (18

In [26]:

# Load lmi dicts.
lmi = np.load(r'\\sv-nas1.rcp.epfl.ch\Petersen-Lab\analysis\Anthony_Renard\data_processed\lmi.npy', allow_pickle=True).item()
lmi_p = np.load(r'\\sv-nas1.rcp.epfl.ch\Petersen-Lab\analysis\Anthony_Renard\data_processed\lmi_p.npy', allow_pickle=True).item()

In [29]:
# Convert to pandas.
mouse_ids = list(psth.keys())
df = []

for mouse_id in mouse_ids:
    # if mouse_id == 'GF308':
    #     continue
    
    for icell, roi in enumerate(metadata[mouse_id]['rois']):
        if (lmi_p[mouse_id]['allcells'][icell] < 0.975) & (lmi_p[mouse_id]['allcells'][icell] > 0.025):
            continue

        time = np.linspace(win_sec[0], win_sec[1], psth[mouse_id]['allcells'][0].shape[2]) - 1
        cell_type = metadata[mouse_id]['cell_types'][icell]
        for iday in range(len(days)):
            for itrial in range(40):
                trace = psth[mouse_id]['allcells'][iday][icell, itrial, :]
                temp = pd.DataFrame(np.stack([time, trace], axis=1), columns=['time', 'activity'])
                temp['day'] = days[iday]
                temp['mouse_id'] = mouse_id
                temp['roi'] = roi
                if lmi_p[mouse_id]['allcells'][icell] > 0.975:
                    temp['modulation'] = 'up'
                else:
                    temp['modulation'] = 'down'
                temp['trial'] = itrial
                temp['cell_type'] = cell_type
                temp['reward_group'] = metadata[mouse_id]['reward_group']
                df.append(temp)
df = pd.concat(df)
df = df.reset_index(drop=True)



In [32]:
# Plot.
data = df.loc[df.time<1.5]
modulation = 'down'
if modulation:
    data = data.loc[data['modulation']==modulation]

# # GF305 has baseline artefact on day -1 at auditory trials.
# data = data.loc[~data.mouse_id.isin(['GF305'])]


fig, axes = plt.subplots(3, 2, sharex=True, sharey=True)

temp = data.loc[(data['trial']<5) & (data['reward_group']=='R+')]
# Average on cells if stats on mice. On cells otherwise.
# temp = temp.groupby(['mouse_id', 'time', 'reward_group', 'cell_type'])['activity'].mean().reset_index()
sns.lineplot(data=temp, x='time', y='activity', errorbar='ci', hue='reward_group',
            palette=sns.color_palette([ '#238443']), linestyle='--',
            ax=axes[0, 0], legend=False)
temp = data.loc[(data['trial']>=5) & (data['reward_group']=='R+')]
# temp = temp.groupby(['mouse_id', 'time', 'reward_group', 'cell_type'])['activity'].mean().reset_index()
sns.lineplot(data=temp, x='time', y='activity', errorbar='ci', hue='reward_group',
            palette=sns.color_palette(['#238443']),
            ax=axes[0, 0], legend=False)

temp = data.loc[(data['trial']<5) & (data['reward_group']=='R-')]
# temp = temp.groupby(['mouse_id', 'time', 'reward_group', 'cell_type'])['activity'].mean().reset_index()
sns.lineplot(data=temp, x='time', y='activity', errorbar='ci', hue='reward_group',
            palette=sns.color_palette([ '#d51a1c']),linestyle='--',
            ax=axes[0, 1], legend=False)
temp = data.loc[(data['trial']>=5) & (data['reward_group']=='R-')]
# temp = temp.groupby(['mouse_id', 'time', 'reward_group', 'cell_type'])['activity'].mean().reset_index()
sns.lineplot(data=temp, x='time', y='activity', errorbar='ci', hue='reward_group',
            palette=sns.color_palette(['#d51a1c']),
            ax=axes[0, 1], legend=False)

temp = data.loc[(data['trial']<5)  & (data['cell_type']=='wS2') & (data['reward_group']=='R+')]
# temp = temp.groupby(['mouse_id', 'time', 'reward_group', 'cell_type'])['activity'].mean().reset_index()
sns.lineplot(data=temp, x='time', y='activity', errorbar='ci', hue='reward_group',
            palette=sns.color_palette(['#238443']),linestyle='--',
            ax=axes[1, 0], legend=False)
temp = data.loc[(data['trial']>=5)  & (data['cell_type']=='wS2') & (data['reward_group']=='R+')]
# temp = temp.groupby(['mouse_id', 'time', 'reward_group', 'cell_type'])['activity'].mean().reset_index()
sns.lineplot(data=temp, x='time', y='activity', errorbar='ci', hue='reward_group',
            palette=sns.color_palette(['#238443',]),
            ax=axes[1, 0], legend=False)

temp = data.loc[(data['trial']<5)  & (data['cell_type']=='wS2') & (data['reward_group']=='R-')]
# temp = temp.groupby(['mouse_id', 'time', 'reward_group', 'cell_type'])['activity'].mean().reset_index()
sns.lineplot(data=temp, x='time', y='activity', errorbar='ci', hue='reward_group',
            palette=sns.color_palette(['#d51a1c']),linestyle='--',
            ax=axes[1, 1], legend=False)
temp = data.loc[(data['trial']>=5)  & (data['cell_type']=='wS2') & (data['reward_group']=='R-')]
# temp = temp.groupby(['mouse_id', 'time', 'reward_group', 'cell_type'])['activity'].mean().reset_index()
sns.lineplot(data=temp, x='time', y='activity', errorbar='ci', hue='reward_group',
            palette=sns.color_palette(['#d51a1c',]),
            ax=axes[1, 1], legend=False)

temp = data.loc[(data['trial']<5)  & (data['cell_type']=='wM1') & (data['reward_group']=='R+')]
# temp = temp.groupby(['mouse_id', 'time', 'reward_group', 'cell_type'])['activity'].mean().reset_index()
sns.lineplot(data=temp, x='time', y='activity', errorbar='ci', hue='reward_group',
            palette=sns.color_palette(['#238443']),linestyle='--',
            ax=axes[2, 0], legend=False)
temp = data.loc[(data['trial']>=5)  & (data['cell_type']=='wM1') & (data['reward_group']=='R+')]
# temp = temp.groupby(['mouse_id', 'time', 'reward_group', 'cell_type'])['activity'].mean().reset_index()
sns.lineplot(data=temp, x='time', y='activity', errorbar='ci', hue='reward_group',
            palette=sns.color_palette(['#238443']),
            ax=axes[2, 0], legend=False)

temp = data.loc[(data['trial']<5)  & (data['cell_type']=='wM1') & (data['reward_group']=='R-')]
# temp = temp.groupby(['mouse_id', 'time', 'reward_group', 'cell_type'])['activity'].mean().reset_index()
sns.lineplot(data=temp, x='time', y='activity', errorbar='ci', hue='reward_group', linestyle='--',
            palette=sns.color_palette(['#d51a1c']),
            ax=axes[2, 1], legend=False)
temp = data.loc[(data['trial']>=5)  & (data['cell_type']=='wM1') & (data['reward_group']=='R-')]
# temp = temp.groupby(['mouse_id', 'time', 'reward_group', 'cell_type'])['activity'].mean().reset_index()
sns.lineplot(data=temp, x='time', y='activity', errorbar='ci', hue='reward_group',
            palette=sns.color_palette(['#d51a1c']),
            ax=axes[2, 1], legend=False)


for ax in axes.flatten():
    ax.axvline(0, color='#FF9600', linestyle='--')
    ax.set_title('')

fig.suptitle(f'First 5 whisker hits VS rest LMI {modulation}')




Text(0.5, 0.98, 'First 5 whisker hits VS rest LMI down')

In [222]:
(data['cell_type']=='allcells').sum()

np.int64(0)

## 1.2. Quantify those responses.

- Amplitude of the response
- Number of significant cells
- variance across days
- dimensionality across days
- ...


In [57]:
# Load data needed to compute before and after learning.

sampling_rate = 30
win = (1, 1.3)  # from stimulus onset to 300 ms after.
win = (int(win[0] * sampling_rate), int(win[1] * sampling_rate))
baseline_win = (0, 1)
baseline_win = (int(baseline_win[0] * sampling_rate), int(baseline_win[1] * sampling_rate))
days = ['-2', '-1', '0', '+1', '+2']
# Correlation matrix for a specific cell type
cell_type = None
variance_explained_thr = 0.7

_, _, mice, _ = io.select_sessions_from_db(db_path,
                                            nwb_dir,
                                            two_p_imaging='yes',)
mice = [m for m in mice if m not in ['AR163']]
print(mice)
len(mice)


['GF305', 'GF306', 'GF307', 'GF308', 'GF310', 'GF311', 'GF313', 'GF314', 'GF317', 'GF318', 'GF319', 'GF323', 'GF333', 'GF334', 'GF348', 'GF350', 'MI062', 'MI069', 'MI072', 'MI075', 'MI076', 'AR132', 'AR133', 'AR137', 'AR139', 'AR127', 'AR143', 'AR177', 'AR178', 'AR179', 'AR180']


31

In [58]:
average_response = {}
peak_response = {}
responsive_p_values = {}
dimensionality = {}
metadata = {}

lmi = {}
globally_responsive = {}

for mouse_id in mice:
    # Disregard these mice as the number of trials is too low.
    # if mouse_id in ['GF307', 'GF310', 'GF333', 'AR144', 'AR135']:
    #     continue
    session_list, nwb_files, _, db_filtered = io.select_sessions_from_db(db_path,
                                                                        nwb_dir,
                                                                        two_p_imaging='yes',
                                                                        subject_id=mouse_id,
                                                                        day=days,)
    print(session_list)

    data = []
    mdata_list = []
    for session_id in session_list:
        arr, mdata = imaging_utils.load_session_2p_imaging(mouse_id,
                                                            session_id,
                                                            processed_dir)
        arr = imaging_utils.substract_baseline(arr, 3, baseline_win)
        data.append(arr)
        mdata_list.append(mdata)

    # Extract UM trials.
    for i, arr in enumerate(data):
        arr = imaging_utils.extract_trials(arr, mdata_list[i], 'UM', n_trials=None)
        data[i] = arr

    # Get some metadata.
    reward_group = io.get_reward_group_from_db(db_path, session_list[0])
    metadata[mouse_id] = {}
    metadata[mouse_id]['reward_group'] = reward_group
    metadata[mouse_id]['cell_types'] = mdata['cell_types']
    
    average_response[mouse_id] = {}
    peak_response[mouse_id] = {}
    responsive_p_values[mouse_id] = {}
    dimensionality[mouse_id] = {}

    for cell_type in ['allcells', 'wS2', 'wM1']:
        # Select cell type.
        if cell_type == 'allcells':
            data_subtype = data
        else:
            data_subtype = []
            cell_type_mask = mdata_list[0]['cell_types']==cell_type
            data_subtype = [arr[cell_type_mask] for arr in data]

        # If no cells of the specified type, skip.
        if data_subtype[0].shape[0] == 0:
            continue

        # Compute average response for each trial, each day.
        # --------------------------------------------------

        average_response[mouse_id][cell_type] = []
        for day in data_subtype:
            average_response[mouse_id][cell_type].append(np.nanmean(day[:, :, win[0]:win[1]], axis=2))

        # Compute peak response for each trial, each day.
        # ------------------------------------------------

        peak_response[mouse_id][cell_type] = []
        for day in data_subtype:
            peak_response[mouse_id][cell_type].append(np.nanmax(day[:, :, win[0]:win[1]], axis=2))

        # # Compute standard deviation of population response.
        # # ----------------------------------------------------  

        # std[mouse_id][cell_type] = []
        # for day in data_subtype:
        #     std[mouse_id][cell_type].append(np.std(np.nanmean(day[:, :, win[0]:win[1]], axis=2), axis=0))


        # Test responsiveness.
        # --------------------

        baseline_avg = []
        response_avg = []
        for day in data_subtype:
            baseline_avg.append(np.nanmean(day[:, :, baseline_win[0]:baseline_win[1]], axis=2))
            response_avg.append(np.nanmean(day[:, :, win[0]:win[1]], axis=2))

        # Compare response amplitude to baseline.
        n_cells = data_subtype[0].shape[0]
        p_values = [np.zeros(n_cells) for _ in range(len(data_subtype))]
        for iday, day in enumerate(data_subtype):
            for icell in range(n_cells):
                # Temporary fix for cells that are always 0 due to baseline substraction.
                if np.all(baseline_avg[iday][icell] == 0.):
                    p_values[iday][icell] = 1
                else:
                    _, p_values[iday][icell] = wilcoxon(baseline_avg[iday][icell], response_avg[iday][icell])
        p_values = np.stack(p_values, axis=0)
        responsive_p_values[mouse_id][cell_type] = p_values


        # # Compute dimensionality of the population response.
        # # --------------------------------------------------

        # dimensionality[mouse_id][cell_type] = []
        # pca_results = []
        # for day in data_subtype:
        #     print(day.shape)
        #     X = np.mean(day[:,:,win[0]:win[1]], axis=2)
        #     X = X.T
        #     X = StandardScaler(with_mean=True, with_std=True).fit_transform(X)
        #     pca = PCA()
        #     model = pca.fit(X)
        #     n_comp = np.sum(model.explained_variance_ratio_.cumsum() < variance_explained_thr) + 1
        #     dimensionality[mouse_id][cell_type].append(n_comp)

['GF305_27112020_083119', 'GF305_28112020_103938', 'GF305_29112020_103331', 'GF305_30112020_110255', 'GF305_02122020_132229']
['GF306_27112020_104436', 'GF306_28112020_125555', 'GF306_29112020_131929', 'GF306_30112020_133249', 'GF306_02122020_161611']
['GF307_17112020_080325', 'GF307_18112020_075939', 'GF307_19112020_083908', 'GF307_20112020_082942', 'GF307_21112020_102608']
['GF308_17112020_105052', 'GF308_18112020_093627', 'GF308_19112020_103527', 'GF308_20112020_122826', 'GF308_21112020_135515']
['GF310_17112020_132720', 'GF310_18112020_122252', 'GF310_19112020_131953', 'GF310_20112020_150929', 'GF310_21112020_160059']
['GF311_17112020_155501', 'GF311_18112020_151838', 'GF311_19112020_160412', 'GF311_20112020_171609', 'GF311_21112020_180049']
['GF313_27112020_141857', 'GF313_28112020_154236', 'GF313_29112020_154625', 'GF313_30112020_154904', 'GF313_03122020_082147']
['GF314_27112020_160459', 'GF314_28112020_171800', 'GF314_29112020_174831', 'GF314_30112020_171906', 'GF314_03122020_1

Quantify population response across days.

In [83]:
mouse_ids = average_response.keys()

df = []
for mouse_id in mouse_ids:
    for cell_type in average_response[mouse_id].keys():
        for iday in range(len(days)):
            amp = np.nanmean(np.nanmean(average_response[mouse_id][cell_type][iday], axis=1), axis=0) * 100
            peak = np.nanmean(np.nanmean(peak_response[mouse_id][cell_type][iday], axis=1), axis=0) * 100
            prop_resp_05 = np.sum(responsive_p_values[mouse_id][cell_type][iday] <= 0.05) / responsive_p_values[mouse_id][cell_type][iday].size * 100
            prop_resp_01 = np.sum(responsive_p_values[mouse_id][cell_type][iday] <= 0.01) / responsive_p_values[mouse_id][cell_type][iday].size * 100
            # dim = dimensionality[mouse_id][cell_type][iday]
            temp = pd.DataFrame([[amp, peak, prop_resp_05, prop_resp_01, days[iday],
                                mouse_id, metadata[mouse_id]['reward_group'], cell_type]],
                                columns=['population_response', 'peak_response', 'prop_responsive_thr_0.05',
                                         'prop_responsive_thr_0.01', 'day', 'mouse_id', 'reward_group', 'cell_type'])
            df.append(temp)
df = pd.concat(df)

output_dir = r'//sv-nas1.rcp.epfl.ch/Petersen-Lab/analysis/Anthony_Renard/analysis_output/sensory_plasticity'
for cell_type in ['allcells', 'wS2', 'wM1']:

    svg_file = f'responses_across_learning_{cell_type}.svg'
    df_file = f'responses_across_learning_{cell_type}.csv'

    sns.set_theme(context='talk', style='ticks', palette='deep', font='sans-serif', font_scale=1)
    palette = sns.color_palette(['#238443', '#d51a1c'])

    fig, axes = plt.subplots(2, 2, figsize=(10, 6), sharex=True)
    sns.barplot(data=df[df.cell_type==cell_type], x='day', y='population_response', hue='reward_group',
                ax=axes[0,0], legend=False, hue_order=['R+', 'R-'], palette=palette)
    axes[0,0].set_title('Amplitude')
    axes[0,0].set_ylabel(r'% dF/F')
    axes[0,0].set_ylim([0, 6])

    sns.barplot(data=df[df.cell_type==cell_type], x='day', y='peak_response', hue='reward_group',
                ax=axes[1,0], legend=False, hue_order=['R+', 'R-'], palette=palette)
    axes[1,0].set_title('Peak')
    axes[1,0].set_ylabel(r'% dF/F')
    axes[1,0].set_ylim([0, 30])

    sns.barplot(data=df[df.cell_type==cell_type], x='day', y='prop_responsive_thr_0.01', hue='reward_group',
                ax=axes[1,1], hue_order=['R+', 'R-'], palette=palette)
    axes[1,1].set_title(r'% responsive cells (p<0.01)')
    axes[1,1].set_ylabel(r'% responsive')
    axes[1,1].set_ylim([0, 100])

    sns.despine()
    plt.tight_layout()
    plt.suptitle(cell_type)
    plt.savefig(os.path.join(output_dir, svg_file), format='svg')
    df.to_csv(os.path.join(output_dir, df_file), index=False)

Amplitude and proportion of significant cells only for each population.

In [82]:

output_dir = r'//sv-nas1.rcp.epfl.ch/Petersen-Lab/analysis/Anthony_Renard/analysis_output/sensory_plasticity/psth'

svg_file = f'amplitude_histogram_unmotivated.svg'
df_file = f'amplitude_histogram_unmotivated.csv'
pvalue_file = f'amplitude_histogram_unmotivated_pvalues.csv'

sns.set_theme(context='talk', style='ticks', palette='deep', font='sans-serif', font_scale=1)
palette = sns.color_palette(['#238443', '#d51a1c'])

fig = sns.catplot(data=df, x='day', y='population_response', hue='reward_group', col='cell_type',
            kind='bar', legend=False, hue_order=['R+', 'R-'], palette=palette)
plt.suptitle('Amplitude')
plt.ylabel(r'% dF/F')   
# plt.ylim([0, 100])

# Perform Mann-Whitney U test to check if the difference between the two reward groups is significant for each day and cell type.
results = []
for cell_type in df['cell_type'].unique():
    for day in df['day'].unique():
        group_rew = df[(df['day'] == day) & (df['reward_group'] == 'R+') & (df['cell_type'] == cell_type)]['population_response']
        group_unrew = df[(df['day'] == day) & (df['reward_group'] == 'R-') & (df['cell_type'] == cell_type)]['population_response']
        stat, p = mannwhitneyu(group_rew, group_unrew)
        results.append({'cell_type': cell_type, 'day': day, 'p_value': p})
        print(f'Cell type {cell_type} Day {day}: p-value = {p}')

# Convert results to a DataFrame and save as CSV
results_df = pd.DataFrame(results)
# Add stars to the plot for each subplot
for ax in fig.axes.flat:
    cell_type = ax.get_title().split(' = ')[-1]
    for result in results:
        if result['cell_type'] == cell_type:
            day_index = list(df['day'].unique()).index(result['day'])
            if result['p_value'] < 0.05:
                ax.text(day_index, 10, '*', ha='center', va='bottom', color='black')
            if result['p_value'] < 0.01:
                ax.text(day_index, 10, '**', ha='center', va='bottom', color='black')
            if result['p_value'] < 0.001:
                ax.text(day_index, 10, '***', ha='center', va='bottom', color='black')

sns.despine()
# plt.tight_layout()
plt.savefig(os.path.join(output_dir, svg_file), format='svg')
df.to_csv(os.path.join(output_dir, df_file), index=False)
results_df.to_csv(os.path.join(output_dir, pvalue_file), index=False)

Cell type allcells Day -2: p-value = 1.0
Cell type allcells Day -1: p-value = 0.9841653411945137
Cell type allcells Day 0: p-value = 0.02759273177042067
Cell type allcells Day +1: p-value = 0.0013964497904225218
Cell type allcells Day +2: p-value = 0.0010574718299194757
Cell type wS2 Day -2: p-value = 0.8028850291124622
Cell type wS2 Day -1: p-value = 0.5602735733449105
Cell type wS2 Day 0: p-value = 0.09067732403710933
Cell type wS2 Day +1: p-value = 0.02466970946003466
Cell type wS2 Day +2: p-value = 0.0327123776970761
Cell type wM1 Day -2: p-value = 0.8580276569875211
Cell type wM1 Day -1: p-value = 0.6333423006990662
Cell type wM1 Day 0: p-value = 0.37109336952269756
Cell type wM1 Day +1: p-value = 0.43823907554994224
Cell type wM1 Day +2: p-value = 0.10740463633025366


In [85]:

output_dir = r'//sv-nas1.rcp.epfl.ch/Petersen-Lab/analysis/Anthony_Renard/analysis_output/sensory_plasticity/psth'

svg_file = f'responsivecells_histogram_unmotivated.svg'
df_file = f'responsivecells_histogram_unmotivated.csv'
pvalue_file = f'responsivecells_histogram_unmotivated_pvalues.csv'

sns.set_theme(context='talk', style='ticks', palette='deep', font='sans-serif', font_scale=1)
palette = sns.color_palette(['#238443', '#d51a1c'])

fig = sns.catplot(data=df, x='day', y='prop_responsive_thr_0.01', hue='reward_group', col='cell_type',
            kind='bar', legend=False, hue_order=['R+', 'R-'], palette=palette)
plt.suptitle('Proportion responsive cells (MW test p<0.01)')
# plt.ylabel(r'% dF/F')   
plt.ylim([0, 100])

# Perform Mann-Whitney U test to check if the difference between the two reward groups is significant for each day and cell type.
results = []
for cell_type in df['cell_type'].unique():
    for day in df['day'].unique():
        group_rew = df[(df['day'] == day) & (df['reward_group'] == 'R+') & (df['cell_type'] == cell_type)]['prop_responsive_thr_0.01']
        group_unrew = df[(df['day'] == day) & (df['reward_group'] == 'R-') & (df['cell_type'] == cell_type)]['prop_responsive_thr_0.01']
        stat, p = mannwhitneyu(group_rew, group_unrew)
        results.append({'cell_type': cell_type, 'day': day, 'p_value': p})
        print(f'Cell type {cell_type} Day {day}: p-value = {p}')

# Convert results to a DataFrame and save as CSV
results_df = pd.DataFrame(results)
# Add stars to the plot for each subplot
for ax in fig.axes.flat:
    cell_type = ax.get_title().split(' = ')[-1]
    for result in results:
        if result['cell_type'] == cell_type:
            day_index = list(df['day'].unique()).index(result['day'])
            if result['p_value'] < 0.05:
                ax.text(day_index, 10, '*', ha='center', va='bottom', color='black')
            if result['p_value'] < 0.01:
                ax.text(day_index, 10, '**', ha='center', va='bottom', color='black')
            if result['p_value'] < 0.001:
                ax.text(day_index, 10, '***', ha='center', va='bottom', color='black')

sns.despine()
# plt.tight_layout()
plt.savefig(os.path.join(output_dir, svg_file), format='svg')
df.to_csv(os.path.join(output_dir, df_file), index=False)
results_df.to_csv(os.path.join(output_dir, pvalue_file), index=False)

Cell type allcells Day -2: p-value = 0.07090482557629343
Cell type allcells Day -1: p-value = 0.08422189624865307
Cell type allcells Day 0: p-value = 0.016328024119138175
Cell type allcells Day +1: p-value = 0.0007960612010168237
Cell type allcells Day +2: p-value = 0.00012788770850288427
Cell type wS2 Day -2: p-value = 0.38990740161738124
Cell type wS2 Day -1: p-value = 0.2916362440467125
Cell type wS2 Day 0: p-value = 0.24389001329596105
Cell type wS2 Day +1: p-value = 0.002038784480434768
Cell type wS2 Day +2: p-value = 0.0024970589395121056
Cell type wM1 Day -2: p-value = 0.17923911846412488
Cell type wM1 Day -1: p-value = 0.14360745604501088
Cell type wM1 Day 0: p-value = 0.24451814060285926
Cell type wM1 Day +1: p-value = 0.033846876635162365
Cell type wM1 Day +2: p-value = 0.002518533777682309


# 2. Correlation matrices and responsive similarity across learning.

## 2.1. For each mouse individually.

Plot population vectors rasters and correlation matrices.

In [40]:
# Parameters.

sampling_rate = 30
win = (1, 1.180)  # from stimulus onset to 300 ms after.
win_length = f'{int(np.round((win[1]-win[0]) * 1000))}'  # for file naming.
win = (int(win[0] * sampling_rate), int(win[1] * sampling_rate))
baseline_win = (0, 1)
baseline_win = (int(baseline_win[0] * sampling_rate), int(baseline_win[1] * sampling_rate))
days = ['-2', '-1', '0', '+1', '+2']
substract_baseline = True
sns.set_theme(context='paper', style='ticks', palette='deep', font='sans-serif', font_scale=1)

_, _, mice, _ = io.select_sessions_from_db(db_path,
                                            nwb_dir,
                                            two_p_imaging='yes',)
print(mice)
excluded_mice = ['GF307', 'GF310', 'GF333', 'MI075', 'AR144', 'AR135', 'AR163']
mice = [m for m in mice if m not in excluded_mice]
mice = ['AR127']

['GF305', 'GF306', 'GF307', 'GF308', 'GF310', 'GF311', 'GF313', 'GF314', 'GF317', 'GF318', 'GF319', 'GF323', 'GF333', 'GF334', 'GF348', 'GF350', 'MI062', 'MI069', 'MI072', 'MI075', 'MI076', 'AR132', 'AR133', 'AR137', 'AR139', 'AR127', 'AR143', 'AR163', 'AR177', 'AR178', 'AR179', 'AR180']


In [19]:
rewarded_mice = [mouse_id for mouse_id in metadata.keys() if metadata[mouse_id]['reward_group'] == 'R+']
count_rewarded_mice = len(rewarded_mice)
nonrewarded_mice = [mouse_id for mouse_id in metadata.keys() if metadata[mouse_id]['reward_group'] == 'R-']
count_nonrewarded_mice = len(nonrewarded_mice)
print(rewarded_mice)
print(len(rewarded_mice))
print(nonrewarded_mice)
print(len(nonrewarded_mice))


['GF305', 'GF306', 'GF308', 'GF311', 'GF313', 'GF314', 'GF317', 'GF318', 'GF323', 'GF334', 'AR133', 'AR127', 'AR143', 'AR177']
14
['GF319', 'GF348', 'GF350', 'MI062', 'MI069', 'MI072', 'MI076', 'AR132', 'AR137', 'AR139', 'AR178', 'AR179', 'AR180']
13


Load the data and save computations in dictionnaries.

In [18]:


def load_psth_data(mouse_id, days, win, baseline_win, processed_dir):
    session_list, _, _, _ = io.select_sessions_from_db(db_path,
                                                        nwb_dir,
                                                        two_p_imaging='yes',
                                                        subject_id=mouse_id,
                                                        day=days,)
    print(session_list)

    data = []
    mdata_list = []
    for session_id in session_list:
        arr, mdata = imaging_utils.load_session_2p_imaging(mouse_id,
                                                            session_id,
                                                            processed_dir)
        if substract_baseline:
            arr = imaging_utils.substract_baseline(arr, 3, baseline_win)
            
        data.append(arr)
        mdata_list.append(mdata)
        
        # Create xarray including the metadata
        coords = {'trial': np.arange(data[0].shape[1]), 'cell': np.arange(data[0].shape[0]), 'day': days}
        data_xr = xr.DataArray(np.stack(data, axis=-1), dims=('cell', 'trial', 'day'), coords=coords)
        data_xr.attrs['mouse_id'] = mouse_id
        data_xr.attrs['reward_group'] = reward_group
        data_xr.attrs['cell_types'] = mdata_list[0]['cell_types']
        data_xr.attrs['rois'] = mdata_list[0]['rois']



metadata = {}
response_amp = {}
corr_avg_days = {}
corr_avg_pre_post = {}
lmi = {}
lmi_p = {}
responsive_p_values = {}
globally_responsive = {}

# mice = ['AR127']

for mouse_id in mice:
    session_list, nwb_files, _, db_filtered = io.select_sessions_from_db(db_path,
                                                                        nwb_dir,
                                                                        two_p_imaging='yes',
                                                                        subject_id=mouse_id,
                                                                        day=days,)
    print(session_list)

    data = []
    mdata_list = []
    for session_id in session_list:
        arr, mdata = imaging_utils.load_session_2p_imaging(mouse_id,
                                                            session_id,
                                                            processed_dir)
        if substract_baseline:
            arr = imaging_utils.substract_baseline(arr, 3, baseline_win)
        data.append(arr)
        mdata_list.append(mdata)

    reward_group = io.get_reward_group_from_db(db_path, session_list[0])
    metadata[mouse_id] = {}
    metadata[mouse_id]['reward_group'] = reward_group
    metadata[mouse_id]['rois'] = mdata_list[0]['rois']
    metadata[mouse_id]['cell_types'] = mdata_list[0]['cell_types']
    
    for d, mday in enumerate(mdata_list):
        metadata[mouse_id][days[d]] = {}
        metadata[mouse_id][days[d]]['trials'] = mdata_list[d]['trials']
        metadata[mouse_id][days[d]]['trial_types'] = mdata_list[d]['trial_types']

    # Extract UM trials.
    for i, arr in enumerate(data):
        arr = imaging_utils.extract_trials(arr, mdata_list[i], trial_type, n_trials=None)
        data[i] = arr

    return data, 


metadata = {}
response_amp = {}
corr_avg_days = {}
corr_avg_pre_post = {}
responsive_p_values = {}
globally_responsive = {}

# Load lmi dicts.
lmi = np.load(r'\\sv-nas1.rcp.epfl.ch\Petersen-Lab\analysis\Anthony_Renard\data_processed\lmi.npy', allow_pickle=True).item()
lmi_p = np.load(r'\\sv-nas1.rcp.epfl.ch\Petersen-Lab\analysis\Anthony_Renard\data_processed\lmi_p.npy', allow_pickle=True).item()

# mice = ['AR127']




    response_amp[mouse_id] = {}
    corr_avg_days[mouse_id] = {}
    corr_avg_pre_post[mouse_id] = {}
    responsive_p_values[mouse_id] = {}
    globally_responsive[mouse_id] = {}

    for cell_type in ['allcells', 'wS2', 'wM1']:

        # Select cell type.
        if cell_type == 'allcells':
            data_subtype = data
        else:
            data_subtype = []
            cell_type_mask = mdata_list[0]['cell_types']==cell_type
            data_subtype = [arr[cell_type_mask] for arr in data]

        # if cell_type == 'allcells':
        #     # Example with and without strong cells for mouse AR127.
        #     strong_cells = [3,11,33,48,57,67,80,86,104,153,166,175]
        #     mask = np.ones(data_subtype[0].shape[0], dtype=bool)
        #     mask[strong_cells] = False
        #     data_subtype = [arr[mask] for arr in data_subtype]

        # If no cells of the specified type, skip.
        if data_subtype[0].shape[0] == 0:
            continue

        # Compute average response for each trial, each day.
        response_amp[mouse_id][cell_type] = []
        for day in data_subtype:
            response_amp[mouse_id][cell_type].append(np.nanmean(day[:, :, win[0]:win[1]], axis=2))

        # Compute LMI.
        # if cell_type == 'allcells':
        #     # # pre = np.mean(np.concatenate(response_amp[0:2], axis=1), axis=1)
        #     # # print(pre.shape)
        #     # # post = np.mean(np.concatenate((response_amp[5], response_amp[7]), axis=1), axis=1)
        #     # # lmi[mouse_id] = (post - pre) / (np.abs(post) + np.abs(pre))

        #     # lmis = []
        #     # ncells = len(metadata[mouse_id]['rois'])
        #     # for icell in range(ncells):
        #     #     # mapping trials of D-2, D-1, D+1, D+2.
        #     #     X = np.r_[response_amp[mouse_id][cell_type][0][icell], response_amp[mouse_id][cell_type][1][icell],
        #     #               response_amp[mouse_id][cell_type][3][icell], response_amp[mouse_id][cell_type][4][icell]]
        #     #     y = np.r_[np.zeros(response_amp[mouse_id][cell_type][0][icell].shape[0]),
        #     #               np.zeros(response_amp[mouse_id][cell_type][1][icell].shape[0]),
        #     #               np.ones(response_amp[mouse_id][cell_type][3][icell].shape[0]),
        #     #               np.ones(response_amp[mouse_id][cell_type][4][icell].shape[0])]
        #     #     fpr, tpr, _ = roc_curve(y, X)
        #     #     roc_auc = auc(fpr, tpr)
        #     #     lmis.append((roc_auc - 0.5) * 2)
        #     # lmi[mouse_id]['allcells'] = np.array(lmis)

        #     pre = [response_amp[mouse_id][cell_type][days.index('-2')],
        #            response_amp[mouse_id][cell_type][days.index('-1')]]
        #     pre = np.concatenate(pre, axis=1)
        #     post = [response_amp[mouse_id][cell_type][days.index('+1')],
        #             response_amp[mouse_id][cell_type][days.index('+2')]]
        #     post = np.concatenate(post, axis=1)
        #     lmi[mouse_id]['allcells'], lmi_p[mouse_id]['allcells'] = imaging_utils.compute_lmi(pre, post, nshuffles=None)
        # else:
        #     lmi[mouse_id][cell_type] = lmi[mouse_id]['allcells'][metadata[mouse_id]['cell_types'] == cell_type]
        #     lmi_p[mouse_id][cell_type] = lmi_p[mouse_id]['allcells'][metadata[mouse_id]['cell_types'] == cell_type]


        # # Test responsiveness.
        # if cell_type == 'allcells':
        #     base = []
        #     resp = []
        #     for day in data_subtype:
        #         base.append(np.nanmean(day[:, :, baseline_win[0]:baseline_win[1]], axis=2))
        #         resp.append(np.nanmean(day[:, :, win[0]:win[1]], axis=2))

        #     # Compare response amplitude to baseline.
        #     n_cells = data_subtype[0].shape[0]
        #     p_values = [np.zeros(n_cells) for _ in range(len(data_subtype))]
        #     for iday, day in enumerate(data_subtype):
        #         for icell in range(n_cells):
        #             # Temporary fix for cells that are always 0 due to baseline substraction.
        #             if np.all(base[iday][icell] == 0.):
        #                 p_values[iday][icell] = 1
        #             else:
        #                 _, p_values[iday][icell] = wilcoxon(base[iday][icell], resp[iday][icell])
        #     p_values = np.stack(p_values, axis=1)
        #     responsive_p_values[mouse_id][cell_type] = p_values

        #     # Test global responsiveness by pulling trials of all days together.
        #     base = np.concatenate(base, axis=1)
        #     resp = np.concatenate(resp, axis=1)
        #     p_values = np.zeros(n_cells)
        #     for icell in range(n_cells):
        #         # Temporary fix for cells that are always 0 due to baseline substraction.
        #         if np.all(base[icell] == 0.):
        #             p_values[icell] = 1
        #         else:
        #             _, p_values[icell] = wilcoxon(base[icell], resp[icell])
        #         globally_responsive[mouse_id][cell_type] = p_values
        # else:
        #     responsive_p_values[mouse_id][cell_type] = responsive_p_values[mouse_id]['allcells'][metadata[mouse_id]['cell_types'] == cell_type]
        #     globally_responsive[mouse_id][cell_type] = globally_responsive[mouse_id]['allcells'][metadata[mouse_id]['cell_types'] == cell_type]


['AR127_20240221_133407', 'AR127_20240222_152629', 'AR127_20240223_131820', 'AR127_20240224_140853', 'AR127_20240225_142858']


Plot individual population vector rasters and correlation matrix.

In [19]:
cell_selection = 'no_selection'
responsiveness_thr = 0.001
# percent_best_lmi = 15
# lmi_thr = np.percentile(np.abs(np.concatenate([lmi[mouse_id]['allcells'] for mouse_id in mice])), 100-percent_best_lmi)
sns.set_theme(context='paper', style='ticks', palette='deep', font='sans-serif', font_scale=1)

In [22]:
pdf_file = f'correlation_matrices_pop_vector_individual_mice_win_{win_length}_ms_{cell_selection}.pdf'
output_dir = fr'//sv-nas1.rcp.epfl.ch/Petersen-Lab/analysis/Anthony_Renard/analysis_output/sensory_plasticity/correlation_matrices'

with PdfPages(os.path.join(output_dir, pdf_file)) as pdf:
    for mouse_id in mice:
        for cell_type in ['allcells', 'wS2', 'wM1']:
            reward_group = metadata[mouse_id]['reward_group']

            if cell_type not in response_amp[mouse_id].keys():
                continue

            if cell_selection == 'no_selection':
                selected_cells = np.ones(response_amp[mouse_id][cell_type][0].shape[0], dtype=bool)
            elif cell_selection == 'responsive':
                selected_cells = globally_responsive[mouse_id][cell_type] <= responsiveness_thr
            elif cell_selection == 'lmi':
                selected_cells = (lmi_p[mouse_id][cell_type] >= 0.975) | (lmi_p[mouse_id][cell_type] <= 0.025)
            
            if np.sum(selected_cells) == 0:
                continue

            pop_vectors = [np.copy(arr[selected_cells]) for arr in response_amp[mouse_id][cell_type]]
            n_trials = [arr.shape[1] for arr in pop_vectors]
            print(f'{mouse_id} {n_trials}')
            pop_vectors = np.concatenate(pop_vectors, axis=1)
            corr_matrix = np.corrcoef(pop_vectors.T)

            # Plot population vectors.
            vmax = np.percentile(pop_vectors, 98)
            vmin = np.percentile(pop_vectors, 2)
            edges = np.cumsum(n_trials)
            for i in edges[:-1] - 0.5:
                plt.axvline(x=i, color='#252525', linestyle='-', lw=0.5)
            plt.xticks(edges - 0.5, edges)
            f = plt.figure()
            im = plt.imshow(pop_vectors, cmap='viridis', vmin=vmin, vmax=vmax)
            cbar = f.colorbar(im, ticks=[vmin, 0, vmax])
            cbar.ax.set_yticklabels([f'{vmin:.2f}', '0', f'> {vmax:.2f}'])
            cbar.ax.tick_params(size=0)
            pdf.savefig(dpi=300)
            plt.close()

            # Set color map limit to the max without the diagonal.
            vmax = np.percentile(corr_matrix[~np.eye(corr_matrix.shape[0], dtype=bool)], 98)
            vmin = np.percentile(corr_matrix, 2)
            f = plt.figure(figsize=(15, 15))
            im = plt.imshow(corr_matrix, vmin = vmin, vmax=vmax, cmap='viridis')
            n_trials = [arr.shape[1] for arr in response_amp[mouse_id][cell_type]]

            for i in edges[:-1] - 0.5:
                plt.axvline(x=i, color='#252525', linestyle='-', lw=0.5)
                plt.axhline(y=i, color='#252525', linestyle='-', lw=0.5)
            plt.xticks(edges - 0.5, edges)
            plt.yticks(edges - 0.5, edges)
            plt.title(f'{mouse_id} {reward_group} {cell_type}')
            cbar_ax = f.add_axes([0.85, 0.15, 0.05, 0.7])
            cbar = f.colorbar(im, cax=cbar_ax, ticks=[vmin, 0, vmax])
            cbar.ax.set_yticklabels([f'{vmin:.2f}', '0', f'> {vmax:.2f}'])
            cbar.ax.tick_params(size=0)
            pdf.savefig(dpi=300)
            plt.close()

GF305 [50, 50, 50, 50, 50]
GF305 [50, 50, 50, 50, 50]
GF305 [50, 50, 50, 50, 50]
GF306 [50, 50, 50, 50, 50]
GF306 [50, 50, 50, 50, 50]


  c = cov(x, y, rowvar, dtype=dtype)
  c *= np.true_divide(1, fact)
  c *= np.true_divide(1, fact)


GF306 [50, 50, 50, 50, 50]
GF308 [50, 50, 50, 50, 50]
GF308 [50, 50, 50, 50, 50]
GF308 [50, 50, 50, 50, 50]
GF311 [50, 50, 50, 50, 50]
GF311 [50, 50, 50, 50, 50]
GF313 [50, 50, 50, 50, 50]
GF313 [50, 50, 50, 50, 50]
GF313 [50, 50, 50, 50, 50]
GF314 [50, 50, 50, 50, 50]
GF314 [50, 50, 50, 50, 50]
GF314 [50, 50, 50, 50, 50]


  c = cov(x, y, rowvar, dtype=dtype)
  c *= np.true_divide(1, fact)
  c *= np.true_divide(1, fact)


GF317 [50, 50, 50, 50, 50]
GF317 [50, 50, 50, 50, 50]
GF317 [50, 50, 50, 50, 50]
GF318 [50, 50, 50, 50, 50]
GF318 [50, 50, 50, 50, 50]
GF318 [50, 50, 50, 50, 50]
GF319 [50, 50, 50, 50, 50]
GF323 [50, 50, 50, 50, 50]
GF323 [50, 50, 50, 50, 50]
GF323 [50, 50, 50, 50, 50]
GF334 [50, 50, 50, 50, 50]
GF334 [50, 50, 50, 50, 50]
GF334 [50, 50, 50, 50, 50]
GF348 [50, 50, 50, 50, 50]
GF350 [50, 50, 50, 50, 50]
MI062 [50, 50, 50, 50, 50]
MI069 [50, 50, 50, 50, 50]
MI069 [50, 50, 50, 50, 50]
MI072 [50, 50, 50, 50, 50]
MI072 [50, 50, 50, 50, 50]
MI072 [50, 50, 50, 50, 50]
MI076 [50, 50, 50, 50, 50]
MI076 [50, 50, 50, 50, 50]
MI076 [50, 50, 50, 50, 50]
AR132 [50, 50, 50, 50, 50]
AR132 [50, 50, 50, 50, 50]
AR132 [50, 50, 50, 50, 50]
AR133 [50, 49, 50, 50, 50]
AR133 [50, 49, 50, 50, 50]
AR137 [50, 49, 50, 50, 50]
AR137 [50, 49, 50, 50, 50]
AR137 [50, 49, 50, 50, 50]
AR139 [50, 50, 50, 50, 50]
AR139 [50, 50, 50, 50, 50]
AR127 [49, 49, 50, 49, 49]
AR127 [49, 49, 50, 49, 49]
AR143 [50, 50, 50, 50, 50]
A

  c = cov(x, y, rowvar, dtype=dtype)
  c *= np.true_divide(1, fact)
  c *= np.true_divide(1, fact)


AR180 [50, 50, 50, 50, 50]
AR180 [50, 50, 50, 50, 50]
AR180 [50, 50, 50, 50, 50]


## 2.2. Global population correlation matrix and population vectors.

In [12]:
# Parameters

zscore = False
cell_selection = 'lmi'
responsiveness_thr = 0.01
# percent_best_lmi = 20
# lmi_thr = np.percentile(np.abs(np.concatenate([lmi[mouse_id]['allcells'] for mouse_id in mice])), 100-percent_best_lmi)
sns.set_theme(context='paper', style='ticks', palette='deep', font='sans-serif', font_scale=1)

lmi = np.load(r'\\sv-nas1.rcp.epfl.ch\Petersen-Lab\analysis\Anthony_Renard\data_processed\lmi.npy', allow_pickle=True).item()
lmi_p = np.load(r'\\sv-nas1.rcp.epfl.ch\Petersen-Lab\analysis\Anthony_Renard\data_processed\lmi_p.npy', allow_pickle=True).item()

In [31]:
output_dir = fr'//sv-nas1.rcp.epfl.ch/Petersen-Lab/analysis/Anthony_Renard/analysis_output/sensory_plasticity/correlation_matrices'
pdf_file = f'correlation_matrices_pop_vector_global_population_win_{win_length}_ms_cell_selection_{cell_selection}_zscore_{zscore}.pdf'

response_amp_selection = {}
with PdfPages(os.path.join(output_dir, pdf_file)) as pdf:
    for cell_type in ['allcells', 'wS2', 'wM1']:
        for reward_group in ['R+', 'R-']:
            mice_ids = [mouse_id for mouse_id in mice if (metadata[mouse_id]['reward_group'] == reward_group)
                                                         and (cell_type in response_amp[mouse_id].keys())]
            mice_ids = [m for m in mice_ids if m not in ['MI069', 'MI072']]
            print(cell_type, reward_group, mice_ids)
            # # Copying because I will take a subset of the data.
            # response_amp_selection = np.copy(response_amp).item()
            # response_amp_selection = {mouse_id: response_amp_selection[mouse_id] for mouse_id in mice_ids if cell_type in response_amp_selection[mouse_id].keys()}

            response_amp_selection[cell_type] = {}
            for mouse_id in mice_ids:

                response_amp_selection[cell_type][mouse_id] = []

                if cell_selection == 'no_selection':
                    selected_cells = np.ones(response_amp[mouse_id][cell_type][0].shape[0], dtype=bool)
                elif cell_selection == 'responsive':
                    selected_cells = globally_responsive[mouse_id][cell_type] <= responsiveness_thr
                elif cell_selection == 'lmi':
                    selected_cells = (lmi_p[mouse_id][cell_type] >= 0.975) | (lmi_p[mouse_id][cell_type] <= 0.025)


                for iday in range(len(days)):
                    temp = [np.copy(response_amp[mouse_id][cell_type][iday][selected_cells]) for iday in range(len(days))]
                    response_amp_selection[cell_type][mouse_id] = temp

            # Even the number of trials per days across mice.
            min_trials = []
            for iday in range(len(days)):
                m = [data[iday].shape[1] for _, data in response_amp_selection[cell_type].items()]
                print(m)
                min_trials.append(np.min(m))
            print(min_trials)

            for mouse_id, data in response_amp_selection[cell_type].items():
                for iday in range(len(days)):
                    response_amp_selection[cell_type][mouse_id][iday] = data[iday][:, :min_trials[iday]]
            
            pop_vectors = np.concatenate([np.concatenate(data, axis=1) for _, data in response_amp_selection[cell_type].items()], axis=0)
            # AR180 has some cells with 0. response whichi gives nan.
            pop_vectors = np.nan_to_num(pop_vectors, nan=0.0)
            if zscore:
                pop_vectors = (pop_vectors - np.mean(pop_vectors, axis=1, keepdims=True)) / np.std(pop_vectors, axis=1, keepdims=True)
                pop_vectors = np.nan_to_num(pop_vectors, nan=0.0)


            
            corr_matrix = np.corrcoef(pop_vectors.T)
            # corr_matrix = spearmanr(pop_vectors.T, axis=1)[0]

            # Plot population vectors.
            # To have same color scale across R+ and R-.
            if reward_group == 'R+':
                vmax_vectors = np.percentile(pop_vectors, 98)
                vmin_vectors = np.percentile(pop_vectors, 1)

            if cell_type == 'allcells':
                # Split in subplots for readability.
                f, axes = plt.subplots(1, 4)
                for i, (cell_start, cell_end) in enumerate([(0, 600), (600, 1200), (1200, 1800),(1800, 2400)]):
                    im = axes[i].imshow(pop_vectors[cell_start:cell_end], cmap='viridis', vmin=vmin_vectors, vmax=vmax_vectors)
                    n_trials = min_trials
                    edges = np.cumsum(n_trials)
                    for j in edges[:-1] - 0.5:
                        axes[i].axvline(x=j, color='#252525', linestyle='-', lw=0.5)
                    axes[i].set_xticks(edges - 0.5)
                    axes[i].set_xticklabels(edges)
                cb_ax = f.add_axes([.91,.124,.04,.754])
                cbar = f.colorbar(im, ticks=[vmin_vectors, 0, vmax_vectors], cax=cb_ax)
                cbar.ax.set_yticklabels([f'<{vmin_vectors:.2f}', '0', f'> {vmax_vectors:.2f}'])
                cbar.ax.tick_params(size=0)
                plt.suptitle(f'Population vectors {reward_group} {cell_type}')
                pdf.savefig(dpi=300)
                plt.close()
                plt.tight_layout()
            else:
                f = plt.figure()
                im = plt.imshow(pop_vectors, cmap='viridis', vmin=vmin_vectors, vmax=vmax_vectors)
                n_trials = min_trials
                edges = np.cumsum(n_trials)
                for i in edges[:-1] - 0.5:
                    plt.axvline(x=i, color='#252525', linestyle='-', lw=0.5)
                plt.xticks(edges - 0.5, edges)
                cbar = f.colorbar(im, ticks=[vmin_vectors, 0, vmax_vectors])
                cbar.ax.set_yticklabels([f'<{vmin_vectors:.2f}', '0', f'> {vmax_vectors:.2f}'])
                cbar.ax.tick_params(size=0)
                plt.title(f'Population vectors {reward_group} {cell_type}')
                pdf.savefig(dpi=300)
                plt.close()


            # Set color map limit to the max without the diagonal.
            if reward_group == 'R+':
                vmax_matrix = np.percentile(corr_matrix[~np.eye(corr_matrix.shape[0], dtype=bool)], 98)
                vmin_matrix = np.percentile(corr_matrix, 2)

            f = plt.figure(figsize=(15, 15))
            im = plt.imshow(corr_matrix, vmin = vmin_matrix, vmax=vmax_matrix, cmap='viridis')
            n_trials = min_trials
            edges = np.cumsum(n_trials)
            for i in edges[:-1] - 0.5:
                plt.axvline(x=i, color='#252525', linestyle='-', lw=0.5)
                plt.axhline(y=i, color='#252525', linestyle='-', lw=0.5)
            plt.xticks(edges - 0.5, edges)
            plt.yticks(edges - 0.5, edges)
            plt.title(f'Correlation over trials {reward_group} {cell_type}')
            cbar_ax = f.add_axes([0.85, 0.15, 0.05, 0.7])
            cbar = f.colorbar(im, cax=cbar_ax, ticks=[vmin_matrix, 0, vmax_matrix])
            cbar.ax.set_yticklabels([f'{vmin_matrix:.2f}', '0', f'> {vmax_matrix:.2f}'])
            cbar.ax.tick_params(size=0)
            pdf.savefig(dpi=300)
            plt.close()

allcells R+ ['GF305', 'GF306', 'GF308', 'GF311', 'GF313', 'GF314', 'GF317', 'GF318', 'GF323', 'GF334', 'AR133', 'AR127', 'AR143', 'AR177']
[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 49, 50, 50]
[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 49, 49, 50, 50]
[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50]
[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 49, 50, 50]
[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 49, 50, 50]
[np.int64(49), np.int64(49), np.int64(50), np.int64(49), np.int64(49)]


  im = axes[i].imshow(pop_vectors[cell_start:cell_end], cmap='viridis', vmin=vmin_vectors, vmax=vmax_vectors)
  im = axes[i].imshow(pop_vectors[cell_start:cell_end], cmap='viridis', vmin=vmin_vectors, vmax=vmax_vectors)
  im = axes[i].imshow(pop_vectors[cell_start:cell_end], cmap='viridis', vmin=vmin_vectors, vmax=vmax_vectors)


allcells R- ['GF319', 'GF348', 'GF350', 'MI062', 'MI076', 'AR132', 'AR137', 'AR139', 'AR178', 'AR179', 'AR180']
[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50]
[50, 50, 50, 50, 50, 50, 49, 50, 50, 50, 50]
[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50]
[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50]
[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50]
[np.int64(50), np.int64(49), np.int64(50), np.int64(50), np.int64(50)]


  im = axes[i].imshow(pop_vectors[cell_start:cell_end], cmap='viridis', vmin=vmin_vectors, vmax=vmax_vectors)
  im = axes[i].imshow(pop_vectors[cell_start:cell_end], cmap='viridis', vmin=vmin_vectors, vmax=vmax_vectors)
  im = axes[i].imshow(pop_vectors[cell_start:cell_end], cmap='viridis', vmin=vmin_vectors, vmax=vmax_vectors)


wS2 R+ ['GF305', 'GF306', 'GF308', 'GF311', 'GF313', 'GF314', 'GF317', 'GF318', 'GF323', 'GF334', 'AR133', 'AR143', 'AR177']
[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50]
[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 49, 50, 50]
[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50]
[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50]
[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50]
[np.int64(50), np.int64(49), np.int64(50), np.int64(50), np.int64(50)]
wS2 R- ['MI076', 'AR132', 'AR137', 'AR139', 'AR178', 'AR179', 'AR180']
[50, 50, 50, 50, 50, 50, 50]
[50, 50, 49, 50, 50, 50, 50]
[50, 50, 50, 50, 50, 50, 50]
[50, 50, 50, 50, 50, 50, 50]
[50, 50, 50, 50, 50, 50, 50]
[np.int64(50), np.int64(49), np.int64(50), np.int64(50), np.int64(50)]
wM1 R+ ['GF305', 'GF306', 'GF308', 'GF311', 'GF313', 'GF314', 'GF317', 'GF318', 'GF323', 'GF334', 'AR127', 'AR143', 'AR177']
[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 49, 50, 50]
[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 49, 50, 50]
[50, 50, 50, 50, 50, 

## Correlation matrices with average population vectors.


In [7]:
# Parameters
cell_selection = 'lmi'
responsiveness_thr = 0.01
percent_best_lmi = 20
zscore = True
average_trials = 50  # Number of trials to average for each day.
lmi_thr = np.percentile(np.abs(np.concatenate([lmi[mouse_id]['allcells'] for mouse_id in mice])), 100-percent_best_lmi)
sns.set_theme(context='paper', style='ticks', palette='deep', font='sans-serif', font_scale=1)

In [11]:
output_dir = fr'//sv-nas1.rcp.epfl.ch/Petersen-Lab/analysis/Anthony_Renard/analysis_output/sensory_plasticity/correlation_matrices'
pdf_file = f'correlation_matrices_average_pop_vector_global_population_win_{win_length}_ms_cell_selection_{cell_selection}_zscore_{zscore}.pdf'

response_amp_selection = {}
with PdfPages(os.path.join(output_dir, pdf_file)) as pdf:
    for cell_type in ['allcells', 'wS2', 'wM1']:
        for reward_group in ['R+', 'R-']:
            mice_ids = [mouse_id for mouse_id in mice if (metadata[mouse_id]['reward_group'] == reward_group)
                                                            and (cell_type in response_amp[mouse_id].keys())]
            # mice_ids = [m for m in mice_ids if m not in ['MI072', 'MI069', 'AR132']]
            # print(cell_type, reward_group, mice_ids)
            # # Copying because I will take a subset of the data.
            # response_amp_selection = np.copy(response_amp).item()
            # response_amp_selection = {mouse_id: response_amp_selection[mouse_id] for mouse_id in mice_ids if cell_type in response_amp_selection[mouse_id].keys()}

            response_amp_selection[cell_type] = {}
            for mouse_id in mice_ids:

                response_amp_selection[cell_type][mouse_id] = []

                if cell_selection == 'no_selection':
                    selected_cells = np.ones(response_amp[mouse_id][cell_type][0].shape[0], dtype=bool)
                elif cell_selection == 'responsive':
                    selected_cells = globally_responsive[mouse_id][cell_type] <= responsiveness_thr
                elif cell_selection == 'lmi':
                    
                    selected_cells = np.abs(lmi[mouse_id][cell_type]) >= lmi_thr


                for iday in range(len(days)):
                    temp = [np.copy(response_amp[mouse_id][cell_type][iday][selected_cells]) for iday in range(len(days))]
                    response_amp_selection[cell_type][mouse_id] = temp

            # Even the number of trials per days across mice.
            min_trials = []
            for iday in range(len(days)):
                m = [data[iday].shape[1] for _, data in response_amp_selection[cell_type].items()]
                print(m)
                min_trials.append(np.min(m))
            print(min_trials)

            for mouse_id, data in response_amp_selection[cell_type].items():
                print(mouse_id)
                # If no cells of the specified type, skip.
                if data[0].shape[0] == 0:
                    continue
                for iday in range(len(days)):
                    temp = data[iday][:, :min_trials[iday]]
                    
                    # Average trials.
                    if temp.shape[1] % average_trials != 0:
                        # Add nan to complete the last average if needed by missing trials.
                        temp = np.concatenate([temp, np.full((temp.shape[0], average_trials - (temp.shape[1] % average_trials)), np.nan)], axis=1)
                    temp = np.nanmean(temp.reshape(temp.shape[0], -1, average_trials), axis=2)
                    response_amp_selection[cell_type][mouse_id][iday] = temp
            
            print([np.concatenate(data, axis=1).shape for _, data in response_amp_selection[cell_type].items()])
            pop_vectors = np.concatenate([np.concatenate(data, axis=1) for _, data in response_amp_selection[cell_type].items()], axis=0)
            # AR180 has some cells with 0. response whichi gives nan.
            pop_vectors = np.nan_to_num(pop_vectors, nan=0.0)
            if zscore:
                pop_vectors = (pop_vectors - np.mean(pop_vectors, axis=1, keepdims=True)) / np.std(pop_vectors, axis=1, keepdims=True)
                pop_vectors = np.nan_to_num(pop_vectors, nan=0.0)

            corr_matrix = np.corrcoef(pop_vectors.T)
            # corr_matrix = spearmanr(pop_vectors.T, axis=1)[0]

            # Set color map limit to the max without the diagonal.
            if reward_group == 'R+':
                vmax_matrix = np.percentile(corr_matrix[~np.eye(corr_matrix.shape[0], dtype=bool)], 98)
                vmin_matrix = np.percentile(corr_matrix, 1)

            f = plt.figure(figsize=(15, 15))
            im = plt.imshow(corr_matrix, vmin = 0, vmax=1, cmap='viridis')
            n_trials = [50 / average_trials] * 5
            edges = np.cumsum(n_trials)
            for i in edges[:-1] - 0.5:
                plt.axvline(x=i, color='#252525', linestyle='-', lw=0.5)
                plt.axhline(y=i, color='#252525', linestyle='-', lw=0.5)
            plt.xticks(edges - 0.5, edges)
            plt.yticks(edges - 0.5, edges)
            plt.title(f'Correlation over trials {reward_group} {cell_type}')
            cbar_ax = f.add_axes([0.85, 0.15, 0.05, 0.7])
            cbar = f.colorbar(im, cax=cbar_ax, ticks=[vmin_matrix, 0, vmax_matrix])
            cbar.ax.set_yticklabels([f'{vmin_matrix:.2f}', '0', f'> {vmax_matrix:.2f}'])
            cbar.ax.tick_params(size=0)

            pdf.savefig(dpi=300)
            plt.close()

[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 49, 50, 50]
[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 49, 49, 50, 50]
[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50]
[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 49, 50, 50]
[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 49, 50, 50]
[np.int64(49), np.int64(49), np.int64(50), np.int64(49), np.int64(49)]
GF305
GF306
GF308
GF311
GF313
GF314
GF317
GF318
GF323
GF334
AR133
AR127
AR143
AR177
[(40, 5), (58, 5), (24, 5), (23, 5), (53, 5), (46, 5), (37, 5), (45, 5), (54, 5), (49, 5), (34, 5), (50, 5), (39, 5), (8, 5)]
[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50]
[50, 50, 50, 50, 50, 50, 50, 50, 49, 50, 50, 50, 50]
[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50]
[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50]
[50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50]
[np.int64(50), np.int64(49), np.int64(50), np.int64(50), np.int64(50)]
GF319
GF348
GF350
MI062
MI069
MI072
MI076
AR132
AR137
AR139
AR178
AR179
AR180
[(23, 5), (31, 5), (

ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 1, the array at index 0 has size 5 and the array at index 3 has size 250

In [None]:

response_amp_selection[cell_type][mouse_id][iday].shape
cell_type
[np.concatenate(data, axis=1).shape for _, data in response_amp_selection[cell_type].items()]

[(5, 5),
 (1, 5),
 (2, 5),
 (0, 250),
 (3, 5),
 (7, 5),
 (5, 5),
 (10, 5),
 (14, 5),
 (9, 5),
 (7, 5),
 (4, 5),
 (2, 5)]

In [161]:
for mouse_id in mice:
    data = response_amp_selection[cell_type][mouse_id]
    for d in data:
        print(d.shape)


(2, 5)
(2, 5)
(2, 5)
(2, 5)
(2, 5)
(0, 50)
(0, 50)
(0, 50)
(0, 50)
(0, 50)
(2, 5)
(2, 5)
(2, 5)
(2, 5)
(2, 5)
(0, 50)
(0, 50)
(0, 50)
(0, 50)
(0, 50)
(1, 5)
(1, 5)
(1, 5)
(1, 5)
(1, 5)
(5, 5)
(5, 5)
(5, 5)
(5, 5)
(5, 5)
(3, 5)
(3, 5)
(3, 5)
(3, 5)
(3, 5)
(6, 5)
(6, 5)
(6, 5)
(6, 5)
(6, 5)


KeyError: 'GF319'

## 2.3. Quantify the correlation between days.

- average correlation inside each day -- response variability
- average correlation between pre and post learning 

In [146]:
# Cell selection.

cell_selection = 'responsive'
responsiveness_thr = 0.001
percent_best_lmi = 15
sns.set_theme(context='paper', style='ticks', palette='deep', font='sans-serif', font_scale=1)


In [162]:
output_dir = fr'//sv-nas1.rcp.epfl.ch/Petersen-Lab/analysis/Anthony_Renard/analysis_output/sensory_plasticity/correlation_matrices'
pdf_file = f'correlation_matrix_quantification_{win_length}_ms_cell_selection_{cell_selection}.pdf'

corr_avg_days = {}
corr_avg_pre_post = {}

# Because cell selection changes the data.
response_amp_copy = np.copy(response_amp).item()

for mouse_id in mouse_ids:
    corr_avg_days[mouse_id] = {}
    corr_avg_pre_post[mouse_id] = {}

    for cell_type in ['allcells', 'wS2', 'wM1']:

        if cell_type not in response_amp[mouse_id].keys():
            continue
        
        if cell_selection == 'no_selection':
            pop_vectors = np.concatenate(response_amp_copy[mouse_id][cell_type], axis=1)
        elif cell_selection == 'responsive':
            responsive_cells = globally_responsive[mouse_id][cell_type] <= responsiveness_thr
            pop_vectors = np.concatenate([response_amp_copy[mouse_id][cell_type][iday][responsive_cells, :] for iday in range(len(days))], axis=1)
            print(f'{mouse_id} {cell_type} {np.sum(responsive_cells)}/ {responsive_cells.shape[0]} responsive cells')
        elif cell_selection == 'lmi':
            lmi_thr = np.percentile(np.abs(np.concatenate([lmi[mouse_id]['allcells'] for mouse_id in mice])), 100-percent_best_lmi)
            modulated_cells = lmi[mouse_id][cell_type] >= lmi_thr
            pop_vectors = np.concatenate([response_amp_copy[mouse_id][cell_type][iday][modulated_cells, :] for iday in range(len(days))], axis=1)

        # if cell_selection == 'responsive_cells':
        #     corr_matrix = np.corrcoef(pop_vectors[mouse_id][cell_type][globally_responsive[mouse_id][cell_type], :].T)
        # elif cell_selection == 'lmi':
        #     corr_matrix = np.corrcoef(pop_vectors[mouse_id][cell_type][lmi[mouse_id][cell_type], :].T)
        # else:
        #     corr_matrix = np.corrcoef(pop_vectors[mouse_id][cell_type].T)

        corr_matrix = np.corrcoef(pop_vectors.T)

        # Compute average correlation inside each days.
        corr_avg_days[mouse_id][cell_type] = []
        n_trials = [arr.shape[1] for arr in response_amp[mouse_id][cell_type]]
        print(np.cumsum(n_trials))
        for start, end in zip(np.cumsum([0] + n_trials[:-1]), np.cumsum(n_trials)):
            upper_triangle = np.triu(corr_matrix[start:end, start:end], k=1)
            corr_avg_days[mouse_id][cell_type].append(np.mean(upper_triangle))

        # Compare correlation between inside pre training days,
        # inside post training days and between pre and post training days.
        trial_cumsum = np.cumsum([0] + n_trials)
        pre_in_start_x, pre_in_end_x = trial_cumsum[1], trial_cumsum[2]
        pre_in_start_y, pre_in_end_y = trial_cumsum[0], trial_cumsum[1]
        pre_in = np.mean(corr_matrix[pre_in_start_x:pre_in_end_x, pre_in_start_y:pre_in_end_y])

        post_in_start_x, post_in_end_x = trial_cumsum[4], trial_cumsum[5]
        post_in_start_y, post_in_end_y = trial_cumsum[3], trial_cumsum[4]
        post_in = np.mean(corr_matrix[post_in_start_x:post_in_end_x, post_in_start_y:post_in_end_y])

        pre_post_start_x, pre_post_end_x = trial_cumsum[3], trial_cumsum[5]
        pre_post_start_y, pre_post_end_y = trial_cumsum[0], trial_cumsum[2]
        pre_post = np.mean(corr_matrix[pre_post_start_x:pre_post_end_x, pre_post_start_y:pre_post_end_y])

        corr_avg_pre_post[mouse_id][cell_type] = [pre_in, post_in, pre_post]

# Convert to pandas.
# ------------------

df_corr_days = []
for mouse_id in mouse_ids:
    for cell_type in corr_avg_days[mouse_id].keys():
        for iday in range(len(days)):
            corr = corr_avg_days[mouse_id][cell_type][iday]
            temp = pd.DataFrame([[corr, days[iday], cell_type, mouse_id, metadata[mouse_id]['reward_group']]],
                                columns=['correlation','day', 'cell_type', 'mouse_id', 'reward_group'])
            df_corr_days.append(temp)
df_corr_days = pd.concat(df_corr_days, ignore_index=True)


compare = ['pre_in', 'post_in', 'pre_post']
df_corr_pre_post = []
for mouse_id in mouse_ids:
    for cell_type in corr_avg_days[mouse_id].keys():
        for i, comp in enumerate(compare):
            corr = corr_avg_pre_post[mouse_id][cell_type][i]
            temp = pd.DataFrame([[corr, comp, cell_type, mouse_id, metadata[mouse_id]['reward_group']]],
                                columns=['correlation', 'comparison', 'cell_type', 'mouse_id', 'reward_group'])
            df_corr_pre_post.append(temp)
df_corr_pre_post = pd.concat(df_corr_pre_post, ignore_index=True)


# Save plot and stats.
# --------------------

with PdfPages(os.path.join(output_dir, pdf_file)) as pdf:

    palette = sns.color_palette(['#238443', '#d51a1c'])
    for ct in ['allcells', 'wS2', 'wM1']:

        plt.figure()
        sns.barplot(data=df_corr_days.loc[df_corr_days.cell_type==ct], x='day', y='correlation', hue='reward_group', palette=palette, hue_order=['R+', 'R-'])
        sns.despine()
        plt.title(f'Correlation inside days - {ct}')
        pdf.savefig(dpi=300)
        plt.close()

        # Perform Mann-Whitney U test to check if the difference between the two reward groups is significant for each day.
        p_values = []
        for day in days:
            group_rew = df_corr_days[(df_corr_days['day'] == day) & (df_corr_days['reward_group'] == 'R+') & (df_corr_days.cell_type==ct)]['correlation']
            group_rew = group_rew[~np.isnan(group_rew)]
            group_unrew = df_corr_days[(df_corr_days['day'] == day) & (df_corr_days['reward_group'] == 'R-') & (df_corr_days.cell_type==ct)]['correlation']
            group_rew = group_rew[~np.isnan(group_rew)]
            stat, p = mannwhitneyu(group_rew, group_unrew)
            p_values.append(p)
            print(f'Day {day}: p-value = {p}')
        # Add p-values to the dataframe for visualization
        df_p_values = pd.DataFrame({'day': days, 'p_value': p_values})
        print(df_p_values)
        df_p_values.to_csv(os.path.join(output_dir, f'correlation_matrix_quantification_{win_length}_ms_{cell_selection}_inside_days.csv'), index=False)

        plt.figure()
        sns.barplot(data=df_corr_pre_post.loc[df_corr_pre_post.cell_type==ct], x='comparison', y='correlation', hue='reward_group', palette=palette, hue_order=['R+', 'R-'])
        sns.despine()
        plt.title(f'Correlation inside pre-training, post-training and across both - {ct}')
        pdf.savefig(dpi=300)
        plt.close()

        # Perform Mann-Whitney U test to check if the difference between the two reward groups is significant for each day.
        p_values = []
        for comp in compare:
            group_rew = df_corr_pre_post[(df_corr_pre_post['comparison'] == comp) & (df_corr_pre_post['reward_group'] == 'R+') & (df_corr_pre_post.cell_type==ct)]['correlation']
            group_rew = group_rew[~np.isnan(group_rew)]
            group_unrew = df_corr_pre_post[(df_corr_pre_post['comparison'] == comp) & (df_corr_pre_post['reward_group'] == 'R-') & (df_corr_pre_post.cell_type==ct)]['correlation']
            group_rew = group_rew[~np.isnan(group_rew)]
            stat, p = mannwhitneyu(group_rew, group_unrew)
            p_values.append(p)
            print(f'Comp {comp}: p-value = {p}')
        # Add p-values to the dataframe for visualization
        df_p_values = pd.DataFrame({'comp': compare, 'p_value': p_values})
        print(df_p_values)
        df_p_values.to_csv(os.path.join(output_dir, f'correlation_matrix_quantification_{win_length}_ms_{cell_selection}_across_pre_post.csv'), index=False)


GF305 allcells 55/ 133 responsive cells
[ 50 100 150 200 250]
GF305 wS2 7/ 12 responsive cells
[ 50 100 150 200 250]
GF305 wM1 9/ 20 responsive cells
[ 50 100 150 200 250]
GF306 allcells 95/ 215 responsive cells
[ 50 100 150 200 250]
GF306 wS2 1/ 3 responsive cells
[ 50 100 150 200 250]
GF306 wM1 6/ 15 responsive cells
[ 50 100 150 200 250]
GF308 allcells 42/ 147 responsive cells
[ 50 100 150 200 250]
GF308 wS2 8/ 23 responsive cells
[ 50 100 150 200 250]
GF308 wM1 2/ 16 responsive cells
[ 50 100 150 200 250]
GF311 allcells 42/ 105 responsive cells
[ 50 100 150 200 250]
GF311 wS2 0/ 4 responsive cells
[ 50 100 150 200 250]
GF311 wM1 9/ 17 responsive cells
[ 50 100 150 200 250]
GF313 allcells 71/ 164 responsive cells
[ 50 100 150 200 250]
GF313 wS2 7/ 15 responsive cells
[ 50 100 150 200 250]
GF313 wM1 7/ 14 responsive cells
[ 50 100 150 200 250]
GF314 allcells 116/ 197 responsive cells
[ 50 100 150 200 250]
GF314 wS2 13/ 18 responsive cells
[ 50 100 150 200 250]
GF314 wM1 4/ 10 respons

  c = cov(x, y, rowvar, dtype=dtype)
  c *= np.true_divide(1, fact)
  c *= np.true_divide(1, fact)
  avg = a.mean(axis, **keepdims_kw)
  ret = um.true_divide(


Day -2: p-value = 0.643011484398551
Day -1: p-value = 0.5239281449175213
Day 0: p-value = 0.03700438758047722
Day +1: p-value = 0.023850728855048315
Day +2: p-value = 0.017530018461674984
  day   p_value
0  -2  0.643011
1  -1  0.523928
2   0  0.037004
3  +1  0.023851
4  +2  0.017530
Comp pre_in: p-value = 1.0
Comp post_in: p-value = 0.014960455146298667
Comp pre_post: p-value = 1.0
       comp  p_value
0    pre_in  1.00000
1   post_in  0.01496
2  pre_post  1.00000
Day -2: p-value = nan
Day -1: p-value = nan
Day 0: p-value = nan
Day +1: p-value = nan
Day +2: p-value = nan
  day  p_value
0  -2      NaN
1  -1      NaN
2   0      NaN
3  +1      NaN
4  +2      NaN
Comp pre_in: p-value = nan
Comp post_in: p-value = nan
Comp pre_post: p-value = nan
       comp  p_value
0    pre_in      NaN
1   post_in      NaN
2  pre_post      NaN
Day -2: p-value = nan
Day -1: p-value = nan
Day 0: p-value = nan
Day +1: p-value = nan
Day +2: p-value = nan
  day  p_value
0  -2      NaN
1  -1      NaN
2   0     

### Quantify correlations across pre post learning on the global population matrix.

Variance is computated across pairs of trials rather than mice.

In [258]:
cell_selection = 'no_selection'
responsiveness_thr = 0.001
percent_best_lmi = 15
sns.set_theme(context='paper', style='ticks', palette='deep', font='sans-serif', font_scale=1)

In [268]:
output_dir = fr'//sv-nas1.rcp.epfl.ch/Petersen-Lab/analysis/Anthony_Renard/analysis_output/sensory_plasticity/correlation_matrices'
pdf_file = f'correlation_matrix_quantification_global_matrix_{win_length}_ms_cell_selection_{cell_selection}.pdf'

response_amp_selection = {}
df_corr_days = []
df_corr_pre_post = []

for cell_type in ['allcells', 'wS2', 'wM1']:
    for reward_group in ['R+', 'R-']:
        mice_ids = [mouse_id for mouse_id in mice if (metadata[mouse_id]['reward_group'] == reward_group)
                                                        and (cell_type in response_amp[mouse_id].keys())]

        # # Copying because I will take a subset of the data.
        # response_amp_selection = np.copy(response_amp).item()
        # response_amp_selection = {mouse_id: response_amp_selection[mouse_id] for mouse_id in mice_ids if cell_type in response_amp_selection[mouse_id].keys()}

        response_amp_selection[cell_type] = {}


        for mouse_id in mice_ids:

            response_amp_selection[cell_type][mouse_id] = []

            if cell_selection == 'no_selection':
                selected_cells = np.ones(response_amp[mouse_id][cell_type][0].shape[0], dtype=bool)
            elif cell_selection == 'responsive':
                selected_cells = globally_responsive[mouse_id][cell_type] <= responsiveness_thr
            elif cell_selection == 'lmi':
                lmi_thr = np.percentile(np.abs(np.concatenate([lmi[mouse_id]['allcells'] for mouse_id in mice])), 100-percent_best_lmi)
                selected_cells = lmi[mouse_id][cell_type] >= lmi_thr

            for iday in range(len(days)):
                temp = [np.copy(response_amp[mouse_id][cell_type][iday][selected_cells]) for iday in range(len(days))]
                response_amp_selection[cell_type][mouse_id] = temp

        # Even the number of trials per days across mice.
        min_trials = []
        for iday in range(len(days)):
            m = [data[iday].shape[1] for _, data in response_amp_selection[cell_type].items()]
            min_trials.append(np.min(m))
        print(min_trials)

        for mouse_id, data in response_amp_selection[cell_type].items():
            for iday in range(len(days)):
                response_amp_selection[cell_type][mouse_id][iday] = data[iday][:, :min_trials[iday]]
        
        pop_vectors = np.concatenate([np.concatenate(data, axis=1) for _, data in response_amp_selection[cell_type].items()], axis=0)
        corr_matrix = np.corrcoef(pop_vectors.T)
        # corr_matrix = spearmanr(pop_vectors.T, axis=1)[0]

        plt.figure()
        plt.imshow(corr_matrix)
        plt.title(f'{cell_type} {reward_group}')

        corr_avg_days = []
        corr_avg_pre_post = []

        # Compute average correlation inside each days.
        n_trials = min_trials
        for start, end in zip(np.cumsum([0] + n_trials[:-1]), np.cumsum(n_trials)):
            print(start, end)
            upper_triangle = np.triu(corr_matrix[start:end, start:end], k=1)
            corr_avg_days.append(upper_triangle.flatten())

        # Compare correlation between inside pre training days,
        # inside post training days and between pre and post training days.
        trial_cumsum = np.cumsum([0] + n_trials)
        pre_in_start_x, pre_in_end_x = trial_cumsum[1], trial_cumsum[2]
        pre_in_start_y, pre_in_end_y = trial_cumsum[0], trial_cumsum[1]
        pre_in = corr_matrix[pre_in_start_x:pre_in_end_x, pre_in_start_y:pre_in_end_y]
        print((pre_in_start_x,pre_in_end_x), (pre_in_start_y, pre_in_end_y))

        post_in_start_x, post_in_end_x = trial_cumsum[4], trial_cumsum[5]
        post_in_start_y, post_in_end_y = trial_cumsum[3], trial_cumsum[4]
        post_in = corr_matrix[post_in_start_x:post_in_end_x, post_in_start_y:post_in_end_y]
        print((post_in_start_x,post_in_end_x), (post_in_start_y, post_in_end_y))

        pre_post_start_x, pre_post_end_x = trial_cumsum[3], trial_cumsum[5]
        pre_post_start_y, pre_post_end_y = trial_cumsum[0], trial_cumsum[2]
        pre_post = corr_matrix[pre_post_start_x:pre_post_end_x, pre_post_start_y:pre_post_end_y]
        print((pre_post_start_x,pre_post_end_x), (pre_post_start_y, pre_post_end_y))

        corr_avg_pre_post = [pre_in.flatten(), post_in.flatten(), pre_post.flatten()]

        # Convert to pandas.
        # ------------------

        for iday in range(len(days)):
            for i, corr in enumerate(corr_avg_days[iday]):
                temp = pd.DataFrame([[corr, days[iday], i, cell_type, reward_group]],
                                    columns=['correlation', 'day', 'trial_pair', 'cell_type', 'reward_group'])
                df_corr_days.append(temp)

        compare = ['pre_in', 'post_in', 'pre_post']
        for icomp, comp in enumerate(compare):
            for ipair, corr in enumerate(corr_avg_pre_post[icomp]):
                temp = pd.DataFrame([[corr, comp, ipair, cell_type, reward_group]],
                                    columns=['correlation', 'comparison', 'trial_pair', 'cell_type', 'reward_group'])
                df_corr_pre_post.append(temp)

df_corr_days = pd.concat(df_corr_days, ignore_index=True)
df_corr_pre_post = pd.concat(df_corr_pre_post, ignore_index=True)


[np.int64(49), np.int64(49), np.int64(50), np.int64(49), np.int64(49)]
0 49
49 98
98 148
148 197
197 246
(np.int64(49), np.int64(98)) (np.int64(0), np.int64(49))
(np.int64(197), np.int64(246)) (np.int64(148), np.int64(197))
(np.int64(148), np.int64(246)) (np.int64(0), np.int64(98))
[np.int64(50), np.int64(42), np.int64(43), np.int64(49), np.int64(44)]
0 50
50 92
92 135
135 184
184 228
(np.int64(50), np.int64(92)) (np.int64(0), np.int64(50))
(np.int64(184), np.int64(228)) (np.int64(135), np.int64(184))
(np.int64(135), np.int64(228)) (np.int64(0), np.int64(92))
[np.int64(50), np.int64(49), np.int64(50), np.int64(50), np.int64(49)]
0 50
50 99
99 149
149 199
199 248
(np.int64(50), np.int64(99)) (np.int64(0), np.int64(50))
(np.int64(199), np.int64(248)) (np.int64(149), np.int64(199))
(np.int64(149), np.int64(248)) (np.int64(0), np.int64(99))
[np.int64(50), np.int64(42), np.int64(43), np.int64(49), np.int64(44)]
0 50
50 92
92 135
135 184
184 228
(np.int64(50), np.int64(92)) (np.int64(0), np.

In [265]:
# Save plot and stats.
# --------------------
with PdfPages(os.path.join(output_dir, pdf_file)) as pdf:

    palette = sns.color_palette(['#238443', '#d51a1c'])
    for ct in ['allcells', 'wS2', 'wM1']:

        plt.figure()
        sns.barplot(data=df_corr_days.loc[df_corr_days.cell_type==ct], x='day', y='correlation', hue='reward_group', palette=palette, hue_order=['R+', 'R-'])
        sns.despine()
        plt.title(f'Correlation inside days - {ct}')
        pdf.savefig(dpi=300)
        plt.close()

        # # Perform Mann-Whitney U test to check if the difference between the two reward groups is significant for each day.
        # p_values = []
        # for day in days:
        #     group_rew = df_corr_days[(df_corr_days['day'] == day) & (df_corr_days['reward_group'] == 'R+') & (df_corr_days.cell_type==ct)]['correlation']
        #     group_rew = group_rew[~np.isnan(group_rew)]
        #     group_unrew = df_corr_days[(df_corr_days['day'] == day) & (df_corr_days['reward_group'] == 'R-') & (df_corr_days.cell_type==ct)]['correlation']
        #     group_rew = group_rew[~np.isnan(group_rew)]
        #     stat, p = mannwhitneyu(group_rew, group_unrew)
        #     p_values.append(p)
        #     print(f'Day {day}: p-value = {p}')
        # # Add p-values to the dataframe for visualization
        # df_p_values = pd.DataFrame({'day': days, 'p_value': p_values})
        # print(df_p_values)
        # df_p_values.to_csv(os.path.join(output_dir, f'correlation_matrix_quantification_{win_length}_ms_{cell_selection}_inside_days.csv'), index=False)


        plt.figure()
        sns.barplot(data=df_corr_pre_post.loc[df_corr_pre_post.cell_type==ct], x='comparison', y='correlation', hue='reward_group', palette=palette, hue_order=['R+', 'R-'])
        sns.despine()
        plt.title(f'Correlation inside pre-training, post-training and across both - {ct}')
        pdf.savefig(dpi=300)
        plt.close()

        # # Perform Mann-Whitney U test to check if the difference between the two reward groups is significant for each day.
        # p_values = []
        # for comp in compare:
        #     group_rew = df_corr_pre_post[(df_corr_pre_post['comparison'] == comp) & (df_corr_pre_post['reward_group'] == 'R+') & (df_corr_pre_post.cell_type==ct)]['correlation']
        #     group_rew = group_rew[~np.isnan(group_rew)]
        #     group_unrew = df_corr_pre_post[(df_corr_pre_post['comparison'] == comp) & (df_corr_pre_post['reward_group'] == 'R-') & (df_corr_pre_post.cell_type==ct)]['correlation']
        #     group_rew = group_rew[~np.isnan(group_rew)]
        #     stat, p = mannwhitneyu(group_rew, group_unrew)
        #     p_values.append(p)
        #     print(f'Comp {comp}: p-value = {p}')
        # # Add p-values to the dataframe for visualization
        # df_p_values = pd.DataFrame({'comp': compare, 'p_value': p_values})
        # print(df_p_values)
        # df_p_values.to_csv(os.path.join(output_dir, f'correlation_matrix_quantification_{win_length}_ms_{cell_selection}_across_pre_post.csv'), index=False)



Plot the population vectors and lmi.

This is to show that lmi select cells that go on and off as expected.

In [186]:
f, axes = plt.subplots(1, 2, sharey=True)
im = axes[0].imshow(np.repeat(lmi[mouse_id][:, np.newaxis], 10, axis=1), cmap='viridis', vmin=-1, vmax=1)
plt.colorbar(im)
vmax = np.percentile(pop_vectors_dict[mouse_id]['allcells'], 99)
vmin = np.percentile(pop_vectors_dict[mouse_id]['allcells'], 1)
im = axes[1].imshow(pop_vectors_dict[mouse_id]['allcells'], cmap='viridis', vmin=vmin, vmax=vmax)
plt.colorbar(im)
print(vmin, vmax)

KeyError: (slice(None, None, None), None)

# 3 Correlation during learning.

When is the change of correlation triggered during D0 whisker learning?

- First, plot correlation matrix with WH trials stacked with UM.
- then point plot the correlation of each trial with the average maaping response of D+2
- select modulated cells with LMI and plot population vectors for WH and UM. Is there a graded response? a discret change? or do they respond strong since the very first trial?



In [20]:
# Load data needed to compute before and after learning.

sampling_rate = 30
win = (1, 1.180)  # from stimulus onset to 300 ms after.
win_length = f'{int(np.round((win[1]-win[0]) * 1000))}'  # for file naming.
win = (int(win[0] * sampling_rate), int(win[1] * sampling_rate))
baseline_win = (0, 1)
baseline_win = (int(baseline_win[0] * sampling_rate), int(baseline_win[1] * sampling_rate))
days = ['-2', '-1', '0', '+1', '+2']
trial_type = 'W'
plot_save_figs = False

_, _, mice, _ = io.select_sessions_from_db(db_path,
                                            nwb_dir,
                                            two_p_imaging='yes')
print(mice)
# excluded_mice = ['GF307', 'GF310', 'GF333', 'MI075', 'AR144', 'AR135', 'AR163', 'MI069', 'MI072', 'AR132']
excluded_mice = ['GF307', 'GF310', 'GF333', 'MI075', 'AR144', 'AR135', 'AR163',]
mice = [m for m in mice if m not in excluded_mice]

['GF305', 'GF306', 'GF307', 'GF308', 'GF310', 'GF311', 'GF313', 'GF314', 'GF317', 'GF318', 'GF319', 'GF323', 'GF333', 'GF334', 'GF348', 'GF350', 'MI062', 'MI069', 'MI072', 'MI075', 'MI076', 'AR132', 'AR133', 'AR137', 'AR139', 'AR127', 'AR143', 'AR163', 'AR177', 'AR178', 'AR179', 'AR180']


In [21]:
corr_avg_days = {}
corr_avg_pre_post = {}
metadata = {}
response_amp = {}
pop_vectors_dict = {}
n_trials = {}

globally_responsive = {}
responsive_p_values = {}

# mice = ['GF334']
# Load lmi dicts.
lmi = np.load(r'\\sv-nas1.rcp.epfl.ch\Petersen-Lab\analysis\Anthony_Renard\data_processed\lmi.npy', allow_pickle=True).item()
lmi_p = np.load(r'\\sv-nas1.rcp.epfl.ch\Petersen-Lab\analysis\Anthony_Renard\data_processed\lmi_p.npy', allow_pickle=True).item()


for mouse_id in mice:
    output_dir = fr'//sv-nas1.rcp.epfl.ch/Petersen-Lab/analysis/Anthony_Renard/analysis_output/mice/{mouse_id}'
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    session_list, nwb_files, _, db_filtered = io.select_sessions_from_db(db_path,
                                                                        nwb_dir,
                                                                        two_p_imaging='yes',
                                                                        subject_id=mouse_id,
                                                                        day=days,)
    print(session_list)
    
    data = []
    mdata_list = []
    for session_id in session_list:
        arr, mdata = imaging_utils.load_session_2p_imaging(mouse_id,
                                                            session_id,
                                                            processed_dir)
        arr = imaging_utils.substract_baseline(arr, 3, baseline_win)
        data.append(arr)
        mdata_list.append(mdata)
    reward_group = io.get_reward_group_from_db(db_path, session_list[0])
    metadata[mouse_id] = {}
    metadata[mouse_id]['reward_group'] = reward_group
    metadata[mouse_id]['rois'] = mdata_list[0]['rois']
    metadata[mouse_id]['cell_types'] = mdata_list[0]['cell_types']
    for d, mday in enumerate(mdata_list):
        metadata[mouse_id][days[d]] = {}
        metadata[mouse_id][days[d]]['trials'] = mdata_list[d]['trials']
        metadata[mouse_id][days[d]]['trial_types'] = mdata_list[d]['trial_types']
    
    # Extract UM and WH trials.
    if reward_group == 'R+':
        n_um = 45
        n_wh = 30
    else:
        n_um = 45
        n_wh = 5

    if trial_type == 'W':
        n_um = 45
        n_wh = 60
    
    if trial_type == 'WH':
        # Some days have no WH trials for the mouse.
        if mouse_id == 'AR132':
            continue

    activity = []
    arr = imaging_utils.extract_trials(data[0], mdata_list[0], 'UM', n_trials=n_um)
    activity.append(arr)
    arr = imaging_utils.extract_trials(data[1], mdata_list[1], 'UM', n_trials=n_um)
    activity.append(arr)
    arr = imaging_utils.extract_trials(data[2], mdata_list[2], trial_type, n_trials=n_wh)
    activity.append(arr)
    arr = imaging_utils.extract_trials(data[2], mdata_list[2], 'UM', n_trials=n_um)
    activity.append(arr)
    arr = imaging_utils.extract_trials(data[3], mdata_list[3], trial_type, n_trials=n_wh)
    activity.append(arr)
    arr = imaging_utils.extract_trials(data[3], mdata_list[3], 'UM', n_trials=n_um)
    activity.append(arr)
    arr = imaging_utils.extract_trials(data[4], mdata_list[4], trial_type, n_trials=n_wh)
    activity.append(arr)
    arr = imaging_utils.extract_trials(data[4], mdata_list[4], 'UM', n_trials=n_um)
    activity.append(arr)

    # # Print n trials.
    # print([arr.shape[1] for arr in activity])

    corr_avg_days[mouse_id] = {}
    corr_avg_pre_post[mouse_id] = {}
    response_amp[mouse_id] = {}
    pop_vectors_dict[mouse_id] = {}
    n_trials[mouse_id] = {}
    globally_responsive[mouse_id] = {}
    responsive_p_values[mouse_id] = {}

    for cell_type in ['allcells', 'wS2', 'wM1']:
        # Select cell type.
        if cell_type == 'allcells':
            activity_subtype = activity
        else:
            activity_subtype = []
            cell_type_mask = mdata_list[0]['cell_types']==cell_type
            activity_subtype = [arr[cell_type_mask] for arr in activity]
        
        # strong_cells = [3,11,33,48,57,67,80,86,104,153,166,175]
        # mask = np.ones(data_subtype[0].shape[0], dtype=bool)
        # mask[strong_cells] = False
        # data_subtype = [arr[mask] for arr in data_subtype]

        # If no cells of the specified type, skip.
        if activity_subtype[0].shape[0] == 0:
            continue

        # Compute average response for each trial, each day.
        response_avg = []
        for d in activity_subtype:
            response_avg.append(np.nanmean(d[:, :, win[0]:win[1]], axis=2))
        response_amp[mouse_id][cell_type] = response_avg
        pop_vectors = np.concatenate(response_avg, axis=1)
        pop_vectors_dict[mouse_id][cell_type] = pop_vectors
        
        # Compute LMI.
        # if cell_type == 'allcells':
        #     # pre = np.mean(np.concatenate(response_avg[0:2], axis=1), axis=1)
        #     # post = np.mean(np.concatenate((response_avg[5], response_avg[7]), axis=1), axis=1)
        #     # lmi[mouse_id] = (post - pre) / (np.abs(post) + np.abs(pre))
        #     # lmis = []
        #     # for icell in range(pop_vectors.shape[0]):
        #     #     # mapping trials of D-2, D-1, D+1, D+2.
        #     #     X = np.r_[response_avg[0][icell],
        #     #               response_avg[1][icell],
        #     #               response_avg[5][icell],
        #     #               response_avg[7][icell]]
        #     #     y = np.r_[np.zeros(response_avg[0][icell].shape[0]),
        #     #               np.zeros(response_avg[1][icell].shape[0]),
        #     #               np.ones(response_avg[5][icell].shape[0]),
        #     #               np.ones(response_avg[7][icell].shape[0])]
        #     #     fpr, tpr, _ = roc_curve(y, X)
        #     #     roc_auc = auc(fpr, tpr)
        #     #     lmis.append((roc_auc - 0.5) * 2)
        #     # lmi[mouse_id] = np.array(lmis)

        #     pre = np.concatenate(response_avg[0:2], axis=1)
        #     post = np.concatenate((response_avg[5], response_avg[7]), axis=1)
        #     lmi[mouse_id]['allcells'], lmi_p[mouse_id]['allcells'] = imaging_utils.compute_lmi(pre, post, nshuffles=None)
        # else:
        #     lmi[mouse_id][cell_type] = lmi[mouse_id]['allcells'][metadata[mouse_id]['cell_types'] == cell_type]


        #  # Test responsiveness.
        # if cell_type == 'allcells':
        #     base = []
        #     resp = []
        #     activity_UM = [activity_subtype[i] for i in [0,1,3,5,7]]
        #     for day in activity_UM:
        #         base.append(np.nanmean(day[:, :, baseline_win[0]:baseline_win[1]], axis=2))
        #         resp.append(np.nanmean(day[:, :, win[0]:win[1]], axis=2))

        #     # Test global responsiveness by pulling trials of all days together.
        #     n_cells = base[0].shape[0]
        #     base = np.concatenate(base, axis=1)
        #     resp = np.concatenate(resp, axis=1)
        #     p_values = np.zeros(n_cells)
        #     for icell in range(n_cells):
        #         if np.all(base[icell] == 0) or np.all(resp[icell] == 0):
        #             p_values[icell] = 1
        #         else:
        #             _, p_values[icell] = wilcoxon(base[icell], resp[icell])
        #             globally_responsive[mouse_id][cell_type] = p_values
        # else:
        #     globally_responsive[mouse_id][cell_type] = globally_responsive[mouse_id]['allcells'][metadata[mouse_id]['cell_types'] == cell_type]


        if plot_save_figs:
            corr_matrix = np.corrcoef(pop_vectors.T)
            # corr_matrix = cosine_similarity(pop_vectors.T)
            # corr_matrix = spearmanr(pop_vectors.T, axis=1)[0]

            # Compute average correlation inside each days.
            corr_avg_days[mouse_id][cell_type] = []
            n_trials[mouse_id] = [arr.shape[1] for arr in activity_subtype]
            for start, end in zip(np.cumsum([0] + n_trials[mouse_id][:-1]), np.cumsum(n_trials[mouse_id])):
                upper_triangle = np.triu(corr_matrix[start:end, start:end], k=1)
                corr_avg_days[mouse_id][cell_type].append(np.mean(upper_triangle))

            # Compare correlation between inside pre training days,
            # inside post training days and between pre and post training days.
            trial_cumsum = np.cumsum([0] + n_trials[mouse_id])
            pre_in_start_x, pre_in_end_x = trial_cumsum[1], trial_cumsum[2]
            pre_in_start_y, pre_in_end_y = trial_cumsum[0], trial_cumsum[1]
            pre_in = np.mean(corr_matrix[pre_in_start_x:pre_in_end_x, pre_in_start_y:pre_in_end_y])

            post_in_start_x, post_in_end_x = trial_cumsum[4], trial_cumsum[5]
            post_in_start_y, post_in_end_y = trial_cumsum[3], trial_cumsum[4]
            post_in = np.mean(corr_matrix[post_in_start_x:post_in_end_x, post_in_start_y:post_in_end_y])

            pre_post_start_x, pre_post_end_x = trial_cumsum[3], trial_cumsum[5]
            pre_post_start_y, pre_post_end_y = trial_cumsum[0], trial_cumsum[2]
            pre_post = np.mean(corr_matrix[pre_post_start_x:pre_post_end_x, pre_post_start_y:pre_post_end_y])

            corr_avg_pre_post[mouse_id][cell_type] = [pre_in, post_in, pre_post]


            # Plot population vectors.
            pdf_file = f'pop_vectors_learning_and_mapping_{mouse_id}_{cell_type}_{trial_type}.pdf'
            with PdfPages(os.path.join(output_dir, pdf_file)) as pdf:
                vmax = np.percentile(pop_vectors, 99)
                vmin = np.percentile(pop_vectors, 1)

                f = plt.figure()
                im = plt.imshow(pop_vectors, cmap='viridis', vmin=vmin, vmax=vmax)
                cbar = f.colorbar(im, ticks=[vmin, 0, vmax])
                cbar.ax.set_yticklabels([f'{vmin:.2f}', '0', f'> {vmax:.2f}'])
                cbar.ax.tick_params(size=0)
                pdf.savefig(dpi=300)
                plt.close()

            # Plot correlation matrix.
            pdf_file = f'correlation_matrices_learning_and_mapping_{mouse_id}_{cell_type}_{trial_type}.pdf'            
            with PdfPages(os.path.join(output_dir, pdf_file)) as pdf:
                
                # Set color map limit to the max without the diagonal.
                vmax = np.max(corr_matrix[~np.eye(corr_matrix.shape[0], dtype=bool)])
                vmin = np.min(corr_matrix)
                f = plt.figure()
                im = plt.imshow(corr_matrix, vmin = vmin, vmax=vmax, cmap='viridis')
                n_trials[mouse_id] = [arr.shape[1] for arr in activity]
                for i in np.cumsum(n_trials[mouse_id])[:-1]:
                    plt.axvline(x=i-1, color='#252525', linestyle='-', lw=0.5)
                    plt.axhline(y=i-1, color='#252525', linestyle='-', lw=0.5)
                if cell_type:
                    plt.title(f'{mouse_id} {reward_group} {cell_type}')
                else:
                    plt.title(f'{mouse_id} {reward_group} all cells')
                cbar_ax = f.add_axes([0.85, 0.15, 0.05, 0.7])
                cbar = f.colorbar(im, cax=cbar_ax, ticks=[vmin, 0, vmax])
                cbar.ax.set_yticklabels([f'{vmin:.2f}', '0', f'> {vmax:.2f}'])
                cbar.ax.tick_params(size=0)
                pdf.savefig(dpi=300)
                plt.close()

# Load lmi dicts.
# lmi = {}  

['GF305_27112020_083119', 'GF305_28112020_103938', 'GF305_29112020_103331', 'GF305_30112020_110255', 'GF305_02122020_132229']
here (133, 45, 181)
here (133, 45, 181)
here (133, 60, 181)
here (133, 45, 181)
here (133, 60, 181)
here (133, 45, 181)
here (133, 60, 181)
here (133, 45, 181)
['GF306_27112020_104436', 'GF306_28112020_125555', 'GF306_29112020_131929', 'GF306_30112020_133249', 'GF306_02122020_161611']
here (215, 45, 181)
here (215, 45, 181)
here (215, 60, 181)
here (215, 45, 181)
here (215, 60, 181)
here (215, 45, 181)
here (215, 60, 181)
here (215, 45, 181)
['GF308_17112020_105052', 'GF308_18112020_093627', 'GF308_19112020_103527', 'GF308_20112020_122826', 'GF308_21112020_135515']
here (147, 45, 181)
here (147, 45, 181)
here (147, 60, 181)
here (147, 45, 181)
here (147, 60, 181)
here (147, 45, 181)
here (147, 60, 181)
here (147, 45, 181)
['GF311_17112020_155501', 'GF311_18112020_151838', 'GF311_19112020_160412', 'GF311_20112020_171609', 'GF311_21112020_180049']
here (105, 45, 1

KeyboardInterrupt: 

In [112]:
lmi = np.load(r'\\sv-nas1.rcp.epfl.ch\Petersen-Lab\analysis\Anthony_Renard\data_processed\lmi.npy', allow_pickle=True).item()
lmi_p = np.load(r'\\sv-nas1.rcp.epfl.ch\Petersen-Lab\analysis\Anthony_Renard\data_processed\lmi_p.npy', allow_pickle=True).item()


Construction population vectors.


In [109]:
rewarded_mice

[]

In [114]:
# responsive_thr = 0.001
# lmi_percentile_top = 90
# lmi_percentile_bottom = 10

rewarded_mice = [mouse_id for mouse_id in pop_vectors_dict.keys() if metadata[mouse_id]['reward_group']=='R+']
unrewarded_mice = [mouse_id for mouse_id in pop_vectors_dict.keys() if metadata[mouse_id]['reward_group']=='R-']
# unrewarded_mice = [m for m in unrewarded_mice if m not in ['MI069', 'MI072', 'AR143']]


# # # Compute the LMI thresholds for the top 5% most modulated cells and bottom 5% least modulated cells
# lmi_thr_rew_top = np.percentile(np.concatenate([lmi[mouse_id]['allcells'] for mouse_id in rewarded_mice]), lmi_percentile_top)
# lmi_thr_rew_bottom = np.percentile(np.concatenate([lmi[mouse_id]['allcells'] for mouse_id in rewarded_mice]), lmi_percentile_bottom)
# lmi_thr_unrew_top = np.percentile(np.concatenate([lmi[mouse_id]['allcells'] for mouse_id in unrewarded_mice]), lmi_percentile_top)
# lmi_thr_unrew_bottom = np.percentile(np.concatenate([lmi[mouse_id]['allcells'] for mouse_id in unrewarded_mice]), lmi_percentile_bottom)

rewarded_pop_vectors = np.concatenate(
    [pop_vectors_dict[mouse_id]['allcells'] for mouse_id in rewarded_mice], axis=0)
unrewarded_pop_vectors = np.concatenate(
    [pop_vectors_dict[mouse_id]['allcells'] for mouse_id in unrewarded_mice], axis=0)

rewarded_pop_vectors_wS2 = np.concatenate(
    [pop_vectors_dict[mouse_id]['wS2'] for mouse_id in rewarded_mice if 'wS2' in pop_vectors_dict[mouse_id].keys()], axis=0)
unrewarded_pop_vectors_wS2 = np.concatenate(
    [pop_vectors_dict[mouse_id]['wS2'] for mouse_id in unrewarded_mice if 'wS2' in pop_vectors_dict[mouse_id].keys()], axis=0)

rewarded_pop_vectors_wM1 = np.concatenate(
    [pop_vectors_dict[mouse_id]['wM1'] for mouse_id in rewarded_mice if 'wM1' in pop_vectors_dict[mouse_id].keys()], axis=0)
unrewarded_pop_vectors_wM1 = np.concatenate(
    [pop_vectors_dict[mouse_id]['wM1'] for mouse_id in unrewarded_mice if 'wM1' in pop_vectors_dict[mouse_id].keys()], axis=0)


# Select the top most modulated cells and bottom least modulated cells for each rewarded mouse
rewarded_pop_vectors_top = np.concatenate(
    [pop_vectors_dict[mouse_id]['allcells'][lmi_p[mouse_id]['allcells'] > 0.975] for mouse_id in rewarded_mice], axis=0
)
rewarded_pop_vectors_bottom = np.concatenate(
    [pop_vectors_dict[mouse_id]['allcells'][lmi_p[mouse_id]['allcells'] < 0.025] for mouse_id in rewarded_mice], axis=0
)
rewarded_pop_vectors_lmi = np.concatenate((rewarded_pop_vectors_top, rewarded_pop_vectors_bottom))
rewarded_pop_vectors_lmi_wS2_top = np.concatenate(
    [pop_vectors_dict[mouse_id]['wS2'][lmi_p[mouse_id]['wS2'] > 0.975] for mouse_id in rewarded_mice if 'wS2' in lmi_p[mouse_id].keys()], axis=0
)
rewarded_pop_vectors_lmi_wM1_top = np.concatenate(
    [pop_vectors_dict[mouse_id]['wM1'][lmi_p[mouse_id]['wM1'] > 0.975] for mouse_id in rewarded_mice if 'wM1' in lmi_p[mouse_id].keys()], axis=0
)
rewarded_pop_vectors_lmi_wS2_bottom = np.concatenate(
    [pop_vectors_dict[mouse_id]['wS2'][lmi_p[mouse_id]['wS2'] < 0.025] for mouse_id in rewarded_mice if 'wS2' in lmi_p[mouse_id].keys()], axis=0
)
rewarded_pop_vectors_lmi_wM1_bottom = np.concatenate(
    [pop_vectors_dict[mouse_id]['wM1'][lmi_p[mouse_id]['wM1'] < 0.025] for mouse_id in rewarded_mice if 'wM1' in lmi_p[mouse_id].keys()], axis=0
)
rewarded_pop_vectors_lmi_wS2 = np.concatenate((rewarded_pop_vectors_lmi_wS2_top, rewarded_pop_vectors_lmi_wS2_bottom))
rewarded_pop_vectors_lmi_wM1 = np.concatenate((rewarded_pop_vectors_lmi_wM1_top, rewarded_pop_vectors_lmi_wM1_bottom))

# Select the top 5% most modulated cells and bottom 5% least modulated cells for each unrewarded mouse
unrewarded_pop_vectors_top = np.concatenate(
    [pop_vectors_dict[mouse_id]['allcells'][lmi_p[mouse_id]['allcells'] > 0.975] for mouse_id in unrewarded_mice], axis=0
)
unrewarded_pop_vectors_bottom = np.concatenate(
    [pop_vectors_dict[mouse_id]['allcells'][lmi_p[mouse_id]['allcells'] < 0.025] for mouse_id in unrewarded_mice], axis=0
)
unrewarded_pop_vectors_lmi = np.concatenate((unrewarded_pop_vectors_top, unrewarded_pop_vectors_bottom))
unrewarded_pop_vectors_lmi_wS2_top = np.concatenate(
    [pop_vectors_dict[mouse_id]['wS2'][lmi_p[mouse_id]['wS2'] > 0.975] for mouse_id in unrewarded_mice if 'wS2' in lmi[mouse_id].keys()], axis=0
)
unrewarded_pop_vectors_lmi_wM1_top = np.concatenate(
    [pop_vectors_dict[mouse_id]['wM1'][lmi_p[mouse_id]['wM1'] > 0.975] for mouse_id in unrewarded_mice if 'wM1' in lmi[mouse_id].keys()], axis=0
)
unrewarded_pop_vectors_lmi_wS2_bottom = np.concatenate(
    [pop_vectors_dict[mouse_id]['wS2'][lmi_p[mouse_id]['wS2'] < 0.025] for mouse_id in unrewarded_mice if 'wS2' in lmi[mouse_id].keys()], axis=0
)
unrewarded_pop_vectors_lmi_wM1_bottom = np.concatenate(
    [pop_vectors_dict[mouse_id]['wM1'][lmi_p[mouse_id]['wM1'] < 0.025] for mouse_id in unrewarded_mice if 'wM1' in lmi[mouse_id].keys()], axis=0
)
unrewarded_pop_vectors_lmi_wS2 = np.concatenate((unrewarded_pop_vectors_lmi_wS2_top, unrewarded_pop_vectors_lmi_wS2_bottom))
unrewarded_pop_vectors_lmi_wM1 = np.concatenate((unrewarded_pop_vectors_lmi_wM1_top, unrewarded_pop_vectors_lmi_wM1_bottom))


# # Responsive cells
# # Select the top 5% most modulated cells and bottom 5% least modulated cells for each rewarded mouse
# rewarded_pop_vectors_responsive = np.concatenate(
#     [pop_vectors_dict[mouse_id]['allcells'][globally_responsive[mouse_id]['allcells'] < responsive_thr] for mouse_id in rewarded_mice], axis=0)
# unrewarded_pop_vectors_responsive = np.concatenate(
#     [pop_vectors_dict[mouse_id]['allcells'][globally_responsive[mouse_id]['allcells'] < responsive_thr] for mouse_id in unrewarded_mice], axis=0)

# # Responsive cells
# # Select the top 5% most modulated cells and bottom 5% least modulated cells for each rewarded mouse
# rewarded_pop_vectors_responsive_wS2 = np.concatenate(
#     [pop_vectors_dict[mouse_id]['wS2'][globally_responsive[mouse_id]['wS2'] < responsive_thr] for mouse_id in rewarded_mice if 'wS2' in globally_responsive[mouse_id].keys()], axis=0)
# unrewarded_pop_vectors_responsive_wS2 = np.concatenate(
#     [pop_vectors_dict[mouse_id]['wS2'][globally_responsive[mouse_id]['wS2'] < responsive_thr] for mouse_id in unrewarded_mice if 'wS2' in globally_responsive[mouse_id].keys()], axis=0)

# # Responsive cells
# # Select the top 5% most modulated cells and bottom 5% least modulated cells for each rewarded mouse
# rewarded_pop_vectors_responsive_wM1 = np.concatenate(
#     [pop_vectors_dict[mouse_id]['wM1'][globally_responsive[mouse_id]['wM1'] < responsive_thr] for mouse_id in rewarded_mice if 'wM1' in globally_responsive[mouse_id].keys()], axis=0)
# unrewarded_pop_vectors_responsive_wM1 = np.concatenate(
#     [pop_vectors_dict[mouse_id]['wM1'][globally_responsive[mouse_id]['wM1'] < responsive_thr] for mouse_id in unrewarded_mice if 'wM1' in globally_responsive[mouse_id].keys()], axis=0)

Population matrices with learning. For all cells, modulated cells or responsive cells.

In [57]:
zscore = False

# Subset of cells.
vectors_rew = rewarded_pop_vectors_lmi_wM1
vectors_unrew = unrewarded_pop_vectors_lmi_wM1

n_trial_um = 45
n_trial_rew = 40
n_trial_nonrew = 40

if zscore:
    vectors_rew = (vectors_rew - np.mean(vectors_rew, axis=1, keepdims=True)) / np.std(vectors_rew, axis=1, keepdims=True)
    vectors_unrew = (vectors_unrew - np.mean(vectors_unrew, axis=1, keepdims=True)) / np.std(vectors_unrew, axis=1, keepdims=True)
    vectors_rew = np.nan_to_num(vectors_rew)
    vectors_unrew = np.nan_to_num(vectors_unrew)

block_edges_rew = np.cumsum([n_trial_um, n_trial_um, n_trial_rew, n_trial_um, n_trial_rew, n_trial_um, n_trial_rew, n_trial_um])
corr_matrix = np.corrcoef(vectors_rew.T)
vmax = np.percentile(corr_matrix[~np.eye(corr_matrix.shape[0], dtype=bool)], 99.5)
vmin = np.percentile(corr_matrix, .5)

plt.figure()
im = plt.imshow(corr_matrix, cmap='viridis', vmin=vmin, vmax=vmax)
cbar = f.colorbar(im, ticks=[vmin, 0, vmax])
cbar.ax.set_yticklabels([f'{vmin:.2f}', '0', f'> {vmax:.2f}'])
cbar.ax.tick_params(size=0)

for i in block_edges_rew-0.5:
    plt.axvline(x=i, color='white', linestyle='--', linewidth=1)
    plt.axhline(y=i, color='white', linestyle='--', linewidth=1)
plt.xticks(block_edges_rew-0.5, block_edges_rew)
plt.yticks(block_edges_rew-0.5, block_edges_rew)

block_edges_unrew = np.cumsum([n_trial_um, n_trial_um, n_trial_nonrew, n_trial_um, n_trial_nonrew, n_trial_um, n_trial_nonrew, n_trial_um])[:-1]
corr_matrix = np.corrcoef(vectors_unrew.T)
vmax = np.percentile(corr_matrix[~np.eye(corr_matrix.shape[0], dtype=bool)], 99.5)
vmin = np.percentile(corr_matrix, .5)

plt.figure()
im = plt.imshow(corr_matrix, cmap='viridis', vmin=vmin, vmax=vmax)
cbar = f.colorbar(im, ticks=[vmin, 0, vmax])
cbar.ax.set_yticklabels([f'{vmin:.2f}', '0', f'> {vmax:.2f}'])
cbar.ax.tick_params(size=0)

for i in block_edges_unrew-0.5:
    plt.axvline(x=i, color='white', linestyle='--', linewidth=1)
    plt.axhline(y=i, color='white', linestyle='--', linewidth=1)
plt.xticks(block_edges_unrew-0.5, block_edges_unrew)
plt.yticks(block_edges_unrew-0.5, block_edges_unrew)


  cbar = f.colorbar(im, ticks=[vmin, 0, vmax])
  cbar = f.colorbar(im, ticks=[vmin, 0, vmax])


([<matplotlib.axis.YTick at 0x1e865fa6360>,
  <matplotlib.axis.YTick at 0x1e865fa5f70>,
  <matplotlib.axis.YTick at 0x1e865e8a4e0>,
  <matplotlib.axis.YTick at 0x1e800ccb2f0>,
  <matplotlib.axis.YTick at 0x1e800cc9250>,
  <matplotlib.axis.YTick at 0x1e8052c5b20>,
  <matplotlib.axis.YTick at 0x1e800cc9280>],
 [Text(0, 44.5, '45'),
  Text(0, 89.5, '90'),
  Text(0, 129.5, '130'),
  Text(0, 174.5, '175'),
  Text(0, 214.5, '215'),
  Text(0, 259.5, '260'),
  Text(0, 299.5, '300')])

Population vector plot on the stim for the modulated cells. For WH and UM.

(700, 240)

In [77]:
if wh_trial_type == 'WH':
    block_edges_rew = np.cumsum([45, 45, 30, 45, 30, 45, 30, 45])[:-1]
    block_edges_unrew = np.cumsum([45, 45, 5, 45, 5, 45, 5, 45])[:-1]
elif wh_trial_type == 'WM':
    block_edges_rew = np.cumsum([45, 45, 30, 45, 30, 45, 30, 45])[:-1]
    block_edges_unrew = np.cumsum([45, 45, 10, 45, 10, 45, 10, 45])[:-1]
elif wh_trial_type == 'W':
    block_edges_rew = np.cumsum([45, 45, 40, 45, 40, 45, 40, 45])[:-1]
    block_edges_unrew = np.cumsum([45, 45, 40, 45, 40, 45, 40, 45])[:-1]


vmax = np.percentile(rewarded_pop_vectors_top, 99)
vmin = np.percentile(rewarded_pop_vectors_top, 1)

plt.figure()
plt.imshow(rewarded_pop_vectors, cmap='viridis', vmin=vmin, vmax=vmax)
for i in block_edges_rew-0.5:
    plt.axvline(x=i, color='white', linestyle='--', linewidth=1)
plt.xticks(block_edges_rew-0.5, block_edges_rew)
plt.colorbar()
plt.show()
plt.title('Positively modulated cells for each mouse -- R+')



# Plot for the top 5% positively modulated cells -- R-
vmax = np.percentile(unrewarded_pop_vectors_top, 99)
vmin = np.percentile(unrewarded_pop_vectors_top, 1)
plt.figure()
plt.imshow(unrewarded_pop_vectors, cmap='viridis', vmin=vmin, vmax=vmax)
for i in block_edges_unrew-0.5:
    plt.axvline(x=i, color='white', linestyle='--', linewidth=1)
plt.xticks(block_edges_unrew-0.5, block_edges_unrew)
plt.colorbar()
plt.show()
plt.title('Modulated cells -- R-')


Text(0.5, 1.0, 'Modulated cells -- R-')

Scatter plot of the correlation of each D0 WH with post learning UM.

First have a look with the global population.

In [None]:
sns.set_theme(context='paper', style='ticks', palette='deep', font='sans-serif', font_scale=1)

vectors = rewarded_pop_vectors

# block_edges = np.cumsum([45, 45, 30, 45, 30, 45, 30, 45])
block_edges = np.cumsum([45, 45, 40, 45, 40, 45, 40, 45])
print(block_edges_rew)
# pre = rewarded_pop_vectors[:, :block_edges_rew[0]]
# post = rewarded_pop_vectors[:, block_edges_rew[6]:block_edges_rew[7]]
# d0_learning = rewarded_pop_vectors[:,block_edges_rew[1]:block_edges_rew[2]]
pre = vectors[:, :block_edges[1]]
post1 = vectors[:,block_edges[4]:block_edges[5]]
post2 = vectors[:,block_edges[6]:block_edges[7]]
post = np.concatenate((post1, post2), axis=1)
pre_vect = np.mean(pre, axis=1)
post_vect = np.mean(post, axis=1)
learning_direction = post_vect - pre_vect
d0_learning = vectors[:,block_edges[1]:block_edges[2]]

# correlations = []

# # Average correlation each trial and all the post trials.
# # for i in range(pre.shape[1]):
# #     correlations.append(np.mean(np.corrcoef(pre[:,i], post.T)[1:, 0]))
# # for i in range(d0_learning.shape[1]):
# #     correlations.append(np.mean(np.corrcoef(d0_learning[:, i], post.T)[1:, 0]))
# # for i in range(post.shape[1]):
# #     correlations.append(np.mean(np.corrcoef(post[:, i], post.T)[1:, 0]))
# # correlation = np.array(correlations)

# # or correlation between each trial and the average of the post trials.
# for i in range(pre.shape[1]):
#     correlations.append(np.mean(np.corrcoef(pre[:,i], post_vect)[1:, 0]))
# for i in range(d0_learning.shape[1]):
#     correlations.append(np.mean(np.corrcoef(d0_learning[:, i], post_vect)[1:, 0]))
# for i in range(post.shape[1]):
#     correlations.append(np.mean(np.corrcoef(post[:, i], post_vect)[1:, 0]))
# correlation = np.array(correlations)

projections = []
for i in range(pre.shape[1]):
    projections.append(np.dot(pre[:,i], learning_direction))
for i in range(d0_learning.shape[1]):
    projections.append(np.dot(d0_learning[:, i], learning_direction))
for i in range(post.shape[1]):
    projections.append(np.dot(post[:, i], learning_direction))
projections = np.array(projections)


palette = sns.color_palette([sns.color_palette('deep')[0], '#238443'])
# palette = sns.color_palette([sns.color_palette('deep')[0], '#d51a1c'])
colors = [palette[0]] * pre.shape[1] + [palette[1]] * d0_learning.shape[1] + [palette[0]] * post.shape[1]
plt.figure()
# plt.scatter(range(correlation.shape[0]), correlation, color=colors)
data=pd.DataFrame({'projection': projections,
                   'trial': range(projections.shape[0]),
                   'block': ['pre'] * pre.shape[1] + ['d0_learning'] * d0_learning.shape[1] + ['post'] * post.shape[1]})
# sns.lmplot(data=data, x='trial', y='correlation', fit_reg=True, scatter=True, hue='block', palette=palette, ci=None)


# from scipy.optimize import curve_fit
# # Define the sigmoid function
# def sigmoid(x, L, x0, k, b):
#     return L / (1 + np.exp(-k * (x - x0))) + b

# # Fit the sigmoid function to the data
# p0 = [max(data['correlation']), np.median(data['trial']), 1, min(data['correlation'])]  # Initial guess for the parameters
# params, _ = curve_fit(sigmoid, data['trial'], data['correlation'], p0, method='dogbox')

# # Plot the fitted sigmoid curve
plt.scatter(range(projections.shape[0]), projections, color=colors)
# x_fit = np.linspace(min(data['trial']), max(data['trial']), 200)
# y_fit = sigmoid(x_fit, *params)
# plt.plot(x_fit, y_fit, label='Sigmoid fit', color='red')


# plt.title('Correlation pretraining and D0 learning with post training -- R-')
# edges = np.cumsum([45, 45, 30, 45, 45])
edges = np.cumsum([45, 45, 40, 45, 45])
plt.xticks(edges, edges)
# plt.ylim([-0.4, 1])
sns.despine()



[ 45  90 130 175 215 260 300]


Projection on learning dimension but across mice (not for global population).

In [94]:
# Define the cell type to use: 'allcells', 'responsive', or 'lmi'
cell_type_to_use = 'lmi'  # Change this to 'responsive' or 'lmi' as needed
zscore = True

rewarded_mice = [mouse_id for mouse_id in pop_vectors_dict.keys() if metadata[mouse_id]['reward_group']=='R+']
unrewarded_mice = [mouse_id for mouse_id in pop_vectors_dict.keys() if metadata[mouse_id]['reward_group']=='R-']

mice = unrewarded_mice
if mice == rewarded_mice:
    palette = sns.color_palette([sns.color_palette('deep')[0], '#238443'])
else:
    palette = sns.color_palette([sns.color_palette('deep')[0], '#d51a1c'])

# Plot the same projection across time but for each individual mouse
projections_all = []

for mouse_id in mice:
    if cell_type_to_use == 'allcells':
        vectors = pop_vectors_dict[mouse_id]['allcells']
    elif cell_type_to_use == 'responsive':
        vectors = pop_vectors_dict[mouse_id]['allcells'][globally_responsive[mouse_id]['allcells'] < responsive_thr]
    elif cell_type_to_use == 'lmi':
        vectors = pop_vectors_dict[mouse_id]['allcells'][(lmi_p[mouse_id]['allcells'] >= 0.975) | (lmi_p[mouse_id]['allcells'] <= 0.025)]
        # vectors = pop_vectors_dict[mouse_id]['allcells'][(lmi_p[mouse_id]['allcells'] >= 0.975) ]
        # vectors = pop_vectors_dict[mouse_id]['allcells'][(lmi_p[mouse_id]['allcells'] <= 0.025)]



    block_edges = np.cumsum([45, 45, 40, 45, 40, 45, 40, 45])
    pre = vectors[:, :block_edges[1]]

    post1 = vectors[:, block_edges[4]:block_edges[5]]
    post2 = vectors[:, block_edges[6]:block_edges[7]]
    post = np.concatenate((post1, post2), axis=1)
    d0_learning = vectors[:, block_edges[1]:block_edges[2]]


    # Average the 2D arrays pre and post by group of 5 vectors along axis 1
    # avg_size = 3
    # pre = np.concatenate([pre, np.zeros((pre.shape[0], avg_size - pre.shape[1] % avg_size))], axis=1)
    # post = np.concatenate([post, np.zeros((post.shape[0], avg_size - post.shape[1] % avg_size))], axis=1)
    # d0_learning = np.concatenate([d0_learning, np.zeros((d0_learning.shape[0], avg_size - d0_learning.shape[1] % avg_size))], axis=1)
    # pre = pre.reshape(pre.shape[0], -1, avg_size).mean(axis=2)
    # post = post.reshape(post.shape[0], -1, avg_size).mean(axis=2) 
    # d0_learning = d0_learning.reshape(d0_learning.shape[0], -1, avg_size).mean(axis=2)

    # if zscore:
    #     pre = (pre - np.mean(pre, axis=1, keepdims=True)) / np.std(pre, axis=1, keepdims=True)
    #     post = (post - np.mean(post, axis=1, keepdims=True)) / np.std(post, axis=1, keepdims=True)
    #     pre = np.nan_to_num(pre)
    #     post = np.nan_to_num(post)
    #     d0_learning = (d0_learning - np.mean(d0_learning, axis=1, keepdims=True)) / np.std(d0_learning, axis=1, keepdims=True)
    #     d0_learning = np.nan_to_num(d0_learning)

    pre_vect = np.mean(pre, axis=1)
    post_vect = np.mean(post, axis=1)
    learning_direction = post_vect - pre_vect


    projections = []
    for i in range(pre.shape[1]):
        projection = np.dot(pre[:, i], learning_direction) / np.linalg.norm(learning_direction)
        projections.append(projection)
    for i in range(d0_learning.shape[1]):
        projection = np.dot(d0_learning[:, i], learning_direction) / np.linalg.norm(learning_direction)
        projections.append(projection)
    for i in range(post.shape[1]):
        projection = np.dot(post[:, i], learning_direction) / np.linalg.norm(learning_direction)
        projections.append(projection)

    projections = np.array(projections)

    projections_all.append(pd.DataFrame({
        'projection': projections,
        'trial': range(projections.shape[0]),
        'block': ['pre'] * pre.shape[1] + ['d0_learning'] * d0_learning.shape[1] + ['post'] * post.shape[1],
        'mouse_id': mouse_id
    }))

projections_all = pd.concat(projections_all)


# Plot the line plot with variance across mice
plt.figure(figsize=(10, 6))
sns.lineplot(data=projections_all, x='trial', y='projection', hue='block', errorbar='ci', palette=palette)
# edges = np.cumsum([45, 45, 40, 45, 45])
# plt.xticks(edges, edges)
plt.title('Projection across time with variance across mice')
sns.despine()
plt.show()


The palette list has fewer values (2) than needed (3) and will cycle, which may produce an uninterpretable plot.
  sns.lineplot(data=projections_all, x='trial', y='projection', hue='block', errorbar='ci', palette=palette)


Amplitude of the response during D0 for positively and negatively modulated cells.

In [130]:
cell_type_to_use = 'lmi'  # Change this to 'responsive' or 'lmi' as needed
zscore = True

rewarded_mice = [mouse_id for mouse_id in pop_vectors_dict.keys() if metadata[mouse_id]['reward_group']=='R+']
unrewarded_mice = [mouse_id for mouse_id in pop_vectors_dict.keys() if metadata[mouse_id]['reward_group']=='R-']

mice = rewarded_mice
if mice == rewarded_mice:
    palette = sns.color_palette(['#238443'])
else:
    palette = sns.color_palette(['#d51a1c'])

block_edges = np.cumsum([45, 45, 60, 45, 60, 45, 60, 45])
# Plot the same projection across time but for each individual mouse
amplitude = []

for mouse_id in mice:

    if cell_type_to_use == 'allcells':
        vectors = pop_vectors_dict[mouse_id]['allcells']
    elif cell_type_to_use == 'responsive':
        vectors = pop_vectors_dict[mouse_id]['allcells'][globally_responsive[mouse_id]['allcells'] < responsive_thr]
    elif cell_type_to_use == 'lmi':
        # vectors = pop_vectors_dict[mouse_id]['allcells'][(lmi_p[mouse_id]['allcells'] >= 0.975) | (lmi_p[mouse_id]['allcells'] <= 0.025)]
        # vectors = pop_vectors_dict[mouse_id]['allcells'][(lmi_p[mouse_id]['allcells'] >= 0.975) ]
        vectors = pop_vectors_dict[mouse_id]['allcells'][(lmi_p[mouse_id]['allcells'] <= 0.025)]

    # Average activity across cells.
    d0_learning = (vectors[:, block_edges[1]:block_edges[2]]).mean(axis=0)

    amplitude.append(pd.DataFrame({
    'amplitude': d0_learning,
    'trial': range(d0_learning.shape[0]),
    'mouse_id': mouse_id
    }))

amplitude = pd.concat(amplitude)

sns.lineplot(data=amplitude, x='trial', y='amplitude', palette=palette)
    



  sns.lineplot(data=amplitude, x='trial', y='amplitude', palette=palette)


<Axes: xlabel='trial', ylabel='amplitude'>

Same scatter plot for amplitude of the population for each trial.

In [78]:
block_edges = np.cumsum([45, 45, 30, 45, 30, 45, 30, 45])
# block_edges = np.cumsum([45, 45, 5, 45, 5, 45, 5, 45])
print(block_edges)

pre = rewarded_pop_vectors_wS2[:, :block_edges[1]]
post1 = rewarded_pop_vectors_wS2[:,block_edges[4]:block_edges[5]]
post2 = rewarded_pop_vectors_wS2[:,block_edges[6]:block_edges[7]]
post = np.concatenate((post1, post2), axis=1)
post_vect = np.mean(post, axis=1)
d0_learning = rewarded_pop_vectors_wS2[:,block_edges[1]:block_edges[2]]

amp = np.concatenate((np.mean(pre, axis=0), np.mean(d0_learning, axis=0), np.mean(post, axis=0)))

palette = sns.color_palette([sns.color_palette('deep')[0], '#238443'])
# palette = sns.color_palette([sns.color_palette('deep')[0], '#d51a1c'])
colors = [palette[0]] * pre.shape[1] + [palette[1]] * d0_learning.shape[1] + [palette[0]] * post.shape[1]
plt.figure()
data=pd.DataFrame({'amplitude': amp,
                   'trial': range(amp.shape[0]),
                   'block': ['pre'] * pre.shape[1] + ['d0_learning'] * d0_learning.shape[1] + ['post'] * post.shape[1]})
# sns.regplot(data=data, x='trial', y='amplitude', fit_reg=True, scatter=True, ci=None, order=3)

# from scipy.optimize import curve_fit
# # Define the sigmoid function
# def sigmoid(x, L, x0, k, b):
#     return L / (1 + np.exp(-k * (x - x0))) + b

# # Fit the sigmoid function to the data
# p0 = [max(data['amplitude']), np.median(data['trial']), 1, min(data['amplitude'])]  # Initial guess for the parameters
# params, _ = curve_fit(sigmoid, data['trial'], data['amplitude'], p0, method='dogbox')

# Plot the fitted sigmoid curve
plt.scatter(range(amp.shape[0]), amp, color=colors)
# x_fit = np.linspace(min(data['trial']), max(data['trial']), 200)
# y_fit = sigmoid(x_fit, *params)
# plt.plot(x_fit, y_fit, label='Sigmoid fit', color='red')


plt.title('Amplitude population response pretraining and D0 learning with post training -- R+')
edges = np.cumsum([45, 45, 30, 45, 45])
# edges = np.cumsum([45, 45, 5, 45, 45])
plt.xticks(edges, edges)
# plt.ylim([-0.4, 1])
sns.despine()


[ 45  90 120 165 195 240 270 315]


## 3.X Quantify correlation between first WH and pre post learning.

Correlation matrix for resposnive cells.

In [126]:
cell_selection = 'responsive'
responsiveness_thr = 0.001
percent_best_lmi = 15
sns.set_theme(context='paper', style='ticks', palette='deep', font='sans-serif', font_scale=1)

In [86]:
output_dir = fr'//sv-nas1.rcp.epfl.ch/Petersen-Lab/analysis/Anthony_Renard/analysis_output/sensory_plasticity/correlation_matrices'
pdf_file = f'correlation_first_wh_with_avg_pre_post_{win_length}_ms_cell_selection_{cell_selection}.pdf'

corr_pre = {}
corr_post = {}

mice = [mouse_id for mouse_id in mice if mouse_id != 'AR132']

for mouse_id in mice:

    corr_pre[mouse_id] = {}
    corr_post[mouse_id] = {}

    for cell_type in ['allcells', 'wS2', 'wM1']:

        if cell_type not in response_amp[mouse_id].keys():
            continue
        
        if cell_selection == 'no_selection':
            selected_cells = np.ones(response_amp[mouse_id][cell_type][0].shape[0], dtype=bool)
        elif cell_selection == 'responsive':
            selected_cells = globally_responsive[mouse_id][cell_type] <= responsiveness_thr
        elif cell_selection == 'lmi':
            lmi_thr = np.percentile(np.abs(np.concatenate([lmi[mouse_id]['allcells'] for mouse_id in mice])), 100-percent_best_lmi)
            selected_cells = lmi[mouse_id][cell_type] >= lmi_thr

        pre = np.concatenate((response_amp[mouse_id][cell_type][0][selected_cells],
                             response_amp[mouse_id][cell_type][1][selected_cells]),
                        axis=1)
        # pre = np.mean(pre, axis=1)
        post = np.concatenate((response_amp[mouse_id][cell_type][5][selected_cells],
                                response_amp[mouse_id][cell_type][7][selected_cells]),
                                axis=1)
        # post = np.mean(post, axis=1)
        first_trial = response_amp[mouse_id][cell_type][2][selected_cells, 0]
        # first_trial = np.mean(response_amp[mouse_id][cell_type][2][selected_cells, :5], axis=1)

        # Compute correlation.
        corr_pre[mouse_id][cell_type] = np.sum(np.corrcoef(first_trial, pre.T)[0, :])
        corr_post[mouse_id][cell_type] = np.sum(np.corrcoef(first_trial, post.T)[0, :])


# Convert to pandas.
# ------------------

df_corr = []
for mouse_id in mice:
    for cell_type in corr_pre[mouse_id].keys():
        for comp in ['pre', 'post']:
            if comp == 'pre':
                corr = corr_pre[mouse_id][cell_type]
            else:
                corr = corr_post[mouse_id][cell_type]
            temp = pd.DataFrame([[corr, comp, cell_type, mouse_id, metadata[mouse_id]['reward_group']]],
                            columns=['corr', 'training_phase', 'cell_type', 'mouse_id', 'reward_group'])
            df_corr.append(temp)
df_corr = pd.concat(df_corr, ignore_index=True)



# Save plot and stats.
# --------------------

with PdfPages(os.path.join(output_dir, pdf_file)) as pdf:
    
    palette = sns.color_palette(['#238443', '#d51a1c'])
    for ct in ['allcells', 'wS2', 'wM1']:
        print(f'Cell type {ct}')
        plt.figure()
        sns.barplot(data=df_corr.loc[df_corr.cell_type==ct], x='training_phase', y='corr', hue='reward_group', palette=palette, hue_order=['R+', 'R-'])
        sns.despine()
        plt.title(f'Correlation between first trial and average pre/post training - {ct}')
        pdf.savefig(dpi=300)
        # plt.close()

        # Perform Mann-Whitney U test to check if the difference between the two reward groups is significant for each day.
        p_values = []
        for reward_group in ['R+', 'R-']:
            pre = df_corr[(df_corr['training_phase'] == 'pre') & (df_corr['reward_group'] == reward_group) & (df_corr.cell_type==ct)]['corr']
            pre = pre[~np.isnan(pre)]
            post = df_corr[(df_corr['training_phase'] == 'post') & (df_corr['reward_group'] == reward_group) & (df_corr.cell_type==ct)]['corr']
            post = post[~np.isnan(post)]
            stat, p = mannwhitneyu(pre, post)
            print(f'Cell type {ct} Reward group {reward_group} Comp {comp}: p-value = {p}')
            # p_values.append(p)
        # # Add p-values to the dataframe for visualization
        # df_p_values = pd.DataFrame({'comp': ['pre', 'post'], 'p_value': p_values, 'cell_type': ct})
        # print(df_p_values)
        # df_p_values.to_csv(os.path.join(output_dir, f'correlation_first_wh_with_avg_pre_post_{win_length}_ms_{cell_selection}.csv'), index=False)


  avg = a.mean(axis, **keepdims_kw)
  ret = um.true_divide(
  c = cov(x, y, rowvar, dtype=dtype)
  c *= np.true_divide(1, fact)
  c *= np.true_divide(1, fact)


Cell type allcells
Cell type allcells Reward group R+ Comp post: p-value = 0.3953072503151053
Cell type allcells Reward group R- Comp post: p-value = 0.7074539677020747
Cell type wS2
Cell type wS2 Reward group R+ Comp post: p-value = 1.0
Cell type wS2 Reward group R- Comp post: p-value = 0.06495726495726495
Cell type wM1
Cell type wM1 Reward group R+ Comp post: p-value = 0.8852339144732017
Cell type wM1 Reward group R- Comp post: p-value = 0.30952380952380953


In [270]:
dfs = []
block_labels = [f'block_{i}' for i in range(1, 8)]

df = []
for mouse_id in mouse_ids:
    for cell_type in corr_avg_days[mouse_id].keys():
        if pop_vectors_dict[mouse_id][cell_type].shape[0] < 5:
            continue
        trial_boundaries = np.cumsum([0] + n_trials[mouse_id])
        post_training = np.mean(pop_vectors_dict[mouse_id][cell_type][:, trial_boundaries[-2]:], axis=1, keepdims=True)
        corr = np.corrcoef(pop_vectors_dict[mouse_id][cell_type], post_training, rowvar=False)[-1, :-1]
        
        # blocks = [i for i in range(0, 8) for _ in range(trial_boundaries[i],trial_boundaries[i+1])]
        # trial_id_in_blocks = np.concat([np.arange(0, n_trials[mouse_id][i]) for i in range(8)])
        # block_trial_id = [(block, trial) for block, trial in zip(blocks, trial_id_in_blocks)]
        trial_ids = np.arange(corr.shape[0])

        # multi_index = pd.MultiIndex.from_tuples([('block', 'trial_id')], names=['level_1', 'level_2'])
        df.append(pd.DataFrame([[c, i, cell_type, mouse_id, metadata[mouse_id]['reward_group']] for c, i in zip(corr, trial_ids)],
                            columns=['correlation', 'trial', 'cell_type', 'mouse_id', 'reward_group']))
df = pd.concat(df, ignore_index=True)



In [271]:
sns.pointplot(data=df.loc[df.cell_type=='allcells'], x='trial', y='correlation', linestyles='none', errorbar=None)
plt.ylim([-1,1])
ax = plt.gca()
ax.set_xticks(range(0,280,20))

[<matplotlib.axis.XTick at 0x1fd779938c0>,
 <matplotlib.axis.XTick at 0x1fd5a1a7e60>,
 <matplotlib.axis.XTick at 0x1fd47bbafc0>,
 <matplotlib.axis.XTick at 0x1fd5d194110>,
 <matplotlib.axis.XTick at 0x1fd5d177320>,
 <matplotlib.axis.XTick at 0x1fd5d177e60>,
 <matplotlib.axis.XTick at 0x1fd5d1646b0>,
 <matplotlib.axis.XTick at 0x1fd5d165190>,
 <matplotlib.axis.XTick at 0x1fd5d194470>,
 <matplotlib.axis.XTick at 0x1fd5d165490>,
 <matplotlib.axis.XTick at 0x1fd5d166240>,
 <matplotlib.axis.XTick at 0x1fd5d166ae0>,
 <matplotlib.axis.XTick at 0x1fd5d167590>,
 <matplotlib.axis.XTick at 0x1fd5d167230>]

In [268]:
pop_vectors_dict[mouse_id][cell_type].shape

(1, 260)

In [23]:
pop_vectors_dict[mouse_id][cell_type].shape

(133, 315)

In [None]:
n_blocks = 8
mapping_block = [0, 1, 3, 5, 7]
learning_block = [2, 4, 6]

sns.set_theme(context='paper', style='ticks', palette='deep')
palette = sns.color_palette()
f, axes = plt.subplots(2, 1, figsize=(15, 6))

for i in range(n_blocks):
    if i in learning_block:
        # color = '#238443'
        color = 'red'

    else:
        color = '#eea429ff'
    axes[0].scatter(range(trial_boundaries[i], trial_boundaries[i+1]),
                    correlations[trial_boundaries[i]:trial_boundaries[i+1]],
                    color=color)
axes[0].set_ylim(-1, 1)
# if apply_pca:
#     plt.title('Correlation\n' \
#               f'mice {mouse_list} ' \
#               f'variance retained: {variance_to_retain}')
# else:   
plt.title('Correlation\n' \
            f'mice {mouse_id} ' \
            'full data (no dim reduction)')

behav_table = nwb_read.get_trial_table(nwb_files[2])
behav_table = compute_performance(behav_table, session_list[2], db_path)

palette = sns.color_palette()
plot_single_session(behav_table, session_list[2], axes[1])


# 4 Functional maps across learning days

- amplitude of response
- significance levels (p-value maps)
- LMI

In [16]:
# Load data needed to compute before and after learning.

sampling_rate = 30
win = (1, 1.3)  # from stimulus onset to 300 ms after.
win = (int(win[0] * sampling_rate), int(win[1] * sampling_rate))
baseline_win = (0, 1)
baseline_win = (int(baseline_win[0] * sampling_rate), int(baseline_win[1] * sampling_rate))
reward_group = 'R-'

_, _, mice, _ = io.select_sessions_from_db(db_path,
                                            nwb_dir,
                                            two_p_imaging='yes',
                                            reward_group=reward_group)
print(mice)

['GF319', 'GF348', 'GF350', 'MI062', 'MI069', 'MI072', 'MI075', 'MI076', 'AR132', 'AR137', 'AR139', 'AR131']


In [17]:
output_dir = r'//sv-nas1.rcp.epfl.ch/Petersen-Lab/analysis/Anthony_Renard/analysis_output/functional_maps'
pdf_file = f'functional_maps_{reward_group}.pdf'
with PdfPages(os.path.join(output_dir, pdf_file)) as pdf:
    for mouse_id in mice:
        print(mouse_id)
        session_list, nwb_files, _, db_filtered = io.select_sessions_from_db(db_path,
                                                                            nwb_dir,
                                                                            two_p_imaging='yes',
                                                                            day=days,
                                                                            subject_id=mouse_id)
        print(session_list)
        data = []
        for session_id in session_list:
            arr, metadata = imaging_utils.load_session_2p_imaging(mouse_id,
                                                                session_id,
                                                                processed_dir)
            arr = imaging_utils.substract_baseline(arr, 3, baseline_win)
            data.append(arr)

        # Select UM trials.
        data = [arr[:, -1] for arr in data]
        # Remove trials with NaNs.
        data = [arr[:, ~np.isnan(arr).all(axis=(0,2))] for arr in data]

        # Load image masks.
        roi_masks = nwb_read.get_image_mask(nwb_files[0])
        roi_masks = np.stack(roi_masks, axis=0)
        
        # Compute significance map.
        # -------------------------
        
        # Compute average response and baseline for each trial, each day.
        baseline_avg = []
        response_avg = []
        for day in data:
            baseline_avg.append(np.nanmean(day[:, :, baseline_win[0]:baseline_win[1]], axis=2))
            response_avg.append(np.nanmean(day[:, :, win[0]:win[1]], axis=2))

        # Compare response amplitude to baseline.
        n_cells = data[0].shape[0]
        p_values = [np.zeros(n_cells) for _ in range(len(data))]
        for iday, day in enumerate(data):
            for icell in range(n_cells):
                _, p_values[iday][icell] = wilcoxon(baseline_avg[iday][icell], response_avg[iday][icell])
        p_values = np.stack(p_values, axis=0)

        # Categories p-values.
        p_values_masks = np.copy(p_values)
        p_values_masks[p_values>0.05] = 1
        p_values_masks[p_values<=0.05] = 2
        p_values_masks[p_values<=0.01] = 3
        p_values_masks[p_values<=0.001] = 4
        
        map_significance = []
        for iday in range(5):
            maps = roi_masks * p_values_masks[iday, :, None, None]
            map_significance.append(np.max(maps, axis=0))
            

        # Compute amplitude map.
        # ----------------------
        
        # Compute average response amplitude for each cell.
        response_amplitude = []
        for day in response_avg:
            response_amplitude.append(np.nanmean(day, axis=1))
        response_amplitude = np.stack(response_amplitude, axis=0)
        
        map_amplitude = []
        for iday in range(5):
            maps = roi_masks * response_amplitude[iday, :, None, None]
            map_amplitude.append(np.max(maps, axis=0))
    

        # Plot maps.
        # ----------
        
        f, axes = plt.subplots(2,5, figsize=(20, 8), sharex=True, sharey=True)
        
        # Plot amplitude maps.
        cmap = sns.color_palette("viridis", as_cmap=True)
        # vmin = np.nanmin(response_amplitude)
        vmin = 0
        vmax = np.percentile(response_amplitude, 98)
        
        for iday in range(5):
            a = axes[0,iday].imshow(map_amplitude[iday],
                                interpolation='nearest',
                                cmap=cmap,
                                vmin=vmin, vmax=vmax)
        cbar_ax = f.add_axes([.91,.124,.04,.754])
        f.colorbar(a, cax=cbar_ax, location='right')
        
        
        cmap = ['white', '#d9d9d9', '#fdbb84', '#ef6548', '#990000']
        cmap = colors.ListedColormap(cmap)
        bounds = range(cmap.N+1)
        norm = colors.BoundaryNorm(bounds, cmap.N)
        # Plot responsivity maps.
        for iday in range(5):
            axes[1, iday].imshow(map_significance[iday], cmap=cmap, norm=norm, interpolation='nearest')
            # axes[iday].imshow(map_significance[iday])
            
        plt.suptitle(mouse_id)
        pdf.savefig()
        plt.close()




GF319
['GF319_24122020_120204', 'GF319_25122020_142951', 'GF319_26122020_144746', 'GF319_27122020_135842', 'GF319_28122020_132438']


  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."


Found image mask in plane segmentation
GF348
['GF348_29052021_100151', 'GF348_30052021_110107', 'GF348_31052021_102411', 'GF348_01062021_095758', 'GF348_02062021_084344']


  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."


Found image mask in plane segmentation
GF350
['GF350_29052021_124022', 'GF350_30052021_123155', 'GF350_31052021_135001', 'GF350_01062021_122420', 'GF350_02062021_142138']


  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."


Found image mask in plane segmentation
MI062
['MI062_30092021_091006', 'MI062_01102021_091233', 'MI062_02102021_105027', 'MI062_03102021_103851', 'MI062_04102021_092339']


  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."


Found image mask in plane segmentation
MI069
['MI069_19122021_100830', 'MI069_20122021_095058', 'MI069_21122021_090648', 'MI069_22122021_090212', 'MI069_23122021_085758']


  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."


Found image mask in plane segmentation
MI072
['MI072_19122021_140553', 'MI072_20122021_125805', 'MI072_21122021_132704', 'MI072_22122021_132651', 'MI072_23122021_132111']


  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."


Found image mask in plane segmentation
MI075
['MI075_19122021_152533', 'MI075_20122021_155245', 'MI075_21122021_151949', 'MI075_22122021_152806', 'MI075_23122021_150004']


  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."


Found image mask in plane segmentation
MI076
['MI076_19122021_120004', 'MI076_20122021_113038', 'MI076_21122021_112146', 'MI076_22122021_114039', 'MI076_23122021_113818']


  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."


Found image mask in plane segmentation
AR132
['AR132_20240424_112338', 'AR132_20240425_102625', 'AR132_20240426_093953', 'AR132_20240427_122605', 'AR132_20240428_122206']


  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."


Found image mask in plane segmentation
AR137
['AR137_20240424_172627', 'AR137_20240425_170755', 'AR137_20240426_152510', 'AR137_20240427_171535', 'AR137_20240428_163224']


  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("%s '%s': Length of data does not match length of timestamps. Your data may be transposed. "
  warn("%s '%s': Length of data does not match length of timestamps. Your data may be transposed. "
  warn("%s '%s': Length of data does not match length of timestamps. Your data may be transposed. "
  warn("%s '%s': Length of data does not match length of timestamps. Your data may be transposed. "


Found image mask in plane segmentation
AR139
['AR139_20240424_185913', 'AR139_20240425_181627', 'AR139_20240426_165725', 'AR139_20240427_183701', 'AR139_20240428_180459']


  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."


Found image mask in plane segmentation
AR131
['AR131_20240301_145952', 'AR131_20240302_123034', 'AR131_20240303_171032', 'AR131_20240304_133332', 'AR131_20240305_140141']


  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."


Found image mask in plane segmentation


# 5 Decoding

In [536]:
# Load data needed to compute before and after learning.

sampling_rate = 30
win = (1, 1.180)  # from stimulus onset to 300 ms after.
win = (int(win[0] * sampling_rate), int(win[1] * sampling_rate))
baseline_win = (0, 1)
baseline_win = (int(baseline_win[0] * sampling_rate), int(baseline_win[1] * sampling_rate))
reward_group = 'R+'
plot_save_figs = False
days = ['-2', '-1', '0', '+1', '+2']
wh_trial_type = 'WH'

_, _, mice, _ = io.select_sessions_from_db(db_path,
                                            nwb_dir,
                                            two_p_imaging='yes',
                                            reward_group=reward_group)
print(mice)


['GF305', 'GF306', 'GF307', 'GF308', 'GF310', 'GF311', 'GF313', 'GF314', 'GF317', 'GF318', 'GF323', 'GF333', 'GF334', 'AR133', 'AR135', 'AR127', 'AR143', 'AR144']


In [537]:
corr_avg_days = {}
corr_avg_pre_post = {}
metadata = {}
pop_vectors_dict = {}
lmi = {}

# Disregard these mice as the number of trials is too low.
mice =  [mouse for mouse in mice if mouse not in ['GF307', 'GF310', 'GF333', 'MI075', 'AR144', 'AR135']]

for mouse_id in mice:
    output_dir = fr'//sv-nas1.rcp.epfl.ch/Petersen-Lab/analysis/Anthony_Renard/analysis_output/mice/{mouse_id}'
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    session_list, nwb_files, _, db_filtered = io.select_sessions_from_db(db_path,
                                                                        nwb_dir,
                                                                        two_p_imaging='yes',
                                                                        subject_id=mouse_id,
                                                                        day=days,)
    print(session_list)
    
    data = []
    mdata_list = []
    for session_id in session_list:
        arr, mdata = imaging_utils.load_session_2p_imaging(mouse_id,
                                                            session_id,
                                                            processed_dir)
        arr = imaging_utils.substract_baseline(arr, 3, baseline_win)
        data.append(arr)
        mdata_list.append(mdata)

    reward_group = io.get_reward_group_from_db(db_path, session_list[0])
    metadata[mouse_id] = {}
    metadata[mouse_id]['reward_group'] = reward_group
    
    # Extract UM trials.
    for i, arr in enumerate(data):
        arr = imaging_utils.extract_trials(arr, mdata_list[i], 'UM', n_trials=45)
        data[i] = arr

    corr_avg_days[mouse_id] = {}
    corr_avg_pre_post[mouse_id] = {}
    pop_vectors_dict[mouse_id] = {}
    
    for cell_type in ['allcells', 'wS2', 'wM1']:
        # Select cell type.
        if cell_type == 'allcells':
            data_subtype = data
        else:
            data_subtype = []
            cell_type_mask = mdata_list[0]['cell_types']==cell_type
            data_subtype = [arr[cell_type_mask] for arr in data]

        # if cell_type == 'allcells':  
        #     # Example with and without strong cells for mouse AR127.
        #     strong_cells = [3,11,33,48,57,67,80,86,104,153,166,175]
        #     mask = np.ones(data_subtype[0].shape[0], dtype=bool)
        #     mask[strong_cells] = False
        #     data_subtype = [arr[mask] for arr in data_subtype]

        # If no cells of the specified type, skip.
        if data_subtype[0].shape[0] == 0:
            continue

        # Compute average response for each trial, each day.
        
        response_avg = []
        for day in data_subtype:
            response_avg.append(np.nanmean(day[:, :, win[0]:win[1]], axis=2))

        pop_vectors = np.concatenate(response_avg, axis=1)
        pop_vectors_dict[mouse_id][cell_type] = pop_vectors

                # Compute LMI.
        if cell_type == 'allcells':
            # pre = np.mean(np.concatenate(response_avg[0:2], axis=1), axis=1)
            # print(pre.shape)
            # post = np.mean(np.concatenate((response_avg[5], response_avg[7]), axis=1), axis=1)
            # lmi[mouse_id] = (post - pre) / (np.abs(post) + np.abs(pre))
            lmis = []
            for icell in range(pop_vectors.shape[0]):
                # mapping trials of D-2, D-1, D+1, D+2.
                X = np.r_[response_avg[0][icell],
                          response_avg[1][icell],
                          response_avg[3][icell],
                          response_avg[4][icell]]
                y = np.r_[np.zeros(response_avg[0][icell].shape[0]),
                          np.zeros(response_avg[1][icell].shape[0]),
                          np.ones(response_avg[3][icell].shape[0]),
                          np.ones(response_avg[4][icell].shape[0])]
                fpr, tpr, _ = roc_curve(y, X)
                roc_auc = auc(fpr, tpr)
                lmis.append((roc_auc - 0.5) * 2)
            lmi[mouse_id] = np.array(lmis)

['GF305_27112020_083119', 'GF305_28112020_103938', 'GF305_29112020_103331', 'GF305_30112020_110255', 'GF305_02122020_132229']
['GF306_27112020_104436', 'GF306_28112020_125555', 'GF306_29112020_131929', 'GF306_30112020_133249', 'GF306_02122020_161611']
['GF308_17112020_105052', 'GF308_18112020_093627', 'GF308_19112020_103527', 'GF308_20112020_122826', 'GF308_21112020_135515']
['GF311_17112020_155501', 'GF311_18112020_151838', 'GF311_19112020_160412', 'GF311_20112020_171609', 'GF311_21112020_180049']
['GF313_27112020_141857', 'GF313_28112020_154236', 'GF313_29112020_154625', 'GF313_30112020_154904', 'GF313_03122020_082147']
['GF314_27112020_160459', 'GF314_28112020_171800', 'GF314_29112020_174831', 'GF314_30112020_171906', 'GF314_03122020_102249']
['GF317_15122020_081931', 'GF317_16122020_082007', 'GF317_17122020_080715', 'GF317_18122020_104834', 'GF317_20122020_120604']
['GF318_15122020_095616', 'GF318_16122020_095516', 'GF318_17122020_144100', 'GF318_18122020_132105', 'GF318_19122020_1

In [538]:
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.model_selection import KFold



# Compute the LMI thresholds for the top 5% most modulated cells and bottom 5% least modulated cells
percent_to_keep = range(5, 105, 5)

# Initialize a list to store accuracy for each mouse
accuracy_list = []
for percent in percent_to_keep:
    pop_vectors_lmi = {}

    for mouse_id in mice:
        lmi_threshold = np.percentile(np.abs(lmi[mouse_id]), percent)
        temp = pop_vectors_dict[mouse_id]['allcells'][(np.abs(lmi[mouse_id]) <= lmi_threshold)]
        pop_vectors_lmi[mouse_id] = temp


    # Loop through each mouse in pop_vectors_dict
    for mouse_id  in pop_vectors_lmi.keys():
        # Prepare data for SVM
        pre = pop_vectors_lmi[mouse_id][:, :90]  # First 90 trials
        post = pop_vectors_lmi[mouse_id][:, -90:]  # Last 90 trials
        X = np.concatenate([pre, post], axis=1)
        y = np.concatenate([np.zeros(90), np.ones(90)])  # First 90 trials labeled as 0, last 90 trials labeled as 1

        # Z-score the data
        X = (X - np.mean(X, axis=1, keepdims=True)) / np.std(X, axis=1, keepdims=True)

        # Initialize KFold with the desired number of splits
        kf = KFold(n_splits=5, shuffle=True, random_state=42)

        # List to store accuracy for each fold
        fold_accuracies = []

        # Loop through each fold
        for train_index, test_index in kf.split(X.T):
            X_train, X_test = X.T[train_index], X.T[test_index] 
            y_train, y_test = y[train_index], y[test_index]

            # Train SVM classifier
            svm = SVC(kernel='linear')
            svm.fit(X_train, y_train)

            # Predict and evaluate
            y_pred = svm.predict(X_test)
            accuracy = accuracy_score(y_test, y_pred)
            fold_accuracies.append(accuracy)
        
        # Calculate the average accuracy for the folds
        avg_fold_accuracy = np.mean(fold_accuracies)
        accuracy_list.append({'mouse_id': mouse_id, 'accuracy': avg_fold_accuracy, 'percent': 100-percent})

# Convert accuracy list to DataFrame
df_accuracy = pd.DataFrame(accuracy_list)



In [539]:

sns.set_theme(context='paper', style='ticks', palette='deep', font='sans-serif', font_scale=1)
# Plot the average accuracy across mice as a function of percent
plt.figure(figsize=(10, 6))
sns.lineplot(data=df_accuracy, x='percent', y='accuracy', marker='o')
plt.xlabel('Percent of cell removed')
plt.ylabel('Average Accuracy')
plt.title('Average Accuracy Across Mice as a Function of Percent cells removed')
sns.despine()


In [519]:
percent_to_keep

range(5, 100, 5)

In [None]:
top