In [None]:
""" Load data across all days, which have been exgtracted for stepwise (i.e. in the microstim analyses in drawmonkey)
And plot summaries.
"""

### Stepwise - load across days

In [11]:
import glob
import pandas as pd
import numpy as np
import glob

def _convert_mec_to_bregion(mec, date):
    
    regions = map_date_bregion[date]
        
    off = mec.find("off")>-1
    ttl3 = mec.find("TTL3")>-1    
    ttl4 = mec.find("TTL4")>-1
    
    assert sum([off, ttl3, ttl4])==1
    
    if off:
        return "off"
    elif ttl3:
        return regions[0]
    elif ttl4:
        return regions[1]
    else:
        print(off, ttl3, ttl4)
        assert False   
    
    
def _extract_epochs_from_path(path, s):
    """
    e.g., 
    s = "epochsetchar"
    path = "/mnt/Freiwald_kgupta/kgupta/analyses/main/stepwise/Pancho_231201_gramdirstimpancho3b/actions_and_trials-epochsetchar-('AnBmCk2|TTL3-fgon', 'AnBmCk2|off', 'L|TTL3-fgon', 'L|off')"
    returns ['AnBmCk2|TTL3-fgon', 'AnBmCk2|off', 'L|TTL3-fgon', 'L|off']
    """
    from pythonlib.tools.stringtools import decompose_string
    ind = path.find(f"{s}-")
    s_sub = path[ind+len(s)+1:-1]
    epochs = decompose_string(s_sub, ", ")
    return sorted([e[2:-1] for e in epochs])

    
SDIR = "/mnt/Freiwald_kgupta/kgupta/analyses/main/stepwise"
ANIMAL = "Diego"

# To load less data, restrict to a specific bregion
# bregion = "vlPFC"
bregion = None # get all

if ANIMAL=="Pancho":
    dates = [231011, 231013, 231018, 231023, 231024, 231026, 231101, 231102, 231103, 231106, 231109, 231110, 231118, 231120, 
             231121, 231122, 231128, 231129, 231201]
    
    # for each date, note which brain regions.
    map_date_bregion = {
        231011: ("M1", "preSMA"),
        231013: ("M1", "preSMA"),
        231018: ("dlPFC", "vlPFC"),
        231023: ("FP", "SMA"),
        231024: ("PMv", "PMd"),
        231026: ("M1", "preSMA"),
        231101: ("M1", "preSMA"),
        231102: ("vlPFC", "preSMA"),
        231103: ("dlPFC", "PMd_a"),
        231106: ("preSMA", "M1"),
        231109: ("PMv", "SMA"),
        231110: ("PMd_p", "FP"),
        231118: ("dlPFC",),
        231120: ("vlPFC",),
        231121: ("preSMA",),
        231122: ("PMd_a",),
        231128: ("FP",),
        231129: ("FP",),
        231201: ("dlPFC",),
    }
    
    DATES_TO_SKIP = [
        231011, # AnBm too easy
        231026 # bad beahvior
    ]
elif ANIMAL=="Diego":
    dates = [231101, 231102, 231103, 231106, 231107, 231108, 231109, 231110, 231113, 231114, 231115]
    
    # for each date, note which brain regions.
    map_date_bregion = {
        231101: ("M1", "preSMA"),
        231102: ("vlPFC", "preSMA"),
        231103: ("dlPFC", "PMd"),
        231106: ("preSMA",),
        231107: ("dlPFC", "PMd_a"), 
        231108: ("vlPFC",),
        231109: ("preSMA",),
        231110: ("dlPFC", "vlPFC"),
        231113: ("PMd_a", "PMv"),
        231114: ("SMA", "FP"),
        231115: ("M1-PMd_p", "preSMA"),
    }
    
    DATES_TO_SKIP = []
else:
    assert False
for v in map_date_bregion.values():
    assert isinstance(v, tuple)
    
def load_concat_all_days(WHICH_DATA, trials_or_actions="actions", dates_to_skip=None):
    dict_df = {}
    dict_params = {}
    dict_list_epochs = {}
        
    if dates_to_skip is None:
        dates_to_skip = DATES_TO_SKIP
        
    for d in dates:
        
        if d in dates_to_skip:
            continue
            
        if WHICH_DATA=="alldata":
            path = f"{SDIR}/{ANIMAL}_{d}_*/actions_and_trials-{WHICH_DATA}*"
        else:        
            path = f"{SDIR}/{ANIMAL}_{d}_*/actions_and_trials-{WHICH_DATA}-*"
        import glob
        this = glob.glob(path)
        
        # Take the longer one - it prob has more epochs...
        indtmp = np.argmax([len(t) for t in this])
        sdir = this[indtmp]
        
        list_epoch = _extract_epochs_from_path(sdir, WHICH_DATA)
        
        print(sdir)
        print("EPOCHS: ", list_epoch)
    
        # df_dat = pd.read_pickle(f"{sdir}/df_dat.pkl")    
        if trials_or_actions == "actions":
            df_actions = pd.read_pickle(f"{sdir}/df_actions.pkl")
        elif trials_or_actions == "trials":
            df_actions = pd.read_pickle(f"{sdir}/df_dat.pkl")
        else:
            print(trials_or_actions)
            assert False
        params = pd.read_pickle(f"{sdir}/Params.pkl")    
        # df_dat["date"] = d
        df_actions["date"] = d
            
        # Figure out the bregions for each epoch code, and append as new
        # column in df_actions
        from pythonlib.tools.pandastools import applyFunctionToAllRows    
        def F(x):
            mec = x["microstim_epoch_code"]
            date = x["date"]
            return _convert_mec_to_bregion(mec, date)
        df_actions = applyFunctionToAllRows(df_actions, F, "microstim_bregion")    
        # df_dat = applyFunctionToAllRows(df_dat, F, "microstim_bregion")    
        
        # count how many rules today
        epochs_today = sorted(df_actions["epoch_orig"].unique())
        # nepochs = len(df_actions["epoch_orig"].unique())
        epochs_today = "__".join([str(e) for e in epochs_today])
        df_actions["epochs_today"] =epochs_today
                
        if bregion is None or bregion in df_actions["microstim_bregion"].tolist():
            print("** Keeping:", path)
            
            if "choice_code" in df_actions.columns:
                # convert choice code into a string
                def F(x):
                    return "".join([str(int(c)) for c in x["choice_code"]])
                df_actions = applyFunctionToAllRows(df_actions, F, "choice_code_str")
                
            dict_df[d] = df_actions
            dict_params[d] = params
            dict_list_epochs[d] = list_epoch
            
            ## SANITY CHECKS
            # check that one to one correpsonsdence between epochs in filename and in data
            # this is mainly a sanity check of code. it must be true.
            if not WHICH_DATA=="alldata":
                list_epoch_data = sorted(df_actions["epoch"].unique().tolist())
                assert list_epoch_data==list_epoch
        else:
            assert False
            
    # Concat
    df_actions_all = pd.concat(dict_df.values()).reset_index(drop=True)
    
    from pythonlib.tools.pandastools import applyFunctionToAllRows
    def F(x):
        if x["microstim_epoch_code"]=="off":
            return "off"
        else:
            return "on"
    df_actions_all = applyFunctionToAllRows(df_actions_all, F, "microstim_status")
    
    # Duplicate "off" trials for each day, one for each bregion that is stimmed that day.
    list_df = []
    list_date = sorted(df_actions_all["date"].unique().tolist())
    for date in list_date:
        dfthis = df_actions_all[df_actions_all["date"]==date]
        regions = [r for r in dfthis["microstim_bregion"].unique().tolist() if not r=="off"]
        for r in regions:
            # Get a copy of the stim "off" trials for this (day, region)
            dftmp = dfthis[dfthis["microstim_bregion"]=="off"].copy()
            dftmp["bregion_expt"] = r
            list_df.append(dftmp)
            
            # Same, but stim on:
            dftmp = dfthis[dfthis["microstim_bregion"]==r].copy()
            dftmp["bregion_expt"] = r
            list_df.append(dftmp)
    df_actions_all = pd.concat(list_df).reset_index(drop=True)
    
    return df_actions_all, dict_df, dict_params, dict_list_epochs

    

In [None]:
trials_or_actions = "actions"
for WHICH_DATA in ["alldata", "epochsetchar"]:
    # WHICH_DATA = "alldata"
    # WHICH_DATA = "epochsetchar"
    # s = "alldata"
    
    
    # s = WHICH_DATA
    # s = "epochsetchar"
    # s = "alldata"
    
    df_actions_all, dict_df, dict_params, dict_list_epochs = load_concat_all_days(WHICH_DATA, trials_or_actions=trials_or_actions)
    
    # Lumping more together here
    map_codegrp_to_code = {
        "dir": ("1100", "0100"),
        "random":("000", "0000"),
        "sameshape":("1001", "101", "0001", "001"), # incorrect                        
        "shapes":("011", "0011", "1011", "111"),
        "close":("100", "1000"),
        "others":("1111","0111")
    } 
    map_code_to_codegrp = {}
    for k, v in map_codegrp_to_code.items():
        assert isinstance(v, tuple)
        for vv in v:
            assert vv not in map_code_to_codegrp.keys()
            map_code_to_codegrp[vv] = k
            
    # make sure you got all codes
    for c in df_actions_all["choice_code_str"].unique():
        print(c, " --> ", map_code_to_codegrp[c])
        assert c in map_code_to_codegrp
        
    # append column
    from pythonlib.tools.pandastools import applyFunctionToAllRows
    def F(x):
        return map_code_to_codegrp[x["choice_code_str"]]
    df_actions_all = applyFunctionToAllRows(df_actions_all, F, "choice_code_grp")
    
    
    ##### Plot summary across all days
    ##### Same, but compare bregions in a single plot
            
    
    if False:
        from pythonlib.tools.pandastools import grouping_print_n_samples
        grouping_print_n_samples(df_actions_all, ["date", "microstim_bregion"]);
        from pythonlib.tools.pandastools import grouping_print_n_samples
        grouping_print_n_samples(df_actions_all_duploff, ["date", "microstim_bregion", "bregion_expt"]);
    # Give a generic label "off"/"on" for microstim.

    # Expand binary variable
    from pythonlib.tools.pandastools import expand_categorical_variable_to_binary_variables, plot_45scatter_means_flexible_grouping
    from pythonlib.tools.plottools import savefig
    import matplotlib.pyplot as plt
    
    
    for idx in [0,1]:
        if idx==0:
            # sbuplot is region
            VAR_FIGURE = "bregion_expt"
            VAR_SUBPLOT = "code"
        elif idx==1:
            # subplot is code./
            VAR_FIGURE = "code"
            VAR_SUBPLOT = "bregion_expt"
        else:
            assert False
        
        # GEneral params
        SDIR_SAVE = f"/gorilla1/analyses/main/stepwise/MULT_DATES/{ANIMAL}/{WHICH_DATA}/group_by_{VAR_FIGURE}"
        var_value = "value"
        var_manip = "microstim_status"
        x_lev_manip = "off"
        y_lev_manip = "on"
        
        for var in ["choice_code_str", "beh_label_code", "choice_code_grp"]:
        # for var in ["choice_code_grp"]:
            if VAR_SUBPLOT == "code":
                var_subplot = var
            else:
                var_subplot = VAR_SUBPLOT
            
            if VAR_FIGURE == "code":
                var_figure = var
                # sharex=sharey=True
                shareaxes = True
            else:
                var_figure = VAR_FIGURE
                # sharex=sharey=False
                shareaxes=False
            sdir = f"{SDIR_SAVE}/{var}"
            import os
            os.makedirs(sdir, exist_ok=True)
            
            df_actions_expanded = expand_categorical_variable_to_binary_variables(df_actions_all_duploff, var)
            
            ##### 45deg scatter plots - good!
            # list_bregion = sorted(df_actions_expanded["microstim_bregion"].unique().tolist())
            list_var_figure = sorted(df_actions_expanded[var_figure].unique().tolist())
            # list_var_figure = [b for b in list_bregion if not b==x_lev_manip]
            list_epoch_orig = sorted(df_actions_expanded["epoch_orig"].unique().tolist())
            list_code = sorted(df_actions_all_duploff[var].unique().tolist())
            for bregion in list_var_figure:
                for epoch_orig in list_epoch_orig: 
                    print("This figure: ", bregion, epoch_orig)
                    dfthis = df_actions_expanded[(df_actions_expanded[var_figure]==bregion) & (df_actions_expanded["epoch_orig"]==epoch_orig)].reset_index(drop=True)
                    
                    if x_lev_manip in dfthis[var_manip].tolist() and y_lev_manip in dfthis[var_manip].tolist():
                        dfres, fig = plot_45scatter_means_flexible_grouping(dfthis, var_manip, x_lev_manip, y_lev_manip,
                                                                   var_subplot, var_value, "date", shareaxes=shareaxes)
                        path = f"{sdir}/scatter45-{bregion}-{epoch_orig}.pdf"
                        print(path)
                        savefig(fig, path)
                        plt.close("all")

##### Search for example trials

##### each trial is a single datapt - correct vs failure.

In [23]:
if ANIMAL=="Diego":
    dates_exist = sorted(df_trials_all["date"].unique())
    print(dates_exist)
    # for diego trials, only keep those that are 11/9/23 or after (i.e., those with harder tasks).
    dates_to_skip = list(range(dates_exist[0], 231109))
    print(dates_to_skip)
else:
    dates_to_skip = None

[231101, 231102, 231103, 231106, 231107, 231108, 231109, 231110, 231113, 231114, 231115]
[231101, 231102, 231103, 231104, 231105, 231106, 231107, 231108]


In [24]:
from pythonlib.tools.pandastools import plot_45scatter_means_flexible_grouping
from pythonlib.tools.plottools import  savefig

import matplotlib.pyplot as plt
for WHICH_DATA in ["alldata", "epochsetchar"]:
    df_trials_all, dict_df, dict_params, dict_list_epochs = load_concat_all_days(WHICH_DATA, 
                                                                                 trials_or_actions="trials", 
                                                                                 dates_to_skip=dates_to_skip)

    # Preprocessing
    # - Remove aborts
    df_trials_all = df_trials_all[df_trials_all["exclude_because_online_abort"]==False].reset_index(drop=True)
    df_trials_all["microstim_status"].value_counts()
    
    # each trial is a single datapt - correct vs failure.
    var_manip = "microstim_status"
    x_lev_manip = "off"
    y_lev_manip = "on"
    var_subplot = "bregion_expt"
    var_value = "success_binary_quick"
    var_datapt = "date"
    shareaxes = True
    
    list_epoch_orig = df_trials_all["epoch_orig"].unique().tolist()
    list_epochs_today = df_trials_all["epochs_today"].unique().tolist()
    
    SDIR_SAVE = f"/gorilla1/analyses/main/stepwise/MULT_DATES/{ANIMAL}/{WHICH_DATA}/TRIAL_LEVEL"
    import os
    os.makedirs(SDIR_SAVE, exist_ok=True)
    
    for epoch_orig in list_epoch_orig:
        for epochs_today in list_epochs_today:     
            dfthis = df_trials_all[(df_trials_all["epoch_orig"]==epoch_orig) & (df_trials_all["epochs_today"]==epochs_today)]
            if len(dfthis)>0:
                dfres, fig = plot_45scatter_means_flexible_grouping(dfthis, var_manip, x_lev_manip, y_lev_manip,
                                                       var_subplot, var_value, var_datapt, shareaxes=shareaxes)
                path = f"{SDIR_SAVE}/epochstoday_{epochs_today}-scatter45-{epoch_orig}.pdf"
                savefig(fig, path)
                plt.close("all")
        
        # Also combine all
        dfthis = df_trials_all[(df_trials_all["epoch_orig"]==epoch_orig)]
        if len(dfthis)>0:
            dfres, fig = plot_45scatter_means_flexible_grouping(dfthis, var_manip, x_lev_manip, y_lev_manip,
                                                   var_subplot, var_value, var_datapt, shareaxes=shareaxes)
            path = f"{SDIR_SAVE}/epochstoday_ALL-scatter45-{epoch_orig}.pdf"
            savefig(fig, path)
            plt.close("all")



/mnt/Freiwald_kgupta/kgupta/analyses/main/stepwise/Diego_231109_gramstimdiego2f/actions_and_trials-alldata
EPOCHS:  []
** Keeping: /mnt/Freiwald_kgupta/kgupta/analyses/main/stepwise/Diego_231109_*/actions_and_trials-alldata*
/mnt/Freiwald_kgupta/kgupta/analyses/main/stepwise/Diego_231110_gramstimdiego2f/actions_and_trials-alldata
EPOCHS:  []
** Keeping: /mnt/Freiwald_kgupta/kgupta/analyses/main/stepwise/Diego_231110_*/actions_and_trials-alldata*
/mnt/Freiwald_kgupta/kgupta/analyses/main/stepwise/Diego_231113_gramstimdiego2f/actions_and_trials-alldata
EPOCHS:  []
** Keeping: /mnt/Freiwald_kgupta/kgupta/analyses/main/stepwise/Diego_231113_*/actions_and_trials-alldata*
/mnt/Freiwald_kgupta/kgupta/analyses/main/stepwise/Diego_231114_gramstimdiego2f/actions_and_trials-alldata
EPOCHS:  []
** Keeping: /mnt/Freiwald_kgupta/kgupta/analyses/main/stepwise/Diego_231114_*/actions_and_trials-alldata*
/mnt/Freiwald_kgupta/kgupta/analyses/main/stepwise/Diego_231115_gramstimdiego2f/actions_and_trials-a

In [None]:
,# Pick out a specific 


In [None]:
# 1. exclude bad dates.
# 2. group across single/dual rule days (e.g., 000, 0000)
# 3. Bar plot across brain areas --> semantic label.