In [None]:
import os
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import colors
import matplotlib
import json
import pandas as pd
import seaborn as sns


def read_json(f):
    f = open(f) 
    return json.load(f) 




def gen_plots(path_dir = "/home/*****/work/pid/results/vbn_final/dim_{}/setting_{}/seed_{}/bin{}/results_{}.json",
              bins=[0,1,2,3,4],
              seeds=[11,15,22,23,66],
              setting=3, 
              dim = 20,
              met = "o_inf"):

    e= {met : {
        "change":[],
        "non_change":[]}  for met in bins}
    df=pd.DataFrame(columns=['bins',"seed", 'set', 'o_inf'])

    for i, seed in enumerate( seeds ) :
        for bin in bins:
            file_path_change = path_dir.format(dim,setting,seed,bin,"change")
            file_path_non_change = path_dir.format(dim,setting,seed,bin,"non_change")
            
            out_change=read_json(file_path_change)
            out_non_change =read_json(file_path_non_change)


            e[bin]["change"].append(out_change ["e"]["simple"][met])
            e[bin]["non_change"].append(out_non_change ["e"]["simple"][met])
            df.loc[-1]= [bin,seed,"change",out_change ["e"]["simple"][met]]
            df.index = df.index + 1  # shifting index
            df = df.sort_index()
            df.loc[-1]= [bin,seed,"non_change",out_non_change ["e"]["simple"][met]]
            df.index = df.index + 1  # shifting index
            df = df.sort_index()
    return df,e



def gen_plots_mice(path_dir = "/home/*****/work/pid/results/vbn/dim_{}/setting_{}/seed_{}/bin{}/results_{}.json",
              bins=[0,1,2,3,4],
              seeds=[88,15,77,23,66],
              setting=3, 
              dim = 20,
              met = "o_inf"):

    e= {met : {
        "change":[],
        "non_change":[]}  for met in bins}
    df=pd.DataFrame(columns=['bins',"seed", 'set', met,'mice'])
    
    for i, seed in enumerate( seeds ) :
        for bin in bins:
            file_path_change = path_dir.format(dim,setting,seed,bin,"change")
            
            file_path_non_change = path_dir.format(dim,setting,seed,bin,"non_change")
   
            out_change=read_json(file_path_change)
            out_non_change =read_json(file_path_non_change)

            change=out_change ["ses"][met]
            non_change=out_non_change ["ses"][met]
            e[bin]["change"].append(change)
            e[bin]["non_change"].append(non_change)
            for idx, (mice_change, mice_non_change) in enumerate( zip(change,non_change) ):
                df.loc[-1]= [bin,seed,"change",mice_change,idx ]
                df.index = df.index + 1  # shifting index
                df = df.sort_index()
                
                df.loc[-1]= [bin,seed,"non_change",mice_non_change,idx ]
                df.index = df.index + 1  # shifting index
                df = df.sort_index()    
    
    return df,e

In [None]:
for NB in [3,6]:   
    for dim in [10,25,50]:
        
        df,data = gen_plots_mice(dim=dim,setting = NB)
        df_plot = df.groupby(["set","mice","bins"]).mean().reset_index()
        df_plot["bins"] = df_plot["bins"] * 50
        df_plot["set"] = df_plot["set"].apply(lambda x : "Non-change" if x=="non_change" else "Change")
        palette = "GnBu"
        hue_order = ['Non-change','Change']
        
        fig, ax1 = plt.subplots(nrows=1, ncols=1, figsize=(7, 5))
        
        ax = sns.boxplot(x="bins", y="o_inf", 
                        hue="set", 
                        data=df_plot, 
                        palette=palette,hue_order=hue_order,
                        fliersize=0)

        # add stripplot with dodge=True
        sns.stripplot(x="bins", y="o_inf", hue="set", 
                    data=df_plot, palette=palette, 
                    dodge=True, ax=ax, hue_order=hue_order,
                    #ec='k', 
                    linewidth=1,
                    jitter=True)

        if NB == 3:
            y_ax_label = " ( {}, {}, {} )".format("VISp", "VISl", "VISal", "VISrl", "VISam", "VISpm" )
        else:
            y_ax_label = " ({}, {}, {}, {}, {}, {})".format("VISp", "VISl", "VISal", "VISrl", "VISam", "VISpm" ) 

        # remove extra legend handles
        handles, labels = ax.get_legend_handles_labels()
        ax.legend(handles[:2], labels[:2], title='', #bbox_to_anchor=(1, 1.02), loc='upper left'
                )
        ax.set_xlabel('Time after flash (ms)',fontsize=14)
        ax.set_ylabel(r'$\Omega$'+y_ax_label,fontsize=12)
        ax.set_ylim(ymin=0.5)
        plt.savefig("plots_vbn/new_vbn_setting_{}_dim_{}.png".format(NB,dim),bbox_inches='tight',dpi=300)
        ax = None


In [None]:
for NB in [3,6]:   
    for dim in [10,25,50]:
        
        df,data = gen_plots_mice(dim=dim,setting = NB,met="s_inf")
        df_plot = df.groupby(["set","mice","bins"]).mean().reset_index()
        df_plot["bins"] = df_plot["bins"] * 50
        df_plot["set"] = df_plot["set"].apply(lambda x : "Non-change" if x=="non_change" else "Change")
        palette = "GnBu"
        hue_order = ['Non-change','Change']
        
        fig, ax1 = plt.subplots(nrows=1, ncols=1, figsize=(7, 5))
        
        ax = sns.boxplot(x="bins", y="s_inf", 
                        hue="set", 
                        data=df_plot, 
                        palette=palette,hue_order=hue_order,
                        fliersize=0)

        # add stripplot with dodge=True
        sns.stripplot(x="bins", y="s_inf", hue="set", 
                    data=df_plot, palette=palette, 
                    dodge=True, ax=ax, hue_order=hue_order,
                    #ec='k', 
                    linewidth=1,
                    jitter=True)

        if NB == 3:
            y_ax_label = " ( {}, {}, {} )".format("VISp", "VISl", "VISal", "VISrl", "VISam", "VISpm" )
        else:
            y_ax_label = " ({}, {}, {}, {}, {}, {})".format("VISp", "VISl", "VISal", "VISrl", "VISam", "VISpm" ) 

        # remove extra legend handles
        handles, labels = ax.get_legend_handles_labels()
        ax.legend(handles[:2], labels[:2], title='', #bbox_to_anchor=(1, 1.02), loc='upper left'
                )
        ax.set_xlabel('Time after flash (ms)',fontsize=14)
        ax.set_ylabel(r'S'+y_ax_label,fontsize=12)
        ax.set_ylim(ymin=0.5)
        plt.savefig("plots_vbn/s_inf_vbn_setting_{}_dim_{}.png".format(NB,dim),bbox_inches='tight',dpi=300)
        ax = None
