In [None]:
# print(plt.style.available)
# matplotlib.use('Agg')
# warnings.filterwarnings("ignore", category=UserWarning)

Nous cherchons à savoir comment le système visuel traite les informations visuelles pour les transformer en une représentation interne de l'environnement externe, en prenant en compte les informations contextuelles que sont les mouvements oculaires et de la tête.

Le système vestibulaire implique de façon précoce les noyaux vestibulaires et profonds cérébelleux. Par injection de virus transsynpatiques, on réalise que les DCN notamment sont à l'origine d'un certain nombre de projections axonales dans une région thalamique : le pulvinar.

Or, on sait que les neurones du pulvinar ont une activité modulée par les saccades.

La question sous-tendue est la suivante : les neurones du pulvinar sont-ils modulés par les mouvements de tête ?

Les neurones du Pulvinar sont effectivement modulés par les rotations de la tête.

Comment se distribuent-ils ?

Nous voyons que la modulation est majoritairement observée pour une rotation CW, dans le pulvinar de l'hémisphère gauche. 

La question suivante est : que préfèrent les neurones, lorsqu'une modulation existe ?

# Importer les modules et fonctions

In [None]:
from MyModule import *

# Paths

In [None]:
current_dir = os.getcwd()
print(current_dir)
parent_dir = os.path.dirname(current_dir)
print(parent_dir)

census_path = os.path.join(parent_dir, 'census.xlsx')
saving_path = os.path.join(parent_dir, 'Analyses')

In [None]:
%matplotlib qt

In [None]:
plt.rcParams.update({'figure.max_open_warning': 20})

# Importer les données Excel

In [None]:
df_data = pd.read_excel(census_path, sheet_name='Study')

# Runner les scripts et load les data

In [None]:
%matplotlib qt
warnings.filterwarnings("ignore", category=UserWarning)
%run multi_animal_function.ipynb

# Trier les neurones en fonction de leur modulation

In [None]:
CWsuppressed = []
CWexcited = []
CCWsuppressed = []
CCWexcited = []


for animal in rotation_data:
    modulationSelectivity = rotation_data[animal]['modulation']['selectivity']
    modulationType = rotation_data[animal]['modulation']['type']
    CWsuppressed.append([x+y in ['both-+', 'both--', 'CW-'] for x, y in zip(modulationSelectivity, modulationType)])
    CWexcited.append([x+y in ['both++', 'both+-', 'CW+'] for x, y in zip(modulationSelectivity, modulationType)])
    CCWsuppressed.append([x+y in ['both+-', 'both--', 'CCW-'] for x, y in zip(modulationSelectivity, modulationType)])
    CCWexcited.append([x+y in ['both++', 'both-+', 'CCW+'] for x, y in zip(modulationSelectivity, modulationType)])

from itertools import chain

CWsuppressed = np.array(list(chain.from_iterable(CWsuppressed)))
CWexcited = np.array(list(chain.from_iterable(CWexcited)))
CCWsuppressed = np.array(list(chain.from_iterable(CCWsuppressed)))
CCWexcited = np.array(list(chain.from_iterable(CCWexcited)))

CWmodulated = CWsuppressed | CWexcited
CCWmodulated = CCWsuppressed | CCWexcited

# Afficher la modulation moyenne de chaque type de réponse, et trier les neurones en fonction de leur FR baseline

In [None]:
%matplotlib qt

plt.figure(figsize=(16,7))

for condition, Suppression, Excitation, position in zip(['CW', 'CCW'], [CWsuppressed, CCWsuppressed], [CWexcited, CCWexcited], [121, 122]):
    zscore_all = []
    std_all = []
    baseMean_all = []

    plt.subplot(position)

    for animal in rotation_data:
        
        for neuron in range(len(rotation_data[animal]['SpikeTimes'][condition])):

            StudiedSpikeTimes = rotation_data[animal]['SpikeTimes'][condition][neuron]
            timeObject = rotation_data[animal]['duration']
            binResolution = 0.05

            local_trial_number = len(StudiedSpikeTimes)

            spike_number_per_trial = [[] for _ in range(local_trial_number)]
            FR_per_trial = [[] for _ in range(local_trial_number)]
            edges = []

            for trial in range(local_trial_number):
                spike_number_per_trial[trial], edges = np.histogram(StudiedSpikeTimes[trial], bins=np.arange(timeObject[0], round(timeObject[-1])+binResolution, binResolution))
                FR_per_trial[trial] = spike_number_per_trial[trial] / binResolution

            mean_trial = np.mean(FR_per_trial, axis=0)

            baseMean=np.mean(mean_trial[edges[:-1]<0])
            stdMean = np.std(mean_trial[edges[:-1]<0])
            std_all.append(stdMean)
            zscore = (mean_trial-baseMean) / np.sqrt(stdMean)
            zscore_all.append(zscore)
            baseMean_all.append(baseMean)

            # plt.plot(edges[:-1], zscore, lw=0.2)

    zscore_all = np.array(zscore_all)

    mean_zscore_all = np.mean(zscore_all[Excitation], axis=0)
    sem = np.std(zscore_all[Excitation], axis=0) / np.sqrt(len(zscore_all[Excitation]))

    plt.plot(edges[:-1], mean_zscore_all, color='red')
    plt.fill_between(edges[:-1], mean_zscore_all-sem, mean_zscore_all+sem, alpha=0.25, color='red')

    mean_zscore_all = np.mean(zscore_all[Suppression], axis=0)
    sem = np.std(zscore_all[Suppression], axis=0) / np.sqrt(len(zscore_all[Suppression]))

    plt.plot(edges[:-1], mean_zscore_all, color='blue')
    plt.fill_between(edges[:-1], mean_zscore_all-sem, mean_zscore_all+sem, alpha=0.25, color='blue')

    plt.plot(rotation_data[animal]['duration'], rotation_data[animal]['MeanRotation'][condition]/rotation_data[animal]['rotationSpeed']*max(abs(mean_zscore_all)), color='gray')

    plt.ylabel('Zscore FR') if position == 121 else None
    plt.xlabel('Time (s)')
    plt.title(condition)

    tri_boolean = np.array(baseMean_all) > 0.4

plt.suptitle('Mean excitation and suppression responses')
plt.show()

# Neurones exclus

In [None]:
print('Nombre total de neurones : '+str(np.sum([rotation_data[animal]['Nclust'] for animal in rotation_data])))

In [None]:
print(f"{sum(tri_boolean==False)} neurones exclus, sur un total de {len(tri_boolean)}")

In [None]:
i=0

for animal in rotation_data:
    for neuron in range(rotation_data[animal]['Nclust']):
        if tri_boolean[i]==False:
            plt.figure()
            plotRaster(rotation_data,animal,'CW',neuron, psth=True)
            plt.show()

        i+=1

# Désigner les neurones dont la vMI est particulière

In [None]:
# RECHERCHE D'UNITS PARTICULIERS
max_vMI, min_vMI, nul_vMI = {}, {}, {}

for direction in ['CW', 'CCW']:
    max_vMI[direction] = max(np.concatenate([rotation_data[animal]['vMI'][direction] for animal in rotation_data])[tri_boolean])
    min_vMI[direction] = min(np.concatenate([rotation_data[animal]['vMI'][direction] for animal in rotation_data])[tri_boolean])
    nul_vMI[direction] = min(np.concatenate([np.abs(rotation_data[animal]['vMI'][direction]) for animal in rotation_data])[tri_boolean])

# Illustration des réponses

In [None]:
unit_max_CCW = ['animal21_a21d1s1', 24]
unit_min_CCW = ['animal21_a21d1s1', 2]
unit_nul_CCW = ['animal21_a53d1s1', 36]
unit_max_CW = ['animal21_a45d1s2', 10]
unit_min_CW = ['animal21_a17d1s1', 23]
unit_nul_CW = ['animal21_a53d1s1', 36]

unit_max = [unit_max_CW, unit_max_CCW]
unit_min = [unit_min_CW, unit_min_CCW]
unit_nul = [unit_nul_CW, unit_nul_CCW]

In [None]:
i = 0

plt.figure()
for targeted in [unit_max, unit_min, unit_nul]:
    for one_targeted, condition in zip(targeted, ['CW', 'CCW']):
        i+=1
        plt.subplot(3,2,i)
        color = 'magenta' if condition == 'CW' else 'cyan'
        xlabel = '' if i<5 else 'Time (s)'
        ylabel = '' if i%2!=1 else '# Trial'
        title = f'unit {one_targeted[1]} of {one_targeted[0]}'
        plotRaster(rotation_data, one_targeted[0], condition, one_targeted[1],
                    psth=True, psthcolor='black', shadedcolor='k',
                    show=False, 
                    plotvelocity=True, 
                    velocitycolor=color, color=color,
                    title=title, xlabel=xlabel, ylabel=ylabel)
plt.tight_layout()
plt.show()

In [None]:
import seaborn as sns


for condition, excited, suppressed in zip(['CW', 'CCW'], [CWexcited, CCWexcited], [CWsuppressed, CCWsuppressed]):
    for position_label, position in zip(['ML_pos', 'AP_pos'], ['Mediolateral', 'Anteroposterior']):
        # Créer une figure avec deux sous-plots partageant le même axe Y
        fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True, figsize=(10, 6), gridspec_kw={'width_ratios': [3, 1]})

        bruit = np.random.normal(0, 0.009, sum([rotation_data[animal]['Nclust'] for animal in rotation_data]))

        # Scatter plot à gauche
        ax1.scatter(
            np.concatenate([rotation_data[animal]['vMI'][condition] for animal in rotation_data])[excited],
            np.concatenate([rotation_data[animal][position_label] * np.ones(rotation_data[animal]['Nclust']) for animal in rotation_data])[excited] + bruit[excited],
            color='red', alpha=0.5, label='Excited'
        )
        ax1.scatter(
            np.concatenate([rotation_data[animal]['vMI'][condition] for animal in rotation_data])[suppressed],
            np.concatenate([rotation_data[animal][position_label] * np.ones(rotation_data[animal]['Nclust']) for animal in rotation_data])[suppressed] + bruit[suppressed],
            color='blue', alpha=0.5, label='Suppressed'
        )
        ax1.scatter(
            np.concatenate([rotation_data[animal]['vMI'][condition] for animal in rotation_data])[~suppressed & ~excited],
            np.concatenate([rotation_data[animal][position_label] * np.ones(rotation_data[animal]['Nclust']) for animal in rotation_data])[~suppressed & ~excited] + bruit[~suppressed & ~excited],
            color='gray', alpha=0.5, label='Not Modulated'
        )

        ax1.axvline(0, ls='--', color='gray')

        ax1.set_ylabel(f'{position} position (mm)')
        ax1.set_xlabel('vMI')
        ax1.set_title(f'vMI for {condition} rotation')
        ax1.set_ylim(min([rotation_data[animal][position_label] for animal in rotation_data]) - 0.1, max([rotation_data[animal][position_label] for animal in rotation_data]) + 0.1)

        # Courbe de densité à droite
        excited_density, edges = np.histogram(np.concatenate([rotation_data[animal][position_label] * np.ones(rotation_data[animal]['Nclust']) for animal in rotation_data])[excited] + bruit[excited], bins=30, density=True)
        suppressed_density, edges = np.histogram(np.concatenate([rotation_data[animal][position_label] * np.ones(rotation_data[animal]['Nclust']) for animal in rotation_data])[suppressed] + bruit[suppressed], bins=30, density=True)
        total_density, edges = np.histogram(np.concatenate([rotation_data[animal][position_label] * np.ones(rotation_data[animal]['Nclust']) for animal in rotation_data]) + bruit, bins=30, density=True)
        for i in range(len(total_density)):
            if total_density[i] != 0:
                excited_density[i] /= total_density[i]
                suppressed_density[i] /= total_density[i]
            
        
        # excited_proportion = [x for x,i in enumerate(excited_density/total_density) if total_density[i]!=0 else 0]
        ax2.plot(excited_density, edges[:-1], color='red')
        ax2.plot(suppressed_density, edges[:-1], color='blue')

        ax2.set_xlabel('Density')


        plt.tight_layout()
        plt.show()

# vMI & dirMI

In [None]:
# ## Plot des indices de modulation -- intérêt limité, cf stats cellule du dessous.
# %matplotlib qt

# def modulation_figure(df_foo):
#     df_foo = pd.DataFrame(df_foo)   

#     bruit = np.random.normal(0, 0.009, len(df_foo['position'][tri_boolean]))

#     g = sns.jointplot(x=df_foo["position"][tri_boolean]+bruit, y=df_foo["depth"][tri_boolean]*-1, c=df_foo["modulation"][tri_boolean],
#                         alpha=0.5, s=100,
#                         joint_kws={"color":None, 'cmap':'coolwarm'},
#                         marginal_kws=dict(bins=50),
#                         marginal_ticks=True)

#     g.figure.set_size_inches(12, 8)
#     g.figure.colorbar(g.ax_joint.collections[0], ax=[g.ax_joint, g.ax_marg_y, g.ax_marg_x], use_gridspec=True, orientation='vertical')
#     g.figure.suptitle(df_foo["title"][0])
#     plt.show()






# # vMI CW

# df_foo = {
#     'position' : np.concatenate([np.full(rotation_data[animal]['Nclust'], rotation_data[animal]['ML_pos']) for animal in rotation_data]),
#     'modulation' : np.concatenate([rotation_data[animal]['vMI']['CW'] for animal in rotation_data]),
#     'depth' : np.concatenate([rotation_data[animal]['AllDepth'] for animal in rotation_data]),
#     'title' : 'vMI CW MedioLateral'
# }

# modulation_figure(df_foo)

# df_foo = {
#     'position' : np.concatenate([np.full(rotation_data[animal]['Nclust'], rotation_data[animal]['AP_pos']) for animal in rotation_data]),
#     'modulation' : np.concatenate([rotation_data[animal]['vMI']['CW'] for animal in rotation_data]),
#     'depth' : np.concatenate([rotation_data[animal]['AllDepth'] for animal in rotation_data]),
#     'title' : 'vMI CW AnteroPosterior'
# }

# modulation_figure(df_foo)






# ## vMI CCW

# df_foo = {
#     'position' : np.concatenate([np.full(rotation_data[animal]['Nclust'], rotation_data[animal]['ML_pos']) for animal in rotation_data]),
#     'modulation' : np.concatenate([rotation_data[animal]['vMI']['CCW'] for animal in rotation_data]),
#     'depth' : np.concatenate([rotation_data[animal]['AllDepth'] for animal in rotation_data]),
#     'title' : 'vMI CCW MedioLateral'
# }

# modulation_figure(df_foo)

# df_foo = {
#     'position' : np.concatenate([np.full(rotation_data[animal]['Nclust'], rotation_data[animal]['AP_pos']) for animal in rotation_data]),
#     'modulation' : np.concatenate([rotation_data[animal]['vMI']['CCW'] for animal in rotation_data]),
#     'depth' : np.concatenate([rotation_data[animal]['AllDepth'] for animal in rotation_data]),
#     'title' : 'vMI CCW AnteroPosterior'
# }

# modulation_figure(df_foo)





# # dirMI ML

# df_foo = {
#     'position' : np.concatenate([np.full(rotation_data[animal]['Nclust'], rotation_data[animal]['ML_pos']) for animal in rotation_data]),
#     'modulation' : np.concatenate([rotation_data[animal]['dirMI'] for animal in rotation_data]),
#     'depth' : np.concatenate([rotation_data[animal]['AllDepth'] for animal in rotation_data]),
#     'title' : 'dirMI MedioLateral'
# }

# modulation_figure(df_foo)

# # dirMI AP

# df_foo = {
#     'position' : np.concatenate([np.full(rotation_data[animal]['Nclust'], rotation_data[animal]['AP_pos']) for animal in rotation_data]),
#     'modulation' : np.concatenate([rotation_data[animal]['dirMI'] for animal in rotation_data]),
#     'depth' : np.concatenate([rotation_data[animal]['AllDepth'] for animal in rotation_data]),
#     'title' : 'dirMI AnteroPosterior'
# }

# modulation_figure(df_foo)

# Repartition along AP and ML axis

In [None]:
def modulation_figure(df_foo, hue_groups, xlim=None, invert=False):
    df_foo = pd.DataFrame(df_foo)

    bruit = np.random.normal(0, 0.009, len(df_foo['Position']))

    g = sns.jointplot(x=df_foo["Position"][tri_boolean]+bruit[tri_boolean], y=df_foo["Depth"][tri_boolean]*-1,
                        alpha=0.5, s=75,
                        marginal_ticks=True,
                        hue=df_foo[hue_groups][tri_boolean],
                        palette=['red', 'blue'])
    
    g.ax_joint.scatter(df_foo["Position"][NoModulation]+bruit[NoModulation], df_foo["Depth"][NoModulation]*-1, color='gray', alpha=0.25, s=75, label='No modulation')
    g.ax_joint.legend()

    g.figure.set_size_inches(12, 8)
    g.ax_joint.invert_xaxis() if invert else None
    g.figure.suptitle(df_foo["title"][0])
    g.ax_joint.set_xlim(xlim) if xlim!=None else None
    plt.show()



##########################################################################
############################     vMI CW        ###########################
##########################################################################

selectivity = []
NoModulation = []

for animal in rotation_data:
    for modulationSelectivity, modulationType in zip(rotation_data[animal]['modulation']['selectivity'], rotation_data[animal]['modulation']['type']):
        if modulationSelectivity + modulationType == 'both+-' or modulationSelectivity + modulationType == 'both++' or modulationSelectivity + modulationType == 'CW+':
            selectivity.append('Excited')
            NoModulation.append(False)
        elif modulationSelectivity + modulationType == 'both--' or modulationSelectivity + modulationType == 'both-+' or modulationSelectivity + modulationType == 'CW-':
            selectivity.append('Suppressed')
            NoModulation.append(False)
        else:
            selectivity.append(None)
            NoModulation.append(True)
NoModulation = np.array(NoModulation)

df_foo = {
    'Position' : np.concatenate([np.full(rotation_data[animal]['Nclust'], rotation_data[animal]['ML_pos']) for animal in rotation_data]),
    'Modulation' : np.concatenate([rotation_data[animal]['vMI']['CW'] for animal in rotation_data]),
    'Depth' : np.concatenate([rotation_data[animal]['AllDepth'] for animal in rotation_data]),
    'title' : 'vMI CW MedioLateral',
    'Selectivity' : selectivity
}

modulation_figure(df_foo, 'Selectivity', invert=True, xlim=(1.3, 0.6))

df_foo = {
    'Position' : np.concatenate([np.full(rotation_data[animal]['Nclust'], rotation_data[animal]['AP_pos']) for animal in rotation_data]),
    'Modulation' : np.concatenate([rotation_data[animal]['vMI']['CW'] for animal in rotation_data]),
    'Depth' : np.concatenate([rotation_data[animal]['AllDepth'] for animal in rotation_data]),
    'title' : 'vMI CW AnteroPosterior',
    'Selectivity' : selectivity
}

modulation_figure(df_foo, 'Selectivity', xlim=(1.6,2.2))


##########################################################################
############################     vMI CCW       ###########################
##########################################################################

selectivity = []
NoModulation = []

for animal in rotation_data:
    for modulationSelectivity, modulationType in zip(rotation_data[animal]['modulation']['selectivity'], rotation_data[animal]['modulation']['type']):
        if modulationSelectivity + modulationType == 'both++' or modulationSelectivity + modulationType == 'both-+' or modulationSelectivity + modulationType == 'CCW+':
            selectivity.append('Excited')
            NoModulation.append(False)
        elif modulationSelectivity + modulationType == 'both--' or modulationSelectivity + modulationType == 'both+-' or modulationSelectivity + modulationType == 'CCW-':
            selectivity.append('Suppressed')
            NoModulation.append(False)
        else:
            selectivity.append(None)
            NoModulation.append(True)
NoModulation = np.array(NoModulation)

df_foo = {
    'Position' : np.concatenate([np.full(rotation_data[animal]['Nclust'], rotation_data[animal]['ML_pos']) for animal in rotation_data]),
    'Modulation' : np.concatenate([rotation_data[animal]['vMI']['CCW'] for animal in rotation_data]),
    'Depth' : np.concatenate([rotation_data[animal]['AllDepth'] for animal in rotation_data]),
    'title' : 'vMI CCW MedioLateral',
    'Selectivity' : selectivity
}

modulation_figure(df_foo, 'Selectivity', invert=True, xlim=(1.3, 0.6))

df_foo = {
    'Position' : np.concatenate([np.full(rotation_data[animal]['Nclust'], rotation_data[animal]['AP_pos']) for animal in rotation_data]),
    'Modulation' : np.concatenate([rotation_data[animal]['vMI']['CCW'] for animal in rotation_data]),
    'Depth' : np.concatenate([rotation_data[animal]['AllDepth'] for animal in rotation_data]),
    'title' : 'vMI CCW AnteroPosterior',
    'Selectivity' : selectivity
}

modulation_figure(df_foo, 'Selectivity', xlim=(1.6,2.2))


##########################################################################
############################     dirMI ML       ##########################
##########################################################################

preference = []
alpha = []

for animal in rotation_data:
    for preferency in rotation_data[animal]['preference']:
        if preferency == 'CW':
            preference.append('CW-preferred')
            alpha.append(0.5)
        elif preferency == 'CCW':
            preference.append('CCW-preferred')
            alpha.append(0.5)
        else:
            preference.append(None)
            alpha.append(0.3)
alpha = np.array(alpha)


df_foo = {
    'Position' : np.concatenate([np.full(rotation_data[animal]['Nclust'], rotation_data[animal]['ML_pos']) for animal in rotation_data]),
    'Modulation' : np.concatenate([rotation_data[animal]['dirMI'] for animal in rotation_data]),
    'Depth' : np.concatenate([rotation_data[animal]['AllDepth'] for animal in rotation_data]),
    'title' : 'dirMI MedioLateral',
    'Preference' : preference
}

modulation_figure(df_foo, 'Preference', invert=True, xlim=(1.3,0.6))


##########################################################################
############################     dirMI AP       ##########################
##########################################################################

df_foo = {
    'Position' : np.concatenate([np.full(rotation_data[animal]['Nclust'], rotation_data[animal]['AP_pos']) for animal in rotation_data]),
    'Modulation' : np.concatenate([rotation_data[animal]['dirMI'] for animal in rotation_data]),
    'Depth' : np.concatenate([rotation_data[animal]['AllDepth'] for animal in rotation_data]),
    'title' : 'dirMI AnteroPosterior',
    'Preference' : preference
}

modulation_figure(df_foo, 'Preference', xlim=(1.6,2.2))

# In 3D

## Colorbar

### dirMI

In [None]:
%matplotlib qt

x = np.concatenate([np.random.normal(loc=rotation_data[animal]['AP_pos'], scale=0.015, size=(rotation_data[animal]['Nclust'], 1)) for animal in rotation_data])
y = np.concatenate([np.random.normal(loc=rotation_data[animal]['ML_pos'], scale=0.015, size=(rotation_data[animal]['Nclust'], 1)) for animal in rotation_data])
z = np.concatenate([rotation_data[animal]['AllDepth'] for animal in rotation_data])
colors = np.concatenate([rotation_data[animal]['dirMI'] for animal in rotation_data])
colorlabel = r'Vestibular Modulation Index : $\frac{during-before}{during+before}$'
xlabel = 'Anteroposterior (mm)'
ylabel = 'Mediolateral (mm)'
zlabel = 'Depth (µm)'
title = f"dirMI of units in 3D space"
filename = f"3D_dirMI.gif"
filepath = os.path.join(saving_path, '3Dmodulation')
save=False

show=True
anim=False


scatter3D(x,y,z,colors,colorlabel,xlabel,ylabel,zlabel,title,filename,filepath,anim=anim,show=show)

### vMI

In [None]:
%matplotlib qt

for direction in ['CW', 'CCW']:
    x = np.concatenate([np.random.normal(loc=rotation_data[animal]['AP_pos'], scale=0.015, size=(rotation_data[animal]['Nclust'], 1)) for animal in rotation_data])
    y = np.concatenate([np.random.normal(loc=rotation_data[animal]['ML_pos'], scale=0.015, size=(rotation_data[animal]['Nclust'], 1)) for animal in rotation_data])
    z = np.concatenate([rotation_data[animal]['AllDepth'] for animal in rotation_data])
    colors = np.concatenate([rotation_data[animal]['vMI'][direction] for animal in rotation_data])
    colorlabel = r'Vestibular Modulation Index : $\frac{during-before}{during+before}$'
    xlabel = 'Anteroposterior (mm)'
    ylabel = 'Mediolateral (mm)'
    zlabel = 'Depth (µm)'
    title = f"{direction} vMI of units in 3D space"
    filename = f"3D_{direction}_vMI.gif"
    filepath = os.path.join(saving_path, '3Dmodulation')
    save=False
    show=True
    anim=False


    scatter3D(x,y,z,colors,colorlabel,xlabel,ylabel,zlabel,title,filename,filepath,anim=anim,show=show)

## Stats

### dirMI

In [None]:
x = np.concatenate([np.random.normal(loc=rotation_data[animal]['AP_pos'], scale=0.015, size=(rotation_data[animal]['Nclust'], 1)) for animal in rotation_data])
y = np.concatenate([np.random.normal(loc=rotation_data[animal]['ML_pos'], scale=0.015, size=(rotation_data[animal]['Nclust'], 1)) for animal in rotation_data])
z = np.concatenate([rotation_data[animal]['AllDepth'] for animal in rotation_data])
colors = []

for animal in rotation_data:
    for neuron in range(rotation_data[animal]['Nclust']):
        if rotation_data[animal]['preference'][neuron]=='CW':
            colors.append('magenta')
        elif rotation_data[animal]['preference'][neuron]=='CCW':
            colors.append('cyan')
        else:
            colors.append('gray')
xlabel = 'Anteroposterior (mm)'
ylabel = 'Mediolateral (mm)'
zlabel = 'Depth (µm)'
title = f"dirMI of units in 3D space"
filename = f"3D_dirMI.gif"
filepath = os.path.join(saving_path, '3Dmodulation')
save=False
show=True
anim=False


fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

scatter = ax.scatter(x, y, z, color=colors, marker='o', s=30, alpha=0.5)


ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_zlabel(zlabel)
ax.invert_zaxis()

plt.title(title)

if anim:
    def update(frame):
        ax.view_init(30, frame)
        return scatter,

    ani = FuncAnimation(fig, update, frames=np.arange(0, 360, 2), interval=100)

if save and anim:
    os.makedirs(filepath, exist_ok=True)
    ani.save(os.path.join(filepath, filename), writer='pillow')
elif save:
    os.makedirs(filepath, exist_ok=True)
    plt.savefig(os.path.join(filepath, filename), format='pdf')

if show:
    plt.show()

### vMI

In [None]:
for direction in ['CW', 'CCW']:
    x = np.concatenate([np.random.normal(loc=rotation_data[animal]['AP_pos'], scale=0.015, size=(rotation_data[animal]['Nclust'], 1)) for animal in rotation_data])
    y = np.concatenate([np.random.normal(loc=rotation_data[animal]['ML_pos'], scale=0.015, size=(rotation_data[animal]['Nclust'], 1)) for animal in rotation_data])
    z = np.concatenate([rotation_data[animal]['AllDepth'] for animal in rotation_data])
    colors = []
    i=0
    for animal in rotation_data:
        for neuron in range(rotation_data[animal]['Nclust']):
            if direction == 'CW':
                if CWexcited[i]:
                    colors.append('red')
                elif CWsuppressed[i]:
                    colors.append('blue')
                else:
                    colors.append('gray')
            elif direction == 'CCW':
                if CCWexcited[i]:
                    colors.append('red')
                elif CCWsuppressed[i]:
                    colors.append('blue')
                else:
                    colors.append('gray')
            i+=1
    xlabel = 'Anteroposterior (mm)'
    ylabel = 'Mediolateral (mm)'
    zlabel = 'Depth (µm)'
    title = f"{direction} vMI of units in 3D space"
    filename = f"3D_{direction}_vMI.gif"
    filepath = os.path.join(saving_path, '3Dmodulation')
    save=False
    show=True
    anim=False


    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    scatter = ax.scatter(x, y, z, color=colors, marker='o', s=30, alpha=0.5)


    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_zlabel(zlabel)
    ax.invert_zaxis()

    plt.title(title)

    if anim:
        def update(frame):
            ax.view_init(30, frame)
            return scatter,

        ani = FuncAnimation(fig, update, frames=np.arange(0, 360, 2), interval=100)

    if save and anim:
        os.makedirs(filepath, exist_ok=True)
        ani.save(os.path.join(filepath, filename), writer='pillow')
    elif save:
        os.makedirs(filepath, exist_ok=True)
        plt.savefig(os.path.join(filepath, filename), format='pdf')

    if show:
        plt.show()


## Projected Scatter (relevance ?)

In [None]:
# For CW vMI only here

x = np.concatenate([np.random.normal(loc=rotation_data[animal]['ML_pos'], scale=0.015, size=(rotation_data[animal]['Nclust'], 1)) for animal in rotation_data])
y = np.concatenate([np.random.normal(loc=rotation_data[animal]['AP_pos'], scale=0.015, size=(rotation_data[animal]['Nclust'], 1)) for animal in rotation_data])
colors = np.concatenate([rotation_data[animal]['vMI']['CW'] for animal in rotation_data])

# with load_theme('arctic_light') as theme:
plt.figure()
# Initialiser le graphique de dispersion 3D
plt.scatter(x, y, c=colors, cmap='coolwarm', marker='o', alpha=0.6, s=100)

plt.gca().invert_xaxis()
plt.gca().invert_yaxis()
plt.ylabel('Anteroposterior (mm)')
plt.xlabel('Mediolateral (mm)')

plt.title('CW vMI in the mediolateral vs anteroposterior axis')

plt.colorbar(label='Vestibular Modulation Index')

# Afficher l'animation
plt.show()


# How to calculate vMI

In [None]:
fig, ax1 = plt.subplots(figsize=(14,4))

animal = unit_max_CW[0]
condition = 'CW'
neuron = unit_max_CW[1]

edges, Zscore, _ = getPSTHparameters(rotation_data[animal]['SpikeTimes'][condition][neuron], rotation_data[animal]['duration'], 0.03)


ax1.plot(edges[:-1], savgol_filter(Zscore,15,1))
ax1.axvline(np.where(rotation_data[animal]['MeanRotation'][condition]>40)[0][0]/30000-2, ls='--', c='g')
ax1.axvline(np.where(rotation_data[animal]['MeanRotation'][condition]>40)[0][-1]/30000-2, ls='--', c='g')
ax1.axvspan(np.where(rotation_data[animal]['MeanRotation'][condition]>40)[0][0]/30000-2, np.where(rotation_data[animal]['MeanRotation'][condition]>40)[0][-1]/30000-2, alpha=0.1, color='g')
ax1.axvline(-1, ls='--', c='m')
ax1.axvline(-1+len(np.where(rotation_data[animal]['MeanRotation'][condition]>40)[0])/30000, ls='--', c='m')
ax1.axvspan(-1, -1+len(np.where(rotation_data[animal]['MeanRotation'][condition]>40)[0])/30000, alpha=0.1, color='m')
ax1.set_xlim(-2,4)

ax2 = ax1.twinx()
ax2.plot(rotation_data[animal]['duration'], rotation_data[animal]['MeanRotation'][condition], color='gray', lw=0.5)

ax1.text(2.5,3, r'$vMI_{CW}=\frac{FR_{rot}-FR_{pre}}{FR_{rot}+FR_{pre}}$', fontsize=15)
ax1.text(-0.8,1, r'$pre$', fontsize=15)
ax1.text(0.85,1, r'$rot$', fontsize=15)
ax1.set_xlabel('Time (s)')
ax1.set_ylabel('Z-Score FR')
ax2.set_ylabel('Table velocity (°/s)')
plt.title('vMI calculation')
plt.margins(y=0)
plt.show()

# How to calculate dirMI

In [None]:
fig, ax = plt.subplots(2,1,figsize=(14,6))

animal = unit_max_CW[0]
neuron = unit_max_CW[1]


edges1, Zscore1, _ = getPSTHparameters(rotation_data[animal]['SpikeTimes']['CW'][neuron], rotation_data[animal]['duration'], 0.03)
ax[0].plot(edges1[:-1], savgol_filter(Zscore1, 15, 1))
ax[0].set_xlim(-2, 4)
ax2_1 = ax[0].twinx()
ax2_1.plot(rotation_data[animal]['duration'], rotation_data[animal]['MeanRotation']['CW'], color='gray', lw=0.5)
ax[0].axvline(np.where(rotation_data[animal]['MeanRotation']['CW']>40)[0][0]/30000-2, ls='--', c='magenta')
ax[0].axvline(np.where(rotation_data[animal]['MeanRotation']['CW']>40)[0][-1]/30000-2, ls='--', c='magenta')
ax[0].axvspan(np.where(rotation_data[animal]['MeanRotation']['CW']>40)[0][0]/30000-2, np.where(rotation_data[animal]['MeanRotation']['CW']>40)[0][-1]/30000-2, alpha=0.1, color='magenta')
ax[0].text(0.85,1, r'$CW$', fontsize=15)
ax[0].text(2.5,3, r'$dirMI=\frac{FR_{CW}-FR_{CCW}}{FR_{CW}+FR_{CCW}}$', fontsize=15)
plt.margins(y=0)
ax[0].set_ylabel('Z-Score FR')
ax[0].set_title('dirMI calculation')
ax2_1.set_ylabel('Table velocity (°/s)')

# Deuxième subplot de la première figure
edges2, Zscore2, _ = getPSTHparameters(rotation_data[animal]['SpikeTimes']['CCW'][neuron], rotation_data[animal]['duration'], 0.03)
ax[1].plot(edges2[:-1], savgol_filter(Zscore2, 15, 1))
ax[1].set_xlim(-2, 4)
ax2_2 = ax[1].twinx()
ax2_2.plot(rotation_data[animal]['duration'], rotation_data[animal]['MeanRotation']['CCW'], color='gray', lw=0.5)
ax[1].axvline(np.where(rotation_data[animal]['MeanRotation']['CW']>40)[0][0]/30000-2, ls='--', c='cyan')
ax[1].axvline(np.where(rotation_data[animal]['MeanRotation']['CW']>40)[0][-1]/30000-2, ls='--', c='cyan')
ax[1].axvspan(np.where(rotation_data[animal]['MeanRotation']['CW']>40)[0][0]/30000-2, np.where(rotation_data[animal]['MeanRotation']['CW']>40)[0][-1]/30000-2, alpha=0.1, color='cyan')
ax[1].text(0.8,1, r'$CCW$', fontsize=15)
plt.margins(y=0)
ax[1].set_ylabel('Z-Score FR')
ax[1].set_xlabel('Time (s)')
ax2_2.set_ylabel('Table velocity (°/s)')

# Affichage des figures


plt.show()

# Some statistics over population

In [None]:
import copy

spike_numbers, FR, mean_FR, mean_baseline, zscores, Zscore_During = ({} for _ in range(6))

for animal in rotation_data:
    spike_numbers[animal], FR[animal], mean_FR[animal], mean_baseline[animal], zscores[animal], Zscore_During[animal] = (copy.deepcopy(rotation_data[animal]['SpikeTimes']) for _ in range(6))

timeObject = rotation_data[animal]['duration']
animal = random.choice([animal for animal in rotation_data])
time_len = sum(rotation_data[animal]['MeanRotation']['CW']>40)/30000


for condition in ['CW', 'CCW']:
    for animal in rotation_data:
        for neuron in range(rotation_data[animal]['Nclust']):
            for trial in range(len(rotation_data[animal]['SpikeTimes'][condition][neuron])):
                spike_numbers[animal][condition][neuron][trial], edges = np.histogram(rotation_data[animal]['SpikeTimes'][condition][neuron][trial], bins=np.arange(timeObject[0], round(timeObject[-1])+0.03, 0.03))
                FR[animal][condition][neuron][trial] = spike_numbers[animal][condition][neuron][trial]/0.03
            mean_FR[animal][condition][neuron] = np.mean(FR[animal][condition][neuron], axis=0)
            mean_baseline[animal][condition][neuron] = np.mean(FR[animal][condition][neuron][:][:np.where(edges<0)[0][-1]], axis=0)
            if np.std(mean_baseline[animal][condition][neuron])==0:
                zscores[animal][condition][neuron] = np.zeros(len(mean_FR[animal][condition][neuron]))
                Zscore_During[animal][condition][neuron] = 0
            else:
                zscores[animal][condition][neuron] = (mean_FR[animal][condition][neuron] - np.mean(mean_baseline[animal][condition][neuron])) / np.std(mean_baseline[animal][condition][neuron])
                Zscore_During[animal][condition][neuron] = (np.mean(rotation_data[animal]['numberDur'][condition][neuron])/time_len - np.mean(mean_baseline[animal][condition][neuron])) / np.std(mean_baseline[animal][condition][neuron])

In [None]:
np.concatenate([np.concatenate([Zscore_During[animal][condition] for animal in rotation_data])[excited] for condition, excited in zip(['CW', 'CCW'], [CWexcited, CCWexcited])])

In [None]:
from scipy import stats


df = pd.DataFrame({
    'Group': np.concatenate([np.full(sum(CWexcited), 'CW', dtype='<U3'), np.full(sum(CCWexcited), 'CCW', dtype='<U3')]),
    'Value': np.concatenate([np.concatenate([Zscore_During[animal][condition] for animal in rotation_data])[excited] for condition, excited in zip(['CW', 'CCW'], [CWexcited, CCWexcited])])
})
np.concatenate([np.concatenate([rotation_data[animal]['vMI'][condition] for animal in rotation_data])[excited] for condition, excited in zip(['CW', 'CCW'], [CWexcited, CCWexcited])])

plt.figure(figsize=(8, 6))
sns.boxplot(x='Group', y='Value', data=df, hue='Group', legend=False, palette=['orchid', 'lightskyblue'])
sns.swarmplot(x='Group', y='Value', data=df, color='black', alpha=0.5)

group_A = df[df['Group'] == 'CW']['Value']
group_B = df[df['Group'] == 'CCW']['Value']
_, p_value_AB = stats.mannwhitneyu(group_A, group_B)

y, h, col = df['Value'].max()*1.025, 0.01, 'k'
plt.plot([0, 1], [y, y], lw=1.5, c=col)
if p_value_AB < 0.001:
    text = '***'
elif p_value_AB < 0.01:
    text = '**'
elif p_value_AB < 0.05:
    text = '*'
else:
    text = 'n.s.'
plt.text(0.5, y + h, text, ha='center', va='bottom', color=col)
plt.ylim(top = 1.05*y+h)
plt.gca().spines['top'].set_visible(False)  
plt.gca().spines['right'].set_visible(False)


plt.xlabel('')
plt.ylabel("Z-score FR")
plt.title("Comparison of Z-score FR of neurons excited by CW or CCW rotation")
plt.show()

In [None]:
from scipy import stats

df = pd.DataFrame({
    'Group': np.concatenate([np.full(sum(CWexcited), 'CW', dtype='<U3'), np.full(sum(CCWexcited), 'CCW', dtype='<U3')]),
    'Value': np.concatenate([np.concatenate([rotation_data[animal]['vMI'][condition] for animal in rotation_data])[excited] for condition, excited in zip(['CW', 'CCW'], [CWexcited, CCWexcited])])
})


plt.figure(figsize=(8, 6))
sns.boxplot(x='Group', y='Value', data=df, hue='Group', legend=False, palette=['orchid', 'lightskyblue'])
sns.swarmplot(x='Group', y='Value', data=df, color='black', alpha=0.5)

group_A = df[df['Group'] == 'CW']['Value']
group_B = df[df['Group'] == 'CCW']['Value']
_, p_value_AB = stats.mannwhitneyu(group_A, group_B)

y, h, col = df['Value'].max()*1.025, 0.01, 'k'
plt.plot([0, 1], [y, y], lw=1.5, c=col)
if p_value_AB < 0.001:
    text = '***'
elif p_value_AB < 0.01:
    text = '**'
elif p_value_AB < 0.05:
    text = '*'
else:
    text = 'n.s.'
plt.text(0.5, y + h, text, ha='center', va='bottom', color=col)
plt.ylim(top = 1.05*y+h)
plt.gca().spines['top'].set_visible(False)  
plt.gca().spines['right'].set_visible(False)


plt.xlabel('')
plt.ylabel("Z-score FR")
plt.title("Comparison of Z-score FR of neurons excited by CW or CCW rotation")
plt.show()

In [None]:
from scipy import stats


df = pd.DataFrame({
    'Group': np.concatenate([np.full(sum(CWsuppressed), 'CW', dtype='<U3'), np.full(sum(CCWsuppressed), 'CCW', dtype='<U3')]),
    'Value': np.concatenate([np.concatenate([rotation_data[animal]['vMI'][condition] for animal in rotation_data])[suppressed] for condition, suppressed in zip(['CW', 'CCW'], [CWsuppressed, CCWsuppressed])])
})


plt.figure(figsize=(8, 6))
sns.boxplot(x='Group', y='Value', data=df, hue='Group', legend=False, palette=['orchid', 'lightskyblue'])
sns.swarmplot(x='Group', y='Value', data=df, color='black', alpha=0.5)

group_A = df[df['Group'] == 'CW']['Value']
group_B = df[df['Group'] == 'CCW']['Value']
_, p_value_AB = stats.mannwhitneyu(group_A, group_B)

y, h, col = df['Value'].max()+0.05, 0.01, 'k'
plt.plot([0, 1], [y, y], lw=1.5, c=col)
if p_value_AB < 0.001:
    text = '***'
elif p_value_AB < 0.01:
    text = '**'
elif p_value_AB < 0.05:
    text = '*'
else:
    text = 'n.s.'
plt.text(0.5, y + h, text, ha='center', va='bottom', color=col)
plt.ylim(-1,0)
plt.gca().spines['top'].set_visible(False)  
plt.gca().spines['right'].set_visible(False)


plt.xlabel('')
plt.ylabel("Z-score FR")
plt.title("Comparison of Z-score FR of neurons suppressed by CW or CCW rotation")
plt.show()

# Phototagging

## Phototagging response of phototagged neurons

In [None]:
for animal in phototagging_data:
    print(animal)
    phototagged_units= np.where(phototagging_data[animal]['modulation']==1)[0]
    print(phototagged_units)
    for unit in phototagged_units:
        plt.figure(figsize=(10,4))
        plt.subplot(221)
        plotRaster(phototagging_data, animal, 'phototagging', unit, color='red', show=False, xlim=(-0.5,0.5), smooth=False)
        plt.subplot(222)
        plotRaster(phototagging_data, animal, 'phototagging', unit, color='blue', show=False, ylabel=False, xlim=(-0.5,0.5), smooth=False)
        plt.subplot(223)
        plotPSTH(phototagging_data, animal, 'phototagging', unit, color='red', show=False, xlim=(-0.5,0.5), smooth=False)
        plt.subplot(224)
        plotPSTH(phototagging_data, animal, 'phototagging', unit, color='blue', show=False, ylabel=False, xlim=(-0.5,0.5), smooth=False)
        plt.suptitle(f"{animal} unit {unit}, phototagged")
        plt.show()

## Rotation response of phototagged neurons

In [None]:
for animal in phototagging_data:
    print(animal)
    phototagged_units = np.where(phototagging_data[animal]['modulation']==1)[0]
    print(phototagged_units)
    for unit in phototagged_units:
        plt.figure(figsize=(10,4))
        plt.subplot(221)
        plotRaster(rotation_data, animal, "CW", unit, plotvelocity=True, velocitycolor='red', color='red', show=False)
        plt.subplot(222)
        plotRaster(rotation_data, animal, "CCW", unit, plotvelocity=True, velocitycolor='blue', color='blue', show=False, ylabel=False)
        plt.subplot(223)
        plotPSTH(rotation_data, animal, "CW", unit, plotvelocity=True, velocitycolor='red', shadedcolor='red', color='red', show=False)
        plt.subplot(224)
        plotPSTH(rotation_data, animal, "CCW", unit, plotvelocity=True, velocitycolor='blue', color='blue', shadedcolor='blue', show=False, ylabel=False)
        plt.suptitle(f"{animal} unit {unit}, phototagged")
        plt.show()

## Gaussian fit of rotation response in order to study kinetics

In [None]:
from scipy.optimize import curve_fit

def Gauss(x, a, x0, sigma, offset):
    return a * np.exp(-(x - x0)**2 / (2 * sigma**2)) + offset





for animal in phototagging_data:
    phototagged_units = np.where(phototagging_data[animal]['modulation']==1)[0]
    for unit in phototagged_units:
        yCW, edges = np.histogram(np.concatenate([trial for trial in rotation_data[animal]['SpikeTimes']['CW'][unit]]), 100)
        yCCW, edges = np.histogram(np.concatenate([trial for trial in rotation_data[animal]['SpikeTimes']['CCW'][unit]]), 100)
        
        x = edges[:-1]
        x_fit = np.linspace(np.min(x), np.max(x), 1000)

        poptCW, _ = curve_fit(Gauss, x, yCW, p0=[np.max(yCW), np.median(x), np.std(x), np.min(yCW)], maxfev=20000)
        poptCCW, _ = curve_fit(Gauss, x, yCCW, p0=[np.max(yCCW), np.median(x), np.std(x), np.min(yCCW)], maxfev=20000)

        plt.figure(figsize=(10,4))
        plt.subplot(121)
        plt.plot(x, yCW, label='raw')
        plt.plot(x_fit, Gauss(x_fit, *poptCW), 'r-', label='fit')
        plt.legend()
        plt.title('CW')
        plt.margins(0,0.05)
        plt.subplot(122)
        plt.plot(x, yCCW, label='raw')
        plt.plot(x_fit, Gauss(x_fit, *poptCCW), 'r-', label='fit')
        plt.legend()
        plt.title('CCW')
        plt.margins(0,0.05)
        plt.suptitle(f"{animal} unit {unit}")
        plt.show()

# Archive

In [None]:
# def vMI_function(AllData,
#                 s=100, alpha=0.6, color='c', scale=0.008, binsNumber=30,
#                 ML=True, AP=True, CW=True, CCW=True, stats=False, hist=False, interest=None,
#                 size=(13,10),
#                 xfontsize=10, yfontsize=10, suptitlefontsize=14,
#                 save=False, filename='',
#                 show=True):
#     # INITIATION
#     stereotaxic_title, stereotaxic_label, direction_selected = [], [], []

#     if ML:
#         stereotaxic_title.append('Mediolateral')
#         stereotaxic_label.append('ML_pos')
#     if AP:
#         stereotaxic_title.append('Anteroposterior')
#         stereotaxic_label.append('AP_pos')
#     if CW:
#         direction_selected.append('CW')
#     if CCW:
#         direction_selected.append('CCW')

#     modulation_quantity = dict() if (stats and hist) else None


#     # LOOPS
#     for pos_title, posOrientation in zip(stereotaxic_title, stereotaxic_label):
#         for direction in direction_selected:

#             # IMPORTATION OF DATA
#             NtotClust = np.sum([AllData[animal]['Nclust'] for animal in AllData])
#             AllDepth = np.concatenate([AllData[animal]['AllDepth'] for animal in AllData])
#             vMI = np.concatenate([np.array(AllData[animal]['vMI'][direction]) for animal in AllData])
#             pos = [AllData[animal][posOrientation] for animal in AllData]
#             position = np.concatenate([np.random.normal(loc=pos[i], scale=scale, size=(sum(AllData[animal]['good_baseline']), 1)) for animal, i in zip(AllData, range(len(pos)))])


#             maximum_value = max([max(np.array(AllData[animal]['vMI'][direction])) for animal in AllData])
#             minimum_value = min([min(np.array(AllData[animal]['vMI'][direction])) for animal in AllData])
#             nul_value = min([min(np.array(AllData[animal]['vMI'][direction]), key=lambda x: abs(x)) for animal in AllData])


            

#             # CREATING DATA FOR PLOT
#             if stats:
#                 AllHow = np.concatenate([AllData[animal]['modulation']['type'] for animal in AllData])
#                 AllWho = np.concatenate([AllData[animal]['modulation']['selectivity'] for animal in AllData])

#                 condition_positive, condition_negative, condition_unmodulated = [], [], []

#                 if direction == 'CW':
#                     for i in range(NtotClust):
#                         condition_positive.append((AllWho[i] == 'CW' or AllWho[i] == 'both') and (AllHow[i] == '+' or AllHow[i] == '+/+' or AllHow[i] == '+/-'))
#                         condition_negative.append((AllWho[i] == 'CW' or AllWho[i] == 'both') and (AllHow[i] == '-' or AllHow[i] == '-/+' or AllHow[i] == '-/-'))
#                         condition_unmodulated.append(AllWho[i] == 'unmodulated' or AllWho[i] == 'CCW')
#                 elif direction == 'CCW':
#                     for i in range(NtotClust):
#                         condition_positive.append((AllWho[i] == 'CCW' or AllWho[i] == 'both') and (AllHow[i] == '+' or AllHow[i] == '+/+' or AllHow[i] == '-/+'))
#                         condition_negative.append((AllWho[i] == 'CCW' or AllWho[i] == 'both') and (AllHow[i] == '-' or AllHow[i] == '+/-' or AllHow[i] == '-/-'))
#                         condition_unmodulated.append(AllWho[i] == 'unmodulated' or AllWho[i] == 'CW')


#                 positionSigni_P = [position[i] for i in range(NtotClust) if condition_positive[i]]
#                 depthSigni_P = [AllDepth[i] for i in range(NtotClust) if condition_positive[i]]
#                 vMISigni_P = [vMI[i] for i in range(NtotClust) if condition_positive[i]]

#                 positionSigni_N = [position[i] for i in range(NtotClust) if condition_negative[i]]
#                 depthSigni_N = [AllDepth[i] for i in range(NtotClust) if condition_negative[i]]
#                 vMISigni_N = [vMI[i] for i in range(NtotClust) if condition_negative[i]]

#                 positionNot = [position[i] for i in range(NtotClust) if condition_unmodulated[i]]
#                 depthNot = [AllDepth[i] for i in range(NtotClust) if condition_unmodulated[i]]
#                 vMINot = [vMI[i] for i in range(NtotClust) if condition_unmodulated[i]]

#                 positionData = [positionSigni_P, positionSigni_N, positionNot]
#                 depthData = [depthSigni_P, depthSigni_N, depthNot]
#                 vMIData = [vMISigni_P, vMISigni_N, vMINot]

#                 if hist:
#                     histdata2 = [] ; bins2 = [] ; histdataDist2 = []
#                     histdata3 = [] ; bins3 = [] ; histdataDist3 = []

#                     for data in [depthSigni_P, depthSigni_N]:
#                         histdatafoo, binsfoo = np.histogram(data, bins=binsNumber)
#                         histdata2.append(histdatafoo) ; bins2.append(binsfoo)
#                         histdataDist2.append(histdatafoo / NtotClust * 100)

#                     for data in [positionSigni_P, positionSigni_N]:
#                         histdatafoo, binsfoo = np.histogram(data, bins=binsNumber)
#                         histdata3.append(histdatafoo) ; bins3.append(binsfoo)
#                         histdataDist3.append(histdatafoo / NtotClust * 100)

#                     modulation_label = ['Excitation', 'Suppression', 'No modulation']
#                     modulation_quantity[direction] = [len(positionSigni_P), len(positionSigni_N), len(positionNot)]
#             elif not stats:
#                 if hist:
#                     histdata3, bins3 = np.histogram(position, bins=binsNumber)
#                     histdataDist3 = histdata3 / NtotClust * 100

#                     histdata2, bins2 = np.histogram(AllDepth, bins=binsNumber)
#                     histdataDist2 = histdata2 / NtotClust * 100

#                 positionData = position
#                 depthData = AllDepth
#                 vMIData = vMI
#             ############################################################





#             # mettre en évidence les neurones d'intérêt
#             if interest!=None:
#                 wanted_array = []
#                 for value in interest:
#                     if value == 'max':
#                         wanted_array.append(maximum_value)
#                     elif value == 'min':
#                         wanted_array.append(minimum_value)
#                     elif value == 'nul':
#                         wanted_array.append(nul_value)
#                     else:
#                         wanted_array.append(value)
#             ############################################################





#             # PLOTTING DATA
#             fig = plt.figure(figsize=size)

#             ## Creating axis
#             if hist:
#                 gs = GridSpec(nrows=4, ncols=4)
#                 ax1 = fig.add_subplot(gs[1:4,0:3]) ; # scatter plot on the left
#                 ax2 = fig.add_subplot(gs[1:4,3], sharey=ax1) ;  # histogram on the right
#                 ax3 = fig.add_subplot(gs[0,0:3], sharex=ax1) ;  # histogram on the top
#                 # ax4 = fig.add_subplot(gs[0,3]) if stats else None
#             ############################################################




#             ## Highlighting units of interest
#             if interest!=None:
#                 for value in wanted_array:
#                     ax = ax1 if hist else plt
#                     if type(value) == list:
#                         values_of_interest = [vMI_value for vMI_value in vMI if value[0] <= vMI_value <= value[1]]
#                         for value in values_of_interest:
#                             unit_index = np.where(vMI == value)[0]
#                             ax.scatter(position[unit_index], AllDepth[unit_index], marker='s', s=s*4, edgecolors='black', facecolors='none', linewidths=2) if len(unit_index) > 0 else None
#                     else:
#                         unit_index = np.where(vMI == value)[0]
#                         if len(unit_index) > 0:
#                             ax.scatter(position[unit_index], AllDepth[unit_index], marker='s', s=s*4, edgecolors='black', facecolors='none', linewidths=2)
#             ############################################################




#             ## Plots of data
#             if stats:                    
#                 if hist:
#                     for position, depth, transparence, color,  label in zip(positionData, depthData, [alpha, alpha, alpha/3], ['red', 'blue', 'grey'], ['Excited units', 'Suppressed units', 'Non-significant modulation']):
#                         ax1.scatter(position, depth, s=s, alpha=transparence, color=color, label=label)
#                     ax1.legend()
#                     for histdata, bins, color in zip(histdataDist2, bins2, ['red', 'blue']):
#                         ax2.plot(histdata, bins[:-1], color=color)
#                     for histdata, bins, color in zip(histdataDist3, bins3, ['red', 'blue']):
#                         ax3.plot(bins[:-1], histdata, color=color)
#                     # ax4.bar(modulation_label, modulation_quantity[direction], color=['red', 'blue', 'grey'], alpha=1)
#                 else:
#                     for position, depth, transparence, color,  label in zip(positionData, depthData, [alpha, alpha, alpha/3], ['red', 'blue', 'grey'], ['Excited units', 'Suppressed units', 'Non-significant modulation']):
#                         plt.scatter(position, depth, s=s, alpha=transparence, color=color, label=label)
#                     plt.legend()
#             else:
#                 if hist:
#                     ax1.scatter(positionData, depthData, s=s, alpha=alpha, color=color)
#                     ax2.plot(histdataDist2, bins2[:-1], color=color)
#                     ax3.plot(bins3[:-1], histdataDist3, color=color)
#                 else:
#                     plt.scatter(positionData, depthData, c=vMI, cmap='coolwarm', s=s, alpha=alpha, clim=(minimum_value, maximum_value))                
#                     plt.colorbar(label=f"{direction} modulation index"+r" : $(n_{during} - n_{before})/(n_{during} + n_{before})$") if ((not hist) and (not stats)) else None


#             ## MEP
#             xlabel = f"{pos_title} position (mm)"
#             ylabel = r"Depth ($\mu$m)"
#             histlabel = 'Density (%)' if hist else None
#             # statsylabel = 'Number of units' if stats else None #(ax4)
#             suptitle = f"{direction} modulation in {pos_title} axis"

#             xfontsize = xfontsize
#             yfontsize = yfontsize
#             # statsyfontsize = 12 if stats else None #(ax4)
#             suptitlefontsize = suptitlefontsize
#             ############################################################

#             ## Aesthetics
#             if hist:
#                 ax1.invert_yaxis()
#                 ax1.invert_xaxis() if posOrientation == 'ML_pos' else None
#                 ax1.set_xlabel(xlabel, fontsize=xfontsize)
#                 ax1.set_ylabel(ylabel, fontsize=yfontsize)

#                 ax2.set_xlabel(histlabel, fontsize=xfontsize)
#                 ax2.spines['top'].set_visible(False)
#                 ax2.spines['right'].set_visible(False) 
#                 ax2.spines['bottom'].set_visible(True) 
#                 ax2.spines['left'].set_visible(False)
#                 ax2.tick_params(axis='both', which='both', bottom=True, top=False, left=False, right=False, labelleft=False, labelbottom=True)

#                 ax3.set_ylabel(histlabel, fontsize=yfontsize)
#                 ax3.spines['top'].set_visible(False)
#                 ax3.spines['right'].set_visible(False)
#                 ax3.spines['bottom'].set_visible(False)
#                 ax3.spines['left'].set_visible(True)
#                 ax3.tick_params(axis='both', which='both', bottom=False, top=False, left=True, right=False, labelleft=True, labelbottom=False)

#                 # if stats:
#                 #     ax4.set_ylabel(statsylabel, fontsize=statsyfontsize)
#                 #     ax4.set_xticks(np.arange(len(modulation_label)))
#                 #     ax4.set_xticklabels(modulation_label, rotation=20)
#             else:
#                 plt.gca().invert_yaxis()
#                 plt.gca().invert_xaxis() if posOrientation == 'ML_pos' else None
#                 plt.xlabel(xlabel, fontsize=xfontsize)
#                 plt.ylabel(ylabel, fontsize=yfontsize)
                    
#             plt.suptitle(suptitle, fontsize=suptitlefontsize)
#             ############################################################


#             # Saving and showing
#             if save:
#                 os.makedirs(os.path.dirname(filename), exist_ok=True)
#                 if filename=='':
#                     plt.savefig(os.path.join(saving_path , f"{direction}_modulation_{pos_title}.png"))
#                 else:
#                     plt.savefig(f"{filename}.png")
            
#             plt.show() if show else plt.close()
#             ############################################################

#             foo=dict()
#             for animal in AllData:
#                 foo[animal] = AllData[animal][posOrientation]
#             sorted_keys = [key for key, value in sorted(foo.items(), key=itemgetter(1))] 

#             myTable = PrettyTable(["Animal", pos_title]) 

#             for animal in sorted_keys:
#                 myTable.add_row([animal, foo[animal]])

#             print(myTable)

#         if stats and hist:
#             plotdata = pd.DataFrame({'CW':modulation_quantity['CW'], 'CCW':modulation_quantity['CCW']}, index=modulation_label)
#             plotdata.plot(kind="bar",figsize=(10, 5), color=['red', 'blue'], rot=25)
#             plt.ylabel('Number of units')
#             plt.title('Rotation modulation')
#             if save:
#                 os.makedirs(os.path.dirname(filename), exist_ok=True)
#                 if filename=='':
#                     plt.savefig(os.path.join(saving_path , f"modulation_{pos_title}_barplot.png"))
#                 else:
#                     plt.savefig(f"{filename}.png")
#             if show:
#                 plt.show()
#             else:
#                 plt.close()
#             print('\n')









# vMI_function(rotation_data, save=False, s=50, alpha=0.7, stats=True, hist=True)

In [None]:
# def vMI_function(AllData,
#                 s=100, alpha=0.6, color='c', scale=0.008, binsNumber=30,
#                 ML=True, AP=True, CW=True, CCW=True, stats=False, hist=False, interest=None,
#                 size=(13,10),
#                 xfontsize=10, yfontsize=10, suptitlefontsize=14,
#                 save=False, filename='',
#                 show=True):
#     # INITIATION
#     stereotaxic_title, stereotaxic_label, direction_selected = [], [], []

#     if ML:
#         stereotaxic_title.append('Mediolateral')
#         stereotaxic_label.append('ML_pos')
#     if AP:
#         stereotaxic_title.append('Anteroposterior')
#         stereotaxic_label.append('AP_pos')
#     if CW:
#         direction_selected.append('CW')
#     if CCW:
#         direction_selected.append('CCW')

#     modulation_quantity = dict() if (stats and hist) else None


#     # LOOPS
#     for pos_title, posOrientation in zip(stereotaxic_title, stereotaxic_label):
#         for direction in direction_selected:

#             # IMPORTATION OF DATA
#             NtotClust = np.sum([sum(AllData[animal]['good_baseline']) for animal in AllData])
#             AllDepth = np.concatenate([AllData[animal]['AllDepth'] for animal in AllData])
#             vMI = np.concatenate([np.array(AllData[animal]['vMI'][direction]) for animal in AllData])
#             pos = [AllData[animal][posOrientation] for animal in AllData]
#             position = np.concatenate([np.random.normal(loc=pos[i], scale=scale, size=(sum(AllData[animal]['good_baseline']), 1)) for animal, i in zip(AllData, range(len(pos)))])


#             maximum_value = max([max(np.array(AllData[animal]['vMI'][direction])) for animal in AllData])
#             minimum_value = min([min(np.array(AllData[animal]['vMI'][direction])) for animal in AllData])
#             nul_value = min([min(np.array(AllData[animal]['vMI'][direction]), key=lambda x: abs(x)) for animal in AllData])


            

#             # CREATING DATA FOR PLOT
#             if stats:
#                 AllHow = np.concatenate([AllData[animal]['modulation']['type'] for animal in AllData])
#                 AllWho = np.concatenate([AllData[animal]['modulation']['selectivity'] for animal in AllData])

#                 condition_positive, condition_negative, condition_unmodulated = [], [], []

#                 if direction == 'CW':
#                     for i in range(NtotClust):
#                         condition_positive.append((AllWho[i] == 'CW' or AllWho[i] == 'both') and (AllHow[i] == '+' or AllHow[i] == '+/+' or AllHow[i] == '+/-'))
#                         condition_negative.append((AllWho[i] == 'CW' or AllWho[i] == 'both') and (AllHow[i] == '-' or AllHow[i] == '-/+' or AllHow[i] == '-/-'))
#                         condition_unmodulated.append(AllWho[i] == 'unmodulated' or AllWho[i] == 'CCW')
#                 elif direction == 'CCW':
#                     for i in range(NtotClust):
#                         condition_positive.append((AllWho[i] == 'CCW' or AllWho[i] == 'both') and (AllHow[i] == '+' or AllHow[i] == '+/+' or AllHow[i] == '-/+'))
#                         condition_negative.append((AllWho[i] == 'CCW' or AllWho[i] == 'both') and (AllHow[i] == '-' or AllHow[i] == '+/-' or AllHow[i] == '-/-'))
#                         condition_unmodulated.append(AllWho[i] == 'unmodulated' or AllWho[i] == 'CW')


#                 positionSigni_P = [position[i] for i in range(NtotClust) if condition_positive[i]]
#                 depthSigni_P = [AllDepth[i] for i in range(NtotClust) if condition_positive[i]]
#                 vMISigni_P = [vMI[i] for i in range(NtotClust) if condition_positive[i]]

#                 positionSigni_N = [position[i] for i in range(NtotClust) if condition_negative[i]]
#                 depthSigni_N = [AllDepth[i] for i in range(NtotClust) if condition_negative[i]]
#                 vMISigni_N = [vMI[i] for i in range(NtotClust) if condition_negative[i]]

#                 positionNot = [position[i] for i in range(NtotClust) if condition_unmodulated[i]]
#                 depthNot = [AllDepth[i] for i in range(NtotClust) if condition_unmodulated[i]]
#                 vMINot = [vMI[i] for i in range(NtotClust) if condition_unmodulated[i]]

#                 positionData = [positionSigni_P, positionSigni_N, positionNot]
#                 depthData = [depthSigni_P, depthSigni_N, depthNot]
#                 vMIData = [vMISigni_P, vMISigni_N, vMINot]

#                 if hist:
#                     histdata2 = [] ; bins2 = [] ; histdataDist2 = []
#                     histdata3 = [] ; bins3 = [] ; histdataDist3 = []

#                     for data in [depthSigni_P, depthSigni_N]:
#                         histdatafoo, binsfoo = np.histogram(data, bins=binsNumber)
#                         histdata2.append(histdatafoo) ; bins2.append(binsfoo)
#                         histdataDist2.append(histdatafoo / NtotClust * 100)

#                     for data in [positionSigni_P, positionSigni_N]:
#                         histdatafoo, binsfoo = np.histogram(data, bins=binsNumber)
#                         histdata3.append(histdatafoo) ; bins3.append(binsfoo)
#                         histdataDist3.append(histdatafoo / NtotClust * 100)

#                     modulation_label = ['Excitation', 'Suppression', 'No modulation']
#                     modulation_quantity[direction] = [len(positionSigni_P), len(positionSigni_N), len(positionNot)]
#             elif not stats:
#                 if hist:
#                     histdata3, bins3 = np.histogram(position, bins=binsNumber)
#                     histdataDist3 = histdata3 / NtotClust * 100

#                     histdata2, bins2 = np.histogram(AllDepth, bins=binsNumber)
#                     histdataDist2 = histdata2 / NtotClust * 100

#                 positionData = position
#                 depthData = AllDepth
#                 vMIData = vMI
#             ############################################################





#             # mettre en évidence les neurones d'intérêt
#             if interest!=None:
#                 wanted_array = []
#                 for value in interest:
#                     if value == 'max':
#                         wanted_array.append(maximum_value)
#                     elif value == 'min':
#                         wanted_array.append(minimum_value)
#                     elif value == 'nul':
#                         wanted_array.append(nul_value)
#                     else:
#                         wanted_array.append(value)
#             ############################################################





#             # PLOTTING DATA
#             fig = plt.figure(figsize=size)

#             ## Creating axis
#             if hist:
#                 gs = GridSpec(nrows=4, ncols=4)
#                 ax1 = fig.add_subplot(gs[1:4,0:3]) ; # scatter plot on the left
#                 ax2 = fig.add_subplot(gs[1:4,3], sharey=ax1) ;  # histogram on the right
#                 ax3 = fig.add_subplot(gs[0,0:3], sharex=ax1) ;  # histogram on the top
#                 # ax4 = fig.add_subplot(gs[0,3]) if stats else None
#             ############################################################




#             ## Highlighting units of interest
#             if interest!=None:
#                 for value in wanted_array:
#                     ax = ax1 if hist else plt
#                     if type(value) == list:
#                         values_of_interest = [vMI_value for vMI_value in vMI if value[0] <= vMI_value <= value[1]]
#                         for value in values_of_interest:
#                             unit_index = np.where(vMI == value)[0]
#                             ax.scatter(position[unit_index], AllDepth[unit_index], marker='s', s=s*4, edgecolors='black', facecolors='none', linewidths=2) if len(unit_index) > 0 else None
#                     else:
#                         unit_index = np.where(vMI == value)[0]
#                         if len(unit_index) > 0:
#                             ax.scatter(position[unit_index], AllDepth[unit_index], marker='s', s=s*4, edgecolors='black', facecolors='none', linewidths=2)
#             ############################################################




#             ## Plots of data
#             if stats:                    
#                 if hist:
#                     for position, depth, transparence, color,  label in zip(positionData, depthData, [alpha, alpha, alpha/3], ['red', 'blue', 'grey'], ['Excited units', 'Suppressed units', 'Non-significant modulation']):
#                         ax1.scatter(position, depth, s=s, alpha=transparence, color=color, label=label)
#                     ax1.legend()
#                     for histdata, bins, color in zip(histdataDist2, bins2, ['red', 'blue']):
#                         ax2.plot(histdata, bins[:-1], color=color)
#                     for histdata, bins, color in zip(histdataDist3, bins3, ['red', 'blue']):
#                         ax3.plot(bins[:-1], histdata, color=color)
#                     # ax4.bar(modulation_label, modulation_quantity[direction], color=['red', 'blue', 'grey'], alpha=1)
#                 else:
#                     for position, depth, transparence, color,  label in zip(positionData, depthData, [alpha, alpha, alpha/3], ['red', 'blue', 'grey'], ['Excited units', 'Suppressed units', 'Non-significant modulation']):
#                         plt.scatter(position, depth, s=s, alpha=transparence, color=color, label=label)
#                     plt.legend()
#             else:
#                 if hist:
#                     ax1.scatter(positionData, depthData, s=s, alpha=alpha, color=color)
#                     ax2.plot(histdataDist2, bins2[:-1], color=color)
#                     ax3.plot(bins3[:-1], histdataDist3, color=color)
#                 else:
#                     plt.scatter(positionData, depthData, c=vMI, cmap='coolwarm', s=s, alpha=alpha, clim=(minimum_value, maximum_value))                
#                     plt.colorbar(label=f"{direction} modulation index"+r" : $(n_{during} - n_{before})/(n_{during} + n_{before})$") if ((not hist) and (not stats)) else None


#             ## MEP
#             xlabel = f"{pos_title} position (mm)"
#             ylabel = r"Depth ($\mu$m)"
#             histlabel = 'Density (%)' if hist else None
#             # statsylabel = 'Number of units' if stats else None #(ax4)
#             suptitle = f"{direction} modulation in {pos_title} axis"

#             xfontsize = xfontsize
#             yfontsize = yfontsize
#             # statsyfontsize = 12 if stats else None #(ax4)
#             suptitlefontsize = suptitlefontsize
#             ############################################################

#             ## Aesthetics
#             if hist:
#                 ax1.invert_yaxis()
#                 ax1.invert_xaxis() if posOrientation == 'ML_pos' else None
#                 ax1.set_xlabel(xlabel, fontsize=xfontsize)
#                 ax1.set_ylabel(ylabel, fontsize=yfontsize)

#                 ax2.set_xlabel(histlabel, fontsize=xfontsize)
#                 ax2.spines['top'].set_visible(False)
#                 ax2.spines['right'].set_visible(False) 
#                 ax2.spines['bottom'].set_visible(True) 
#                 ax2.spines['left'].set_visible(False)
#                 ax2.tick_params(axis='both', which='both', bottom=True, top=False, left=False, right=False, labelleft=False, labelbottom=True)

#                 ax3.set_ylabel(histlabel, fontsize=yfontsize)
#                 ax3.spines['top'].set_visible(False)
#                 ax3.spines['right'].set_visible(False)
#                 ax3.spines['bottom'].set_visible(False)
#                 ax3.spines['left'].set_visible(True)
#                 ax3.tick_params(axis='both', which='both', bottom=False, top=False, left=True, right=False, labelleft=True, labelbottom=False)

#                 # if stats:
#                 #     ax4.set_ylabel(statsylabel, fontsize=statsyfontsize)
#                 #     ax4.set_xticks(np.arange(len(modulation_label)))
#                 #     ax4.set_xticklabels(modulation_label, rotation=20)
#             else:
#                 plt.gca().invert_yaxis()
#                 plt.gca().invert_xaxis() if posOrientation == 'ML_pos' else None
#                 plt.xlabel(xlabel, fontsize=xfontsize)
#                 plt.ylabel(ylabel, fontsize=yfontsize)
                    
#             plt.suptitle(suptitle, fontsize=suptitlefontsize)
#             ############################################################


#             # Saving and showing
#             if save:
#                 os.makedirs(os.path.dirname(filename), exist_ok=True)
#                 if filename=='':
#                     plt.savefig(os.path.join(saving_path , f"{direction}_modulation_{pos_title}.png"))
#                 else:
#                     plt.savefig(f"{filename}.png")
            
#             plt.show() if show else plt.close()
#             ############################################################

#             foo=dict()
#             for animal in AllData:
#                 foo[animal] = AllData[animal][posOrientation]
#             sorted_keys = [key for key, value in sorted(foo.items(), key=itemgetter(1))] 

#             myTable = PrettyTable(["Animal", pos_title]) 

#             for animal in sorted_keys:
#                 myTable.add_row([animal, foo[animal]])

#             print(myTable)

#         if stats and hist:
#             plotdata = pd.DataFrame({'CW':modulation_quantity['CW'], 'CCW':modulation_quantity['CCW']}, index=modulation_label)
#             plotdata.plot(kind="bar",figsize=(10, 5), color=['red', 'blue'], rot=25)
#             plt.ylabel('Number of units')
#             plt.title('Rotation modulation')
#             if save:
#                 os.makedirs(os.path.dirname(filename), exist_ok=True)
#                 if filename=='':
#                     plt.savefig(os.path.join(saving_path , f"modulation_{pos_title}_barplot.png"))
#                 else:
#                     plt.savefig(f"{filename}.png")
#             if show:
#                 plt.show()
#             else:
#                 plt.close()
#             print('\n')

# vMI_function(rotation_data, save=False, s=50, alpha=0.5, stats=True, hist=True, interest=['max', 'min', 'nul'])

In [None]:
# dirMI_function(rotation_data, stats=True, hist=True, s=50)

# Test

In [None]:
# x = np.linspace(0, 10, 100)
# y = np.sin(x)




# img = Image.open(os.path.join(saving_path, 'brain.jpeg'))

# fig, ax_main = plt.subplots()

# ax_main.plot(x, y)

# left, width = 0.5, 0.5
# bottom, height = 0.75, 0.2
# rect_image = [left, bottom, width, height]
# ax_image = fig.add_axes(rect_image)


# # Charger l'image et l'afficher
# ax_main.imshow(img)
# # ax_main.axis('off')
# ax_image.plot(x, y)
# # ax_image.axis('off')  # Désactiver les axes pour l'image


# plt.show()


In [None]:
# img = Image.open(os.path.join(saving_path, 'brain.jpeg'))

# with load_theme("arctic_light"):
#     # Créer une nouvelle figure
#     fig = plt.figure(figsize=(10, 5))

#     ax_image = fig.add_subplot()
#     ax_image.imshow(img)

#     left, width = 0.25, 0.5
#     bottom, height = 0.35, 0.2
#     rect_image = [left, bottom, width, height]
#     ax_plot = fig.add_axes(rect_image)


#     x = np.concatenate([np.random.normal(loc=ML_pos[animal], scale=0.015, size=(Nclust[animal], 1)) for animal in ML_pos])
#     y = np.concatenate([np.random.normal(loc=AP_pos[animal], scale=0.015, size=(Nclust[animal], 1)) for animal in AP_pos])
#     colors = np.concatenate([vMI[animal]['second']['CW'] for animal in vMI])


#     im1 = ax_plot.scatter(x, y, c=colors, cmap='coolwarm', marker='o', alpha=0.6, s=50)
#     ax_plot.set_ylabel('Anteroposterior (mm)')
#     ax_plot.set_xlabel('Mediolateral (mm)')

#     ax_plot.set_title('CW vMI')
#     ax_plot.invert_xaxis()

# fig.colorbar(im1, ax=ax_plot, label='Vestibular Modulation Index')


# plt.show()


In [None]:



# # Charger l'image avec Pillow
# img = Image.open(os.path.join(saving_path, 'brain.jpeg'))

# with load_theme("arctic_light"):
#     # Créer une nouvelle figure
#     fig = plt.figure(figsize=(10, 5))
#     gs = GridSpec(nrows=1, ncols=2)

#     ax1 = fig.add_subplot(gs[0,1])
#     # Afficher l'image sans les axes
#     ax1.axis('off')
#     ax1.imshow(img)



#     ax2 = fig.add_subplot(gs[0,0])
#     x = np.concatenate([np.random.normal(loc=ML_pos[animal], scale=0.015, size=(Nclust[animal], 1)) for animal in ML_pos])
#     y = np.concatenate([np.random.normal(loc=AP_pos[animal], scale=0.015, size=(Nclust[animal], 1)) for animal in AP_pos])
#     colors = np.concatenate([vMI[animal]['second']['CW'] for animal in vMI])


#     # Initialiser le graphique de dispersion 3D
#     im1 = ax2.scatter(x, y, c=colors, cmap='coolwarm', marker='o', alpha=0.6, s=50)

#     fig.colorbar(im1, ax=ax2, label='Vestibular Modulation Index')

#     #plt.gca().invert_xaxis()
#     ax2.set_ylabel('Anteroposterior (mm)')
#     ax2.set_xlabel('Mediolateral (mm)')

#     ax2.set_title('CW vMI')
#     ax2.invert_xaxis()

# # Afficher la figure
# plt.show()


In [None]:
# with load_theme("arctic_light"):
#     fig = plt.figure()

#     gs = GridSpec(nrows=6, ncols=7)

#     ax = fig.add_subplot(gs[1:4,0:4])
#     ax.plot([1, 2, 3], [4, 5, 6])

# ax3 = fig.add_subplot(gs[0,5:7])
# PSTH(SpikeTimes['animal21_a53d1s1']['second']['CW'][0], xlabel='Time (s)', ylabel='Z-Score', title='PSTH', ax=ax3)
# # Zscore, SEM, edges = get_Zscore(SpikeTimes['animal21_a53d1s1']['second']['CW'][0])
# # ax3.plot(edges[:-1], Zscore)
# # ax3.fill_between(edges[:-1], Zscore-SEM, Zscore+SEM, alpha=0.5)

# plt.show()

## vMI and dirMI in another way

In [None]:
# vMIplot = []
# posplot = []
# AllDepthplot = []

# for pos_title, posOrientation in zip(['Mediolateral', 'Anteroposterior'], ['ML_pos', 'AP_pos']):
#         for condition in ['second']:
#             for direction in ['CW','CCW']:
#                 for animal in AllData:
#                     vMIplot.extend(AllData[animal]['Statistics_data']['vMI'][condition][direction])
#                     posplot.extend(AllData[animal]['informative_data'][posOrientation]*np.ones(len(AllData[animal]['Statistics_data']['vMI'][condition][direction])))
#                     AllDepthplot.extend(AllData[animal]['MUA_data']['AllDepth'])
                
#                 with load_theme('arctic_light'):
#                     plt.figure(figsize=(17, 8))

#                     plt.gca().invert_yaxis()
#                     plt.scatter(vMIplot, AllDepthplot, c=posplot, cmap='coolwarm', s=200, alpha=0.3)
#                     plt.axvline(0, color='gray', linestyle='--')

#                     plt.xlabel("Vestibular Modulation Index\n(during - before) / (during + before)")
#                     plt.ylabel("Depth (µm)")
#                     plt.title(f"{direction} modulation in {pos_title} axis")

#                 plt.colorbar(label=f"{pos_title} position (mm)")
                
#                 direction_modulation_folder_bis = os.path.join(saving_path, 'Direction_modulation', 'bis')
#                 os.makedirs(direction_modulation_folder_bis, exist_ok=True)
#                 plt.savefig(os.path.join(direction_modulation_folder_bis , f"{direction}_modulation_{pos_title}.png"))

#                 plt.show()
#                 vMIplot = []
#                 posplot = []
#                 AllDepthplot = []

In [None]:
# dirMIplot = []
# posplot = []
# AllDepthplot = []
# mycm = ListedColormap(['blue', 'lawngreen', 'green', 'orange', 'purple', 'red', 'orchid', 'cyan', 'magenta'])

# for pos_title, posOrientation in zip(['Mediolateral', 'Anteroposterior'], ['ML_pos', 'AP_pos']):
#         for condition in ['second']:
#             for animal in AllData:
#                 dirMIplot.extend(AllData[animal]['Statistics_data']['dirMI'][condition])
#                 posplot.extend(AllData[animal]['informative_data'][posOrientation]*np.ones(len(AllData[animal]['Statistics_data']['dirMI'][condition])))
#                 AllDepthplot.extend(AllData[animal]['MUA_data']['AllDepth'])
            
#             plt.figure(figsize=(17, 8))

#             plt.gca().invert_yaxis()
#             plt.scatter(dirMIplot, AllDepthplot, c=posplot, cmap='coolwarm', s=200, alpha=0.3)
#             plt.colorbar(label=f"{pos_title} position (mm)")
#             plt.axvline(0, color='gray', linestyle='--')

#             plt.xlabel("Direction Modulation Index\n(CW - CCW) / (CW + CCW)")
#             plt.ylabel("Depth (µm)")
#             plt.title(f"CW vs CCW preference in {pos_title} axis")

#             direction_preference_folder_bis = os.path.join(saving_path, 'Direction_preference', 'bis')
#             os.makedirs(direction_preference_folder_bis, exist_ok=True)
#             plt.savefig(os.path.join(direction_preference_folder_bis , f"{direction}_preference_{pos_title}.png"))

#             plt.show()
#             dirMIplot = []
#             posplot = []
#             AllDepthplot = []


Responding/Non-responding distribution

In [None]:
# phototagged_ones = np.concatenate([modulation[animal]['first'] for animal in AllData])==1
# both_ones = np.concatenate([[modulation[animal]['second'][neuron]['selectivity'] for neuron in range(Nclust[animal])] for animal in AllData])=='both'
# CW_ones = np.concatenate([[modulation[animal]['second'][neuron]['selectivity'] for neuron in range(Nclust[animal])] for animal in AllData])=='CW'
# CCW_ones = np.concatenate([[modulation[animal]['second'][neuron]['selectivity'] for neuron in range(Nclust[animal])] for animal in AllData])=='CCW'
# unmodulated_ones = np.concatenate([[modulation[animal]['second'][neuron]['selectivity'] for neuron in range(Nclust[animal])] for animal in AllData])=='unmodulated'

# for condition in ['first']:
#     phototagged = len(np.where(phototagged_ones)[0])
#     photo_respBOTH = len(np.where((phototagged_ones) & (both_ones))[0])
#     photo_respCW = len(np.where((phototagged_ones) & (CW_ones))[0])
#     photo_respCCW = len(np.where((phototagged_ones) & (CCW_ones))[0])
#     photo_resp = photo_respBOTH + photo_respCCW + photo_respCW
#     photo_notresp = len(np.where((phototagged_ones) & (unmodulated_ones))[0])

#     non_phototagged = np.sum([Nclust[animal] for animal in Nclust]) - phototagged

#     plt.rcParams.update({
#         "text.usetex": False,
#         # "font.family": "Helvetica"
#     })

#     if photo_resp != 0:
#         plt.figure(figsize=(15,5))
#         plt.subplot(1,3,1)

#         labels = ['Photo-tagged', 'Not photo-tagged']
#         valeurs = [phototagged, non_phototagged]
#         print(f"phototagged or not : {valeurs}")
#         plt.pie(valeurs, labels=["{} ({})".format(label, valeur) for label, valeur in zip(labels, valeurs)], autopct='%1.1f%%', startangle=90, colors=['#0099ff', '#d1d1e0'])
#         plt.title('Photo-tagging')


#         plt.subplot(1,3,2)
#         labels = ['Responding', 'Not responding']
#         valeurs = [photo_resp, photo_notresp]
#         print(f"Phototagged neurons responding or not : {valeurs}")
#         plt.pie(valeurs, labels=["{} ({})".format(label, valeur) for label, valeur in zip(labels, valeurs)], autopct='%1.1f%%', startangle=90, colors=['#0F6EDF', '#FF7762'])
#         plt.title('Response to rotation\nof photo-tagged neurons')


#         plt.subplot(1,3,3)
#         labels = ['CW', 'CCW', 'Both']
#         valeurs = [photo_respCW, photo_respCCW, photo_respBOTH]
#         print(f"Modulation of phototagged neurons : {valeurs}")
#         plt.pie(valeurs, labels=["{} ({})".format(label, valeur) for label, valeur in zip(labels, valeurs)], autopct='%1.1f%%', startangle=90, colors=['#5FA8FF', '#3385E7', '#1F6ECC'])
#         plt.title('Modulation of photo-tagged neurons')

#         plt.tight_layout()

#         # if Saving_boolean:
#         #     os.makedirs(os.path.join(saving_folder, 'Distribution'), exist_ok=True)
#         #     plt.savefig(os.path.join(saving_folder, 'Distribution', f"Distribution_{exp_id}_phototagging.png"))

#         plt.show()
#         print("\n")
#     elif phototagged != 0:
#         plt.figure(figsize=(15,5))
#         plt.subplot(1,2,1)

#         labels = ['Photo-tagged', 'Not photo-tagged']
#         valeurs = [phototagged, non_phototagged]
#         print(f"phototagged or not : {valeurs}")
#         plt.pie(valeurs, labels=["{} ({})".format(label, valeur) for label, valeur in zip(labels, valeurs)], autopct='%1.1f%%', startangle=90, colors=['#0099ff', '#d1d1e0'])
#         plt.title('Photo-tagging')


#         plt.subplot(1,2,2)
#         labels = ['Responding', 'Not responding']
#         valeurs = [photo_resp, photo_notresp]
#         print(f"Phototagged neurons responding or not : {valeurs}")
#         plt.pie(valeurs, labels=["{} ({})".format(label, valeur) for label, valeur in zip(labels, valeurs)], autopct='%1.1f%%', startangle=90, colors=['#0F6EDF', '#FF7762'])
#         plt.title('Response to rotation\nof photo-tagged neurons')

#         plt.tight_layout()

#         # if Saving_boolean:
#         #     os.makedirs(os.path.join(saving_folder, 'Distribution'), exist_ok=True)
#         #     plt.savefig(os.path.join(saving_folder, 'Distribution', f"Distribution_{exp_id}_phototagging.png"))

#         plt.show()
#         print("\n")
#     else:
#         plt.figure(figsize=(15,5))

#         labels = ['Photo-tagged', 'Not photo-tagged']
#         valeurs = [phototagged, non_phototagged]
#         print(f"phototagged or not : {valeurs}")
#         plt.pie(valeurs, labels=["{} ({})".format(label, valeur) for label, valeur in zip(labels, valeurs)], autopct='%1.1f%%', startangle=90, colors=['#0099ff', '#d1d1e0'])
#         plt.title('Photo-tagging')

#         # if Saving_boolean:
#         #     os.makedirs(os.path.join(saving_folder, 'Distribution'), exist_ok=True)
#         #     plt.savefig(os.path.join(saving_folder, 'Distribution', f"Distribution_{exp_id}_phototagging.png"))

#         plt.show()
#         print("\n")








# for condition in ['second']:
#     CWmod = np.count_nonzero(np.concatenate([[modulation[animal][condition][neuron]['selectivity'] == 'CW' for neuron in range(Nclust[animal])] for animal in AllData]))
#     CCWmod = np.count_nonzero(np.concatenate([[modulation[animal][condition][neuron]['selectivity'] == 'CCW' for neuron in range(Nclust[animal])] for animal in AllData]))
#     BOTHmod = np.count_nonzero(np.concatenate([[modulation[animal][condition][neuron]['selectivity'] == 'both' for neuron in range(Nclust[animal])] for animal in AllData]))
#     resp_units = CWmod + CCWmod + BOTHmod
#     nonresp_units = np.sum([Nclust[animal] for animal in Nclust]) - resp_units


#     labels = ['Responding', 'Not responding']
#     valeurs = [resp_units, nonresp_units]

#     plt.rcParams.update({
#         "text.usetex": False,
#         # "font.family": "Helvetica"
#     })

#     if resp_units != 0:
#         plt.figure(figsize=(15,5))
#         plt.subplot(1, 2, 1)
#         plt.pie(valeurs, labels=["{} ({})".format(label, valeur) for label, valeur in zip(labels, valeurs)], autopct='%1.1f%%', startangle=90, colors=['#0F6EDF', '#FF7762'])
#         plt.title('Response to rotation')



#         labels_selectivite = ['CW', 'CCW', 'Both']
#         valeurs_selectivite = [CWmod, CCWmod, BOTHmod]

#         print("\n")
#         plt.subplot(1, 2, 2)
#         plt.pie(valeurs_selectivite, labels=["{} ({})".format(label, valeur) for label, valeur in zip(labels_selectivite, valeurs_selectivite)], autopct='%1.1f%%', startangle=90, colors=['#5FA8FF', '#3385E7', '#1F6ECC'])
#         # plt.text(1,1, '', ha='center', va='center', fontsize=12, color='red')
#         plt.title('Selectivity of responding neurons')

#         plt.tight_layout()
#     else:
#         plt.figure(figsize=(15,5))
#         plt.pie(valeurs, labels=["{} ({})".format(label, valeur) for label, valeur in zip(labels, valeurs)], autopct='%1.1f%%', startangle=90, colors=['#0F6EDF', '#FF7762'])
#         plt.title('Response to rotation')

#     # plt.suptitle(f"{condition} condition", fontsize=16)

#     # if Saving_boolean:
#     #     os.makedirs(os.path.join(saving_folder, 'Distribution'), exist_ok=True)
#     #     plt.savefig(os.path.join(saving_folder, 'Distribution', f"Distribution_{exp_id}_{condition}_condition.png"))
    
#     plt.show()