In [None]:
from datetime import datetime, timedelta
from typing import Union

import pandas as pd
import numpy as np
from scipy.stats import entropy
import sys
import matplotlib.pyplot as plt
import matplotlib.collections as collections
import matplotlib.markers as markers
import os
import csv

# debug = True
debug = False


class task_data:
    def __init__(self, mice: list, tasks, logpath):
        global debug
        self.data_file = ""
        self.data = None
        self.data_ci = None
        self.delta = None
        self.mouse_no = mice
        self.tasks = tasks
        self.pattern_prob = {}
        self.probability = None
        self.mice_task = None
        self.task_prob = {}
        self.mice_delta = {}
        self.entropy_analyze = None
        self.mice_entropy = None
        self.logpath = logpath
        self.session_id = 0
        self.burst_id = 0
        self.data_not_omission = None
        self.fig_prob_tmp = None
        self.fig_prob = {}
        self.bit = 4

        print('reading data...', end='')

        # TODO debug
        # 1.self.data, 2.probability, 3.task_prob, 4.self.delta, 5.self.fig_prob_tmp, 6.pattern, 7.self.entropy_analyze
        # to:dict, add:dict[task]
        # to:dict, add:dict(fig) or DF
        def append_dataframe(to: Union[pd.DataFrame, dict, None], add: Union[pd.DataFrame, dict, None], mouse_id: int,
                             task=None, fig_num=None):
            if isinstance(add, dict):

                # 二回目の場合Fig確定
                if not isinstance(task, type(None)):
                    # print(add)
                    ret_val = {}
                    [ret_val.update(append_dataframe(to.get(task, {}), add[fig], mouse_id, task=task, fig_num=fig)) for
                     fig in ["fig1", "fig2", "fig3"]]
                    return {task: ret_val}
                # taskごと
                # to;dict, add:dict[task]
                for add_task, add_dict in add.items():
                    # append_dataframe(to, append_dataframe(to, add_dict, mouse_id, task=add_task), mouse_id)
                    to.update(append_dataframe(to, add_dict, mouse_id, task=add_task))
                return to
            if isinstance(to, dict):
                # Fig二回目入力
                if not isinstance(fig_num, type(None)):
                    return {fig_num: append_dataframe(to.get(fig_num, None), add, mouse_id)}
                return {
                    task: append_dataframe(to.get(task, None), add, mouse_id)}
            if isinstance(to, type(None)):
                return add.assign(mouse_id=mouse_id)
            else:
                return to.append(add.assign(mouse_id=mouse_id), ignore_index=True)

        if debug:
            for mouse_id in self.mouse_no:
                print('mouse_id={}'.format(mouse_id))
                self.mice_task[mouse_id], self.probability[mouse_id], self.task_prob[mouse_id], self.mice_delta[
                    mouse_id], self.fig_prob[mouse_id] = self.dev_read_data(mouse_id)
                # tmp = self.dev_read_data(mouse_id)
                # self.mice_task = append_dataframe(self.mice_task, tmp[0], mouse_id)
                # self.probability = append_dataframe(self.probability, tmp[1], mouse_id)
                # self.task_prob = append_dataframe(self.task_prob, tmp[2], mouse_id)
                # self.mice_delta = append_dataframe(self.mice_delta, tmp[3], mouse_id)
                # # append_dataframe(self.fig_prob, tmp[4], mouse_id)
                # self.fig_prob[mouse_id] = self.fig_prob[mouse_id].append(tmp[4])
                # self.pattern_prob = append_dataframe(self.pattern_prob, tmp[5], mouse_id)
                # TODO entropy_analyze
        else:
            # all 0: 2.probability
            # 3.task_prob, 4.self.delta, 5.self.fig_prob_tmp, 6.pattern
            for mouse_id in self.mouse_no:
                try:
                    print('mouse_id={}'.format(mouse_id))
                    # self.data_file = "{}no{:03d}_action.csv".format(self.logpath, mouse_id)
                    # self.mice_task[mouse_id], self.probability[mouse_id], self.task_prob[mouse_id], self.mice_delta[
                    #     mouse_id], self.fig_prob[mouse_id], self.pattern_prob[mouse_id] = self.read_data()
                    self.data_file = os.path.join(self.logpath, "no{:03d}_action.csv".format(mouse_id))
                    tmp = self.read_data()
                    self.mice_task = append_dataframe(self.mice_task, tmp[0], mouse_id)
                    # 0
#                     self.probability = append_dataframe(self.probability, tmp[1], mouse_id)
                    self.task_prob = append_dataframe(self.task_prob, tmp[2], mouse_id)
#                     self.mice_delta = append_dataframe(self.mice_delta, tmp[3], mouse_id)
                    # 単体
#                     self.fig_prob = append_dataframe(self.fig_prob, tmp[4], mouse_id)
#                     self.pattern_prob = append_dataframe(self.pattern_prob, tmp[5], mouse_id)
#                     self.mice_entropy = append_dataframe(self.mice_entropy, tmp[6], mouse_id)
                except Exception as e:
                    print("error! no {}".format(mouse_id))
                    print(e)
                    continue
            self.export_csv()
        print('done')

    def read_data(self):

        def rehash_session_id():
            data = pd.read_csv(self.data_file, names=header, parse_dates=[0], dtype={'hole_no': 'str'})
            self.session_id = 0
            print("max_id_col:{}".format(len(data)))

            def remove_terminate(index):
                if data.at[index, "event_type"] == data.at[index + 1, "event_type"] and data.at[
                    index, "event_type"] == "start":
                    data.drop(index, inplace=True)

            def rehash(x_index):
                start_task = data.head(1).task.values[0]
                if data.at[data.index[x_index], "task"] == start_task:
                    if (x_index == 0 or data.shift(1).at[data.index[x_index], "event_type"] == "start") and \
                            len(data[:x_index][data.session_id == 0]) == 0:
                        self.session_id = 0
                        return 0
                    self.session_id = self.session_id + 1
                    return self.session_id
                if data.at[data.index[x_index], "event_type"] == "start":
                    self.session_id = self.session_id + 1
                    return self.session_id
                else:
                    return self.session_id

            list(map(remove_terminate, data.index[:-1]))
            data.reset_index(drop=True, inplace=True)
            data["session_id"] = list(map(rehash, data.index))
            data = data[
                data.session_id.isin(data.session_id[data.event_type.isin(["reward", "failure", "time over"])])]
            data.reset_index(drop=True, inplace=True)
            self.session_id = 0
            data["session_id"] = list(map(rehash, data.index))
            print("{} ; {} done".format(datetime.now(), sys._getframe().f_code.co_name))
            return data

        def add_timedelta():
            data = self.data
            data = data[data.session_id.isin(data[data.event_type.isin(['reward', 'failure'])]["session_id"])]
            deltas = {}
            for task in self.tasks:
                def calculate(session):
                    delta_df = pd.DataFrame()
                    # reaction time
                    current_target = data[data.session_id.isin([session])]
                    if bool(sum(current_target["event_type"].isin(["task called"]))):
                        task_call = current_target[current_target["event_type"] == "task called"]
                        task_end = current_target[current_target["event_type"].isin(["nose poke", "failure"])]
                        reaction_time = task_end.at[task_end.index[0], "timestamps"] - task_call.at[
                            task_call.index[0], "timestamps"]
                        # 連続無報酬期間
                        previous_reward = data[
                            (data["event_type"] == "reward") & (
                                    data["timestamps"] < task_call.at[task_call.index[0], "timestamps"])].tail(1)
                        norewarded_time = task_call.at[task_call.index[0], "timestamps"] - previous_reward.at[
                            previous_reward.index[0], "timestamps"]
                        correct_failure = "correct" if bool(
                            sum(current_target["event_type"].isin(["reward"]))) else "failure"
                        # df 追加
                        delta_df = delta_df.append(
                            {'session_id': session,
                             'type': 'reaction_time',
                             'noreward_duration_sec': pd.to_timedelta(norewarded_time) / np.timedelta64(1, 's'),
                             'reaction_time_sec': pd.to_timedelta(reaction_time) / np.timedelta64(1, 's'),
                             'correct_failure': correct_failure},
                            ignore_index=True)
                    # reward latency
                    if bool(sum(current_target["event_type"].isin(["reward"]))) and bool(
                            sum(current_target["event_type"].isin(["task called"]))):
                        nose_poke = current_target[current_target["event_type"] == "nose poke"]
                        reward_latency = current_target[current_target["event_type"] == "magazine nose poked"]
                        reward_latency = reward_latency.at[reward_latency.index[0], "timestamps"] - \
                                         nose_poke.at[nose_poke.index[0], "timestamps"]
                        previous_reward = data[
                            (data["event_type"] == "reward") & (
                                    data["timestamps"] < nose_poke.at[nose_poke.index[0], "timestamps"])].tail(1)
                        norewarded_time = nose_poke.at[nose_poke.index[0], "timestamps"] - previous_reward.at[
                            previous_reward.index[0], "timestamps"]
                        delta_df = delta_df.append(
                            {'session_id': session,
                             'type': 'reward_latency',
                             'noreward_duration_sec': pd.to_timedelta(norewarded_time) / np.timedelta64(1, 's'),
                             'reward_latency_sec': pd.to_timedelta(reward_latency) / np.timedelta64(1, 's')
                             }, ignore_index=True)
                    return delta_df

                delta_df = data[data.task == task].session_id.drop_duplicates().map(calculate)
                deltas[task] = pd.concat(list(delta_df), sort=False) if len(delta_df) else delta_df
            print("{} ; {} done".format(datetime.now(), sys._getframe().f_code.co_name))
            return deltas

        def add_hot_vector():
            #    data = data[data["event_type"].isin(["reward", "failure", "time over"])]
            data = self.data
            # data = data[data[".seevent_type"].isin(["reward", "failure", "time over"])]
            # data = data[data["task"].isin(self.tasks)]

            data = data.reset_index(drop=True)
            # task interval
            task_start_index = [0]
            for i in range(1, len(data)):
                if not data["task"][i] == data["task"][i - 1]:
                    task_start_index.append(i)

            data["hole_correct"] = -1
            data["hole_failure"] = -1
            data["is_correct"] = -1
            data["is_failure"] = -1
            data["is_omission"] = -1
            data["cumsum_correct"] = -1
            data["cumsum_failure"] = -1
            data["cumsum_omission"] = -1
            data["cumsum_correct_taskreset"] = -1
            data["cumsum_failure_taskreset"] = -1
            data["cumsum_omission_taskreset"] = -1
            for hole_no in range(1, 9 + 1, 2):
                data["is_hole{}".format(str(hole_no))] = -1

            # print(data)

            # hole_information
            # warning mettyaderu zone
            # SettingWithCopyWarning
            data.loc[data["event_type"].isin(["reward"]), 'hole_correct'] = data["hole_no"]
            data.loc[~data["event_type"].isin(["reward"]), 'hole_correct'] = np.nan
            data.loc[data["event_type"].isin(["failure"]), 'hole_failure'] = data["hole_no"]
            data.loc[~data["event_type"].isin(["failure"]), 'hole_failure'] = np.nan

            data.loc[data["event_type"].isin(["reward"]), 'is_correct'] = 1
            data.loc[~data["event_type"].isin(["reward"]), 'is_correct'] = np.nan
            data.loc[data["event_type"].isin(["failure"]), 'is_failure'] = 1
            data.loc[~data["event_type"].isin(["failure"]), 'is_failure'] = np.nan
            data.loc[data["event_type"].isin(["time over"]), 'is_omission'] = 1
            data.loc[~data["event_type"].isin(["time over"]), 'is_omission'] = np.nan

            data["cumsum_correct"] = data["is_correct"].cumsum(axis=0)
            data["cumsum_failure"] = data["is_failure"].cumsum(axis=0)
            data["cumsum_omission"] = data["is_omission"].cumsum(axis=0)

            for hole_no in range(1, 9 + 1, 2):
                data.loc[data['hole_no'] == str(hole_no), "is_hole{}".format(hole_no)] = 1
                data.loc[~(data['hole_no'] == str(hole_no)), "is_hole{}".format(hole_no)] = None

            # cumsum
            # for i in range(0, len(task_start_index)):
            #     index_start = task_start_index[i]
            #     index_end = len(data)
            #     if i < len(task_start_index) - 1:
            #         index_end = task_start_index[i + 1]
            #     pre_correct = data["cumsum_correct"][index_start] if not i == 0 else 0
            #     pre_incorrect = data["cumsum_incorrect"][index_start] if not i == 0 else 0
            #     pre_omission = data["cumsum_omission"][index_start] if not i == 0 else 0
            #     # warning mettyaderu zone
            #     data["cumsum_correct_taskreset"][index_start:index_end] = data["cumsum_correct"][
            #                                                               index_start:index_end] - \
            #                                                               pre_correct
            #     data["cumsum_incorrect_taskreset"][index_start:index_end] = data["cumsum_incorrect"][
            #                                                                 index_start:index_end] - \
            #                                                                 pre_incorrect
            #     data["cumsum_omission_taskreset"][index_start:index_end] = data["cumsum_omission"][
            #                                                                index_start:index_end] - \
            #                                                                pre_omission
            def add_cumsum():
                data["cumsum_correct_taskreset"] = data["is_correct"].fillna(0)
                data["cumsum_failure_taskreset"] = data["is_failure"].fillna(0)
                data["cumsum_omission_taskreset"] = data["is_omission"].fillna(0)
                data["cumsum_correct_taskreset"] = data.groupby("task")["cumsum_correct_taskreset"].cumsum()
                data["cumsum_failure_taskreset"] = data.groupby("task")["cumsum_failure_taskreset"].cumsum()
                data["cumsum_omission_taskreset"] = data.groupby("task")["cumsum_omission_taskreset"].cumsum()

#             add_cumsum()

            # burst
            # data["burst_group"] = 1
            # for i in range(1, len(data)):
            #     if data["timestamps"][i] - data["timestamps"][i - 1] <= datetime.timedelta(seconds=60):
            #         data["burst_group"][i] = data["burst_group"][i - 1]
            #         continue
            #     data["burst_group"][i] = data["burst_group"][i - 1] + 1
            print("{} ; {} done".format(datetime.now(), sys._getframe().f_code.co_name))
            return data

        def calc_entropy(section=150):
            data = self.data[self.data.event_type.isin(['reward', 'failure'])]

            def min_max(x, axis=None):
                np.array(x)
                min = np.array(x).min(axis=axis)
                max = np.array(x).max(axis=axis)
                result = (x - min) / (max - min)
                return result

            # entropy
            ent = [np.nan] * section
            for i in range(0, len(data) - section):
                denominator = float(section)
                # sum([data["is_hole{}".format(str(hole_no))][i:i + 150].sum() for hole_no in range(1, 9 + 1, 2)])
                current_entropy = min_max(
                    [data["is_hole{}".format(str(hole_no))][i:i + section].sum() /
                     denominator for hole_no in [1, 3, 5, 7, 9]])
                ent.append(entropy(current_entropy, base=2))
            # region Description
            # data[data.event_type.isin(['reward', 'failure'])]["hole_choice_entropy"] = ent
            print("{} ; {} done".format(datetime.now(), sys._getframe().f_code.co_name))
            return pd.DataFrame(ent).fillna(0.0).values.tolist()
            # endregion

        def count_task() -> dict:
            dc = self.data[self.data["event_type"].isin(["reward", "failure"])]
            # dc = self.data[self.data["event_type"].isin(["reward", "failure", "time over"])]

            dc = dc.reset_index()

            after_c_all_task = {}
            after_f_all_task = {}

            after_c_starts_task = {}
            after_f_starts_task = {}

            prob_index = ["c_same", "c_diff", "c_omit", "c_checksum", "f_same", "f_diff", "f_omit", "f_checksum",
                          "c_NotMax",
                          "f_NotMax", "o_NotMax"]
            forward_trace = 10

            for task in self.tasks:
                after_c_starts_task[task] = dc[(dc["is_correct"] == 1) & (dc["task"] == task)]
                after_f_starts_task[task] = dc[(dc["is_failure"] == 1) & (dc["task"] == task)]
                after_c_all_task[task] = float(len(after_c_starts_task[task]))
                after_f_all_task[task] = float(len(after_f_starts_task[task]))

                prob = pd.DataFrame(columns=prob_index, index=range(1, forward_trace)).fillna(0.0)
                # correctスタート
                for idx, dt in after_c_starts_task[task].iterrows():
                    for j in range(1, min(forward_trace, len(dc) - idx)):
                        #                    for j in range(1, min(forward_trace, len(self.data_cio) - idx)):
                        # 報酬を得たときと同じ選択(CF両方)をしたときの処理
                        if dt["hole_no"] == dc["hole_no"][idx + j]:
                            prob["c_same"][j] = prob["c_same"][j] + 1
                        # omissionの場合
                        elif dc["is_omission"][idx + j] == 1:
                            prob["c_omit"][j] = prob["c_omit"][j] + 1
                        elif dt["hole_no"] != dc["hole_no"][idx + j]:
                            prob["c_diff"][j] = prob["c_diff"][j] + 1

                # incorrectスタート
                for idx, dt in after_f_starts_task[task].iterrows():
                    for j in range(1, min(forward_trace, len(dc) - idx)):
                        #                    for j in range(1, min(forward_trace, len(self.data_cio) - idx)):
                        if dt["hole_no"] == dc["hole_no"][idx + j]:
                            prob["f_same"][j] = prob["f_same"][j] + 1
                        elif dc["is_omission"][idx + j] == 1:
                            prob["f_omit"][j] = prob["f_omit"][j] + 1
                        elif dt["hole_no"] != dc["hole_no"][idx + j]:
                            prob["f_diff"][j] = prob["f_diff"][j] + 1

                # calculate
                prob["c_same"] = prob["c_same"] / after_c_all_task[task] if not after_c_all_task[task] == 0 else 0.0
                prob["c_diff"] = prob["c_diff"] / after_c_all_task[task] if not after_c_all_task[task] == 0 else 0.0
                prob["c_omit"] = prob["c_omit"] / after_c_all_task[task] if not after_c_all_task[task] == 0 else 0.0
                prob["c_checksum"] = prob["c_same"] + prob["c_diff"] + prob["c_omit"]
                prob["f_same"] = prob["f_same"] / after_f_all_task[task] if not after_f_all_task[task] == 0 else 0.0
                prob["f_diff"] = prob["f_diff"] / after_f_all_task[task] if not after_f_all_task[task] == 0 else 0.0
                prob["f_omit"] = prob["f_omit"] / after_f_all_task[task] if not after_f_all_task[task] == 0 else 0.0
                prob["f_checksum"] = prob["f_same"] + prob["f_diff"] + prob["f_omit"]

                task_prob[task] = prob
            print("{} ; {} done".format(datetime.now(), sys._getframe().f_code.co_name))

        # TODO 結構な確率でエラー吐く
        def analyze_pattern(bit=self.bit):
            fig_prob = {}
            pattern_range = range(0, pow(2, bit))
            for task in self.tasks:
                pattern[task] = {}
                fig_prob[task] = {"fig1": pd.DataFrame(columns=["{:b}".format(i).zfill(bit) for i in pattern_range]
                                                       ).fillna(0.0),
                                  "fig2": pd.DataFrame(columns=["{:b}".format(i).zfill(bit) for i in pattern_range]
                                                       ).fillna(0.0),
                                  "fig3": pd.DataFrame(columns=["{:b}".format(i).zfill(bit) for i in pattern_range],
                                                       ).fillna(0.0)}
                data = self.data[
                    (self.data.task == task) & (
                        self.data.event_type.isin(["reward", "failure", "time over"]))].reset_index(drop=True)
                data_ci = data[data.event_type.isin(["reward", "failure"])].reset_index(drop=True)
                # search pattern

                f_pattern_matching = lambda x: sum([
                    (not np.isnan(data_ci.at[x + (bit - i - 1), "is_correct"])) * pow(2, i)
                    for i in range(0, bit)])
                pattern[task] = data_ci[:-(bit - 1)].assign(pattern=data_ci[:-(bit - 1)].index.map(f_pattern_matching))
                # count

                f_same_base = lambda x: [data_ci.at[data_ci[data_ci.session_id == x].index[0], "hole_no"] == \
                                         data_ci.at[data_ci[data_ci.session_id == x].index[0] + idx, "hole_no"] for idx
                                         in range(1, bit)]
                f_same_prev = lambda x: [data_ci.at[data_ci[data_ci.session_id == x].index[0] + idx - 1, "hole_no"] == \
                                         data_ci.at[data_ci[data_ci.session_id == x].index[0] + idx, "hole_no"] for idx
                                         in range(1, bit)]
                f_omit = lambda x: [bool(data.at[data[data.session_id == x].index[0] + idx, "is_omission"]) for idx in
                                    range(1, bit)]
                functions = lambda x: [f_same_base(x), f_same_prev(x), f_omit(x)]
                # pattern count -> probability
                for pat_tmp in pattern_range:
                    f_p = pd.DataFrame(list(pattern[task][pattern[task].pattern == pat_tmp].session_id.map(functions)),
                                       columns=["fig1", "fig2", "fig3"]).fillna(0.0)
                    if len(f_p):
                        fig_prob[task]["fig1"]["{:b}".format(pat_tmp).zfill(bit)] = pd.DataFrame(
                            list(f_p.fig1)).sum().fillna(0.0) / len(pattern[task][pattern[task].pattern == pat_tmp])
                        fig_prob[task]["fig2"]["{:b}".format(pat_tmp).zfill(bit)] = pd.DataFrame(
                            list(f_p.fig2)).sum().fillna(0.0) / len(pattern[task][pattern[task].pattern == pat_tmp])
                        fig_prob[task]["fig3"]["{:b}".format(pat_tmp).zfill(bit)] = pd.DataFrame(
                            list(f_p.fig3)).sum().fillna(0.0) / len(pattern[task][pattern[task].pattern == pat_tmp])
                    else:
                        for figure in list(f_p.columns):
                            fig_prob[task][figure]["{:b}".format(pat_tmp).zfill(bit)] = fig_prob[task][figure][
                                "{:b}".format(pat_tmp).zfill(bit)].fillna(0.0)
                    for figure in list(f_p.columns):
                        fig_prob[task][figure].at["n", "{:b}".format(pat_tmp).zfill(bit)] = len(
                            pattern[task][pattern[task].pattern == pat_tmp])
            # save
            self.fig_prob_tmp = fig_prob
            print("{} ; {} done".format(datetime.now(), sys._getframe().f_code.co_name))
            return pattern

        def burst():
            def calc_burst(session):
                if session == 0:
                    self.burst_id = 0
                    return self.burst_id
                if data.at[data.index[data.session_id == session][0], "timestamps"] - \
                        data.at[data.index[data.session_id == session - 1][0], "timestamps"] >= timedelta(
                    seconds=60):
                    self.burst_id = self.burst_id + 1
                return self.burst_id

            data = self.data[self.data.event_type.isin(["reward", "failure", "time over"])]
            self.data = self.data.merge(
                pd.DataFrame({"session_id": self.data.session_id.unique(),
                              "burst": list(map(calc_burst, self.data.session_id.unique()))}),
                on="session_id", how="left")
            print("{} ; {} done".format(datetime.now(), sys._getframe().f_code.co_name))

        def entropy_analyzing(section=10, bit=self.bit):
            data = self.data[(self.data.event_type.isin(["reward", "failure"])) & (self.data.task.isin(self.tasks))]
            entropy_df = data[
                ["session_id", "task", "entropy_{}".format(section), "entropy_after_{}".format(section), "pattern"]]
            count_correct = lambda pat: np.nan if np.isnan(pat) else "{:b}".format(int(pat)).zfill(bit).count("1")
            entropy_df["correctnum_{}bit".format(bit)] = list(map(count_correct, entropy_df.pattern))
            # entropy_df["mouse_no"] = self.mouse_no
            print("{} ; {} done".format(datetime.now(), sys._getframe().f_code.co_name))
            return entropy_df

        # main
        header = ["timestamps", "task", "session_id", "correct_times", "event_type", "hole_no"]
        pattern = {}
        task_prob = {}
        self.data = rehash_session_id()
        self.data = add_hot_vector()
        self.data_ci = self.data
#         self.data.loc[
#             self.data.index[self.data.event_type.isin(['reward', 'failure'])], "hole_choice_entropy"] = calc_entropy()
        # ent_section = 10
        # self.data.loc[
        #     self.data.index[self.data.event_type.isin(['reward', 'failure'])], "entropy_10"] = calc_entropy(ent_section)
        # self.data.loc[
        #     self.data.index[self.data.event_type.isin(['reward', 'failure'])], "entropy_after_10"] = \
        #     self.data.loc[self.data.index[self.data.event_type.isin(
        #         ['reward', 'failure'])], "entropy_10"][(ent_section + self.bit - 1):].to_list() + \
        #     ([np.nan] * (ent_section + self.bit - 1))
#         ent_section3 = 150
#         self.data.loc[
#             self.data.index[self.data.event_type.isin(['reward', 'failure'])], "entropy_{}".format(
#                 ent_section3)] = calc_entropy(ent_section3)
#         ent_section2 = 300
#         self.data.loc[
#             self.data.index[self.data.event_type.isin(['reward', 'failure'])], "entropy_{}".format(
#                 ent_section2)] = calc_entropy(ent_section2)
#         ent_section = 50
#         self.data.loc[
#             self.data.index[self.data.event_type.isin(['reward', 'failure'])], "entropy_{}".format(
#                 ent_section)] = calc_entropy(ent_section)
#         self.data.loc[
#             self.data.index[self.data.event_type.isin(['reward', 'failure'])], "entropy_after_{}".format(
#                 ent_section)] = \
#             self.data.loc[self.data.index[self.data.event_type.isin(
#                 ['reward', 'failure'])], "entropy_{}".format(ent_section)][(ent_section + self.bit - 1):].tolist() + \
#             ([np.nan] * (ent_section + self.bit - 1))
#         self.delta = add_timedelta()
        self.data_not_omission = self.data[
            ~self.data.session_id.isin(self.data.session_id[self.data.event_type.isin(["time over"])])]

        # action Probability
        after_c_all = float(len(self.data[self.data["is_correct"] == 1]))
        after_f_all = float(len(self.data[self.data["is_failure"] == 1]))
        after_c_starts = self.data[self.data["is_correct"] == 1]
        after_f_starts = self.data[self.data["is_failure"] == 1]
        after_c_all_task = {}
        after_f_all_task = {}
        after_c_starts_task = {}
        after_f_starts_task = {}
        for task in self.tasks:
            after_c_starts_task[task] = self.data[(self.data["is_correct"] == 1) & (self.data["task"] == task)]
            after_f_starts_task[task] = self.data[(self.data["is_failure"] == 1) & (self.data["task"] == task)]
            after_c_all_task[task] = float(len(after_c_starts_task[task]))
            after_f_all_task[task] = float(len(after_f_starts_task[task]))

        # after_o_all = len(data[data["event_type"] == "time over"])
        forward_trace = 5
        prob_index = ["c_same", "c_diff", "c_omit", "c_checksum", "f_same", "f_diff", "f_omit", "f_checksum",
                      "c_NotMax",
                      "f_NotMax", "o_NotMax"]
        probability = pd.DataFrame(columns=prob_index, index=range(1, forward_trace + 1)).fillna(0.0)

        #        count_all()
        count_task()
        # bit analyze
#         pp = analyze_pattern(self.bit)
#         pp = pd.concat([pp[task].loc[:, pp[task].columns.isin(["session_id", "pattern"])] for task in self.tasks])
#         self.data = pd.merge(self.data, pp, how='left')
#         # 2 bit analyze
#         pp = analyze_pattern(2)
#         pp = pd.concat([pp[task].loc[:, pp[task].columns.isin(["session_id", "pattern"])] for task in self.tasks])
#         pp = pp.rename(columns={"pattern": "pattern_2bit"})
#         self.data = pd.merge(self.data, pp, how='left')
#         burst()
        # entropy analyzing
#         self.entropy_analyze = entropy_analyzing(section=50)
        # self.entropy_analyze.concat(entropy_analyzing(section=50))
        return self.data, probability, task_prob, self.delta, self.fig_prob_tmp, pattern, self.entropy_analyze

    def dev_read_data(self, mouse_no):
        task_prob = {}
        delta = {}
        fig_prob = {}
        pattern_prob = {}
        data = pd.read_csv(os.path.join(self.logpath, 'data/no{:03d}_{}_data.csv'.format(mouse_no, "all")))
        probability = pd.read_csv(os.path.join(self.logpath, 'data/no{:03d}_{}_prob.csv'.format(mouse_no, "all")))

        for task in self.tasks:
            delta[task] = pd.read_csv(os.path.join(self.logpath, 'data/no{:03d}_{}_time.csv'.format(mouse_no, task)))
            task_prob[task] = pd.read_csv(
                os.path.join(self.logpath, 'data/no{:03d}_{}_prob.csv'.format(mouse_no, task)))
            fig_prob[task] = {}
            for fig_num in ["fig1", "fig2", "fig3"]:
                fig_prob[task][fig_num] = pd.read_csv(
                    os.path.join(self.logpath, 'data/no{:03d}_{}_{}_prob_fig.csv'.format(mouse_no, task, fig_num)),
                    index_col=0)
            pattern_prob[task] = pd.read_csv(
                os.path.join(self.logpath, 'data/no{:03d}_{}_pattern.csv'.format(mouse_no, task)))
        return data, probability, task_prob, delta, fig_prob, pattern_prob

    def export_csv(self, mouse_no=None):
        self.mice_task.to_csv(os.path.join(self.logpath, 'data/{}_data.csv'.format("all")))
#         self.probability.to_csv(
#             os.path.join(self.logpath, 'data/{}_prob.csv'.format("all")))
        for task in self.tasks:
#             self.mice_delta[task].to_csv(os.path.join(self.logpath, 'data/{}_time.csv'.format(task)))
            # AttributeError: 'Series' object has no attribute 'type'
#             reward_latency_data = self.mice_delta[task][self.mice_delta[task].type == "reward_latency"]
#             reward_latency_data.to_csv(os.path.join(self.logpath, 'data/{}_rewardlatency.csv'.format(task)))
            self.task_prob[task].to_csv(os.path.join(self.logpath, 'data/{}_prob.csv'.format(task)))
#             self.pattern_prob[task].to_csv(os.path.join(self.logpath, 'data/{}_pattern.csv'.format(task)))
#             [self.fig_prob[task][fig_num].to_csv(
#                 os.path.join(self.logpath, 'data/prob_fig{}_{}.csv'.format(fig_num, task))) for
#                 fig_num in ["fig1", "fig2", "fig3"]]
            # pattern
            # [self.entropy_analyze[
            #      (self.entropy_analyze["correctnum_{}bit".format(10,self.bit)] == count) &
            #      (self.entropy_analyze["task"] == task)  # & (
            #      # self.entropy_analyze["mouse_no"] == mouse_no)
            #      ][10:-10].to_csv(
            #     '{}data/pattern_entropy/summary/no{:03d}_{}_entropy_pattern_count_{}_summary.csv'.format(
            #         self.logpath, mouse_no, task, int(count))) for count in
            #     self.entropy_analyze["correctnum_{}bit".format(10,self.bit)][
            #         ~np.isnan(self.entropy_analyze["correctnum_{}bit".format(10,self.bit)])].unique()]
            # [self.entropy_analyze[
            #      (self.entropy_analyze["pattern"] == pattern) &
            #      (self.entropy_analyze["task"] == task)  # & (
            #      # self.entropy_analyze["mouse_no"] == mouse_no)
            #      ][10:-10].to_csv(
            #     '{}data/pattern_entropy/no{:03d}_{}_entropy_pattern_{:04b}.csv'.format(
            #         self.logpath, mouse_no, task, int(pattern))) for
            #     pattern in self.data.pattern[~np.isnan(self.data.pattern)].unique()]
#             [self.mice_entropy[(self.mice_entropy["task"] == task)  # & (
#                  # self.entropy_analyze["mouse_no"] == mouse_no)
#              ][50:-50][(self.mice_entropy["correctnum_{}bit".format(self.bit)] == count)].to_csv(
#                 '{}data/pattern_entropy/summary/{}_entropy_pattern{:d}_count_{}_summary.csv'.format(
#                     self.logpath, task, 50, int(count))) for count in
#                 # self.entropy_analyze["correctnum_{}bit".format(self.bit)][
#                 # ~np.isnan(self.entropy_analyze["correctnum_{}bit".format(self.bit)])].unique()]
#                 range(0, self.bit)]
#             [self.mice_entropy[
#                  (self.mice_entropy["task"] == task)  # & (
#                  # self.entropy_analyze["mouse_no"] == mouse_no)
#              ][50:-50][(self.mice_entropy["pattern"] == pattern)].to_csv(
#                 '{}/data/pattern_entropy/{}_entropy{:d}_pattern_{:04b}.csv'.format(
#                     self.logpath, task, 50, int(pattern))) for
#                 pattern in self.data.pattern[~np.isnan(self.data.pattern)].unique()]

        print("{} ; {} done".format(datetime.now(), sys._getframe().f_code.co_name))


def export_onehole_csv(tdata, mice, tasks):
    df = dict(zip(tasks, [pd.DataFrame() for _ in tasks]))
    for task in tasks:
        for mouse_id in mice:
            data_timedelta = tdata.mice_delta[task][(tdata.mice_delta[task].mouse_id.isin([mouse_id]))]
            data_action = tdata.mice_task[
                (tdata.mice_task.task.isin([task])) & (tdata.mice_task.mouse_id.isin([mouse_id]))]
            tmp_df = data_action.merge(
                data_timedelta[(data_timedelta.type.isin(["reward_latency"]))][
                    ["reward_latency_sec", "session_id"]],
                on="session_id", how='left')
            tmp_df = tmp_df.merge(
                data_timedelta[(data_timedelta.type.isin(["reaction_time"]))].drop(
                    columns="reward_latency_sec"),
                on="session_id", how='left')
            tmp_df.loc[
                tmp_df.index[tmp_df.event_type.isin(["failure", "omission"])].to_list(), "reaction_time_sec"] = \
                data_timedelta.reaction_time_sec[
                    (data_timedelta.type.isin(["reaction_time"])) & (~data_timedelta.session_id.isin(
                        (data_timedelta.session_id[
                             data_timedelta.type.isin(["reward_latency"])].drop_duplicates().to_list())))].to_list()

            # nosepoke after reward

            def check_poke_after(correct_num):
                data = data_action[data_action.cumsum_correct_taskreset.isin([correct_num])]
                return bool(sum(data.event_type.isin(["nose poke after rew"])))

            poke_after = list(map(check_poke_after, data_action[
                data_action.event_type.isin(["reward", "failure", "omission"])].cumsum_correct_taskreset.unique()))
            tmp_df = tmp_df.assign(poke_after=np.nan)[tmp_df.event_type.isin(["reward", "failure", "time over"])]
            tmp_df.loc[tmp_df.index[tmp_df.event_type.isin(["reward"])], "poke_after"] = poke_after
            tmp_df = tmp_df[
                ["timestamps", "task", "event_type", "reaction_time_sec", "reward_latency_sec", "poke_after"]]
            df[task] = pd.concat([df[task], tmp_df])
        df[task].to_csv(os.path.join("data", "{}_task-{}_1holedata.csv".format("all", task)), index=False)


def view_averaged_prob_same_prev(tdata, mice, tasks):
    m = []
    t = []
    csame = []
    fsame = []

    for mouse_id in mice:
        for task in tasks:
            m += [mouse_id]
            t += [task]
            csame += [tdata.task_prob[task][tdata.task_prob[task].mouse_id == mouse_id]['c_same']]
            fsame += [tdata.task_prob[task][tdata.task_prob[task].mouse_id == mouse_id]['f_same']]

    after_prob_df = pd.DataFrame(
        data={'mouse_id': m, 'task': t, 'c_same': csame, 'f_same': fsame},
        columns=['mouse_id', 'task', 'c_same', 'f_same']
    )

    plt.style.use('default')
    fig = plt.figure(figsize=(8, 4), dpi=100)
    for task in tasks:
        plt.subplot(1, len(tasks), tasks.index(task) + 1)

        c_same = np.array(after_prob_df[after_prob_df['task'].isin([task])]['c_same'].to_list())
        c_same_avg = np.mean(c_same, axis=0)
        c_same_std = np.std(c_same, axis=0)
        c_same_var = np.var(c_same, axis=0)

        f_same = np.array(after_prob_df[after_prob_df['task'].isin([task])]['f_same'].to_list())
        f_same_avg = np.mean(f_same, axis=0)
        f_same_var = np.var(f_same, axis=0)

        xlen = len(c_same_avg)
        xax = np.array(range(1, xlen + 1))
        plt.plot(xax, c_same_avg, label="rewarded start")
        plt.errorbar(xax, c_same_avg, yerr=c_same_var, capsize=2, fmt='o', markersize=1, ecolor='black',
                     markeredgecolor="black", color='w', lolims=True)

        plt.plot(np.array(range(1, xlen + 1)), f_same_avg, label="no-rewarded start")
        plt.errorbar(xax, f_same_avg, yerr=f_same_var, capsize=2, fmt='o', markersize=1, ecolor='black',
                     markeredgecolor="black", color='w', uplims=True)

        # plt.ion()
        plt.xticks(np.arange(1, xlen + 1, 1))
        plt.xlim(0.5, xlen + 0.5)
        plt.ylim(0, 1.05)
        if tasks.index(task) == 0:
            plt.ylabel('P (same choice)')
            plt.legend()
        plt.xlabel('Trial')
        plt.title('{}'.format(task))
        
    # plt.savefig('fig/{}_prob_all4.png'.format(graph_ins.exportpath))
    plt.rcParams["font.size"] = 18
    plt.savefig('fig/prob_all4_{}.png'.format(tasks[0]))
    plt.show()


def view_summary(tdata, mice, tasks, x="session_id"):
    for mouse_id in mice:
        def plot(mdf, task="all"):
            labels = ["failure", "correct", "omission"]
            df = mdf[mdf["event_type"].isin(["reward", "failure", "time over"])]

            # past time
            df.timestamps = (df.timestamps - df.iat[0, 0]).apply(lambda time: time.total_seconds())

            # entropy
            fig, ax = plt.subplots(4, 1, sharex="all", figsize=(15, 8), dpi=100)
            fig.suptitle('no{:03} summary {}'.format(mouse_id, task), y=1.0)
            plt.subplots_adjust(hspace=0, bottom=0)

            ax[0].plot(df[df.event_type.isin(["reward", "failure"])][x],
                       df[df.event_type.isin(["reward", "failure"])]['hole_choice_entropy'])
            ax[0].set_ylabel('Entropy (bit)')
            ax[0].set_xlim(df[x].min(), df[x].max())
            if task == "all":
                collection = collections.BrokenBarHCollection.span_where(df[x].to_numpy(), ymin=-100, ymax=100,
                                                                         where=(df.task.isin(tasks[0::2])),
                                                                         facecolor='lightblue', alpha=0.3)
                ax[0].add_collection(collection)

            # scatter
            colors = ["red", "blue", "black"]
            size = dict(zip(labels, [25, 50, 25]))
            pos = dict(zip(labels, ["bottom", "full", "bottom"]))
            datasets = ([df[df["is_{}".format(flag)] == 1] for flag in labels])
            leg = []
            for dt, la, cl in zip(datasets, labels, colors):
                marker = markers.MarkerStyle("|", pos[la])
                leg.append(ax[1].scatter(dt[x], dt['is_hole1'] * 1, s=size[la], color=cl, marker=marker, label=la))
                ax[1].scatter(dt[x], dt['is_hole3'] * 2, s=size[la], color=cl, marker=marker)
                ax[1].scatter(dt[x], dt['is_hole5'] * 3, s=size[la], color=cl, marker=marker)
                ax[1].scatter(dt[x], dt['is_hole7'] * 4, s=size[la], color=cl, marker=marker)
                ax[1].scatter(dt[x], dt['is_hole9'] * 5, s=size[la], color=cl, marker=marker)
                ax[1].scatter(dt[x], dt['is_omission'] * 0, s=size[la], color=cl, marker=marker)
            ax[1].set_ylabel("Hole")
            ax[1].set_yticks([1, 2, 3, 4, 5])
            ax[1].legend()
            if task == "all":
                collection = collections.BrokenBarHCollection.span_where(df[x].to_numpy(), ymin=-2, ymax=6,
                                                                         where=(df.task.isin(tasks[0::2])),
                                                                         facecolor='lightblue', alpha=0.3)
                ax[1].add_collection(collection)

            # cumsum
            ax[2].plot(df[x], df['cumsum_correct_taskreset'])
            ax[2].plot(df[x], df['cumsum_incorrect_taskreset'])
            ax[2].plot(df[x], df['cumsum_omission_taskreset'])
            ax[2].set_ylabel('Cumulative')
            # ax[2].set_xlabel('Trial')
            ax[2].legend(["correct", "incorrect", "omission"])
            if task == "all":
                collection = collections.BrokenBarHCollection.span_where(df[x].to_numpy(), ymin=-20, ymax=1000,
                                                                         where=(df.task.isin(tasks[0::2])),
                                                                         facecolor='lightblue', alpha=0.3)
                ax[2].add_collection(collection)

            # 100 step move average
            # make dataframe
            df_o = df[df["event_type"].isin(["reward", "failure"])]
            # data = pd.DataFrame(columns=["is_hole{}".format(i) for i in range(1, 10, 2)])
            data_tmp = pd.DataFrame()
            add_average = lambda idx: [df_o[max(0, idx - 100):idx]["is_hole{}".format(i)].sum() /
                                       max(df_o[max(0, idx - 100):idx]["is_hole{}".format(i)].size, 1) for i in
                                       range(1, 10, 2)]
            data_tmp = data_tmp.append(list(map(add_average, list(range(0, len(df_o))))), ignore_index=True)
            # plot
            ax[3].plot(df_o[x], data_tmp)
            ax[3].set_ylabel("moving average action rate")
            # legend
            ax[3].legend(["hole{}".format(i) for i in range(1, 10, 2)])
            if task == "all":
                collection = collections.BrokenBarHCollection.span_where(df_o[x].to_numpy(), ymin=-20,
                                                                         ymax=1000,
                                                                         where=(df_o.task.isin(tasks[0::2])),
                                                                         facecolor='lightblue', alpha=0.3)
                ax[3].add_collection(collection)
            # savefig
            fig.savefig('fig/no{:03d}_{}_summary_{}.png'.format(mouse_id, task, x))
            fig.show()

        data = tdata.mice_task[tdata.mice_task.mouse_id == mouse_id]
        plot(data)
        list(map(plot, [data[data.task == task] for task in tdata.tasks], tdata.tasks))


def view_trial_per_datetime(tdata, mice=[18], task="All5_30"):
    """ for debug """
    # for mouse_no in mice:
    data = tdata.data[
        (tdata.data.event_type.isin(["reward", "failure", "time over"]))
        & (tdata.data.task == task)
        # &(tdata.data.mouse_id == mouse_no)
        ].set_index("timestamps").resample("1H").sum()

    fig = plt.figure(figsize=(15, 8), dpi=100)
    data.plot.bar(y=["is_correct", "is_incorrect", "is_omission"], stacked=True)
    plt.show()


def view_scatter_vs_times_with_burst(tdata, mice=[18], task="All5_30", burst=1):
    """ fig1 B """
    for mouse_id in mice:

        labels = ["correct", "incorrect", "omission"]

        data = tdata.data.assign(
            timestamps=(tdata.data.timestamps - tdata.data.timestamps[0]).dt.total_seconds())  # [mouse_id]
        data = data[data["event_type"].isin(["reward", "failure", "time over"])]
        # data = data[data.burst.isin(data.burst.unique()[data.groupby("burst").burst.count() > burst])]
        burst_time = list(data.burst.unique()[data.groupby("burst").burst.count() > burst])
        fig = plt.figure(figsize=(15, 8), dpi=100)
        fig_subplot = fig.add_subplot(1, 1, 1)
        # plt.title('{:03} summary'.format(mouse_id))
        #    nose_poke_raster(mouse_id, fig.add_subplot(3, 1, 2))

        colors = ["blue", "red", "black"]
        for single_burst in burst_time:
            d = data[data.burst == single_burst]
            datasets = [(d[d["is_{}".format(flag)] == 1]) for flag in labels]
            for dt, la, cl in zip(datasets, labels, colors):
                plt.scatter(dt.timestamps, dt['is_hole1'] * 1, s=15, c=cl)
                plt.scatter(dt.timestamps, dt['is_hole3'] * 2, s=15, c=cl)
                plt.scatter(dt.timestamps, dt['is_hole5'] * 3, s=15, c=cl)
                plt.scatter(dt.timestamps, dt['is_hole7'] * 4, s=15, c=cl)
                plt.scatter(dt.timestamps, dt['is_hole9'] * 5, s=15, c=cl)
                plt.scatter(dt.timestamps, dt['is_omission'] * 0, s=15, c=cl)
            plt.ylabel("Hole")
            plt.xlim(d.timestamps.min() - 30, d.timestamps.max() + 30)
            plt.ylim(0, 5)
            #    plt.xlim(0, len(mdf))

            collection = collections.BrokenBarHCollection.span_where(data.timestamps.to_numpy(), ymin=0, ymax=5,
                                                                     where=(data.burst.isin(burst_time)),
                                                                     facecolor='pink', alpha=0.3)
            fig_subplot.add_collection(collection)
            # save
            # plt.show()
            burst_len = d.timestamps.count()
            if not os.path.isdir(os.path.join(os.getcwd(), "fig", "burst", "len" + str(burst_len))):
                os.mkdir(os.path.join(os.getcwd(), "fig", "burst", "len" + str(burst_len)))
            plt.savefig(os.path.join(os.getcwd(), 'fig', 'burst', "len" + str(burst_len),
                                     'no{:03d}_burst{}_hole_pasttime_burst.png'.format(mouse_id, single_burst)))


def view_trial_per_time(tdata, mice=[18], task="All5_30"):
    """ fig1 C """
    data = tdata.data[
        (tdata.data.event_type.isin(["reward", "failure", "time over"])) &
        (tdata.data.task == task)
        ].set_index("timestamps").resample("1H").sum()
    data = data.set_index(data.index.time).groupby(level=0).mean()
    fig = plt.figure(figsize=(15, 8), dpi=100)
    data.plot.bar(y=["is_correct", "is_incorrect", "is_omission"], stacked=True)
    plt.show()


def view_prob_same_choice_burst(tdata, mice, tasks, burst=1):
    """ fig4 """
    tdata_ci = tdata[tdata.event_type.isin(["reward", "failure"])]
    tdata_ci = tdata_ci[
        tdata_ci.burst.isin(tdata_ci.burst.unique()[tdata_ci.groupby("burst").burst.count() > burst])].reset_index()

    # burst_len limit なし
    # after_prob_df = pd.concat([tdata.task_prob[task].assign(task=task) for task in tasks])

    plt.style.use('default')
    fig, ax = plt.subplots(1, len(tasks), sharey="all", sharex="all", figsize=(8, 4), dpi=100)
    forward_trace = 7

    def calc(mouse_id):
        prob_index = ["c_same", "f_same", "task", "mouse_id"]
        after_prob_df = pd.DataFrame(columns=prob_index)
        lgnd = None
        for task in tasks:
            data_tmp = tdata_ci[(tdata_ci.task.isin([task])) & (tdata_ci.mouse_id == mouse_id)]  # .groupby("burst")
            "burst ごと確率を出す"
            for bst in data_tmp.burst.unique():
                data = data_tmp[data_tmp.burst.isin([bst])].reset_index(drop=True)
                after_correct_all = data.burst[:-forward_trace][data.is_correct == 1].count()
                after_incorrect_all = data.burst[:-forward_trace][data.is_incorrect == 1].count()
                correct_index = data[:-forward_trace][data.is_correct == 1].index
                incorrect_index = data[:-forward_trace][data.is_incorrect == 1].index
                df = pd.DataFrame(columns=range(forward_trace))
                same_correct = \
                    df.append([data[idx:idx + min(forward_trace, len(data))].hole_no == data.hole_no[idx] for idx in
                               correct_index]).sum() if len(correct_index) else df.sum()
                same_incorrect = \
                    df.append([data[idx:idx + min(forward_trace, len(data))].hole_no == data.hole_no[idx] for idx in
                               incorrect_index]).sum() if len(incorrect_index) else df.sum()
                after_prob_df = after_prob_df.append(pd.DataFrame({"c_same": same_correct / after_correct_all,
                                                                   "f_same": same_incorrect / after_incorrect_all,
                                                                   "task": task, "mouse_id": mouse_id,
                                                                   "burst": bst}).fillna(0.0))

                # after_prob_df = after_prob_df.append(pd.DataFrame({"c_same": (same_correct / after_correct_all).mean(),
                #                                                    "f_same": (same_incorrect / after_incorrect_all).mean(),
                #                                                    "task": task, "mouse_id": mouse_id,
                #                                                    "burst": bst}).fillna(0.0), ignore_index=True)
            """ burstごと確率 を平均する"""
            c_same = after_prob_df[
                (after_prob_df['task'].isin([task])) &
                (after_prob_df["mouse_id"] == mouse_id)
                ]['c_same'].groupby(level=0)
            c_same_avg = c_same.mean()[:forward_trace + 1]
            c_same_var = c_same.var()[:forward_trace + 1]

            f_same = after_prob_df[
                (after_prob_df['task'].isin([task])) &
                (after_prob_df["mouse_id"] == mouse_id)
                ]['f_same'].groupby(level=0)
            f_same_avg = f_same.mean()[:forward_trace + 1]
            f_same_var = f_same.var()[:forward_trace + 1]

            # ここから描画
            xlen = c_same_avg.size
            # xax = np.array(range(1, xlen + 1))
            xax = np.array(range(forward_trace + 1))
            ax.plot(xax, c_same_avg, color="orange", label="rewarded start")
            ax.errorbar(xax, c_same_avg, yerr=c_same_var, capsize=2, fmt='o', markersize=1,
                        ecolor='black',
                        markeredgecolor="black", color='w', lolims=True)

            ax.plot(xax, f_same_avg, color="blue", label="no-rewarded start")
            ax.errorbar(xax, f_same_avg, yerr=f_same_var, capsize=2, fmt='o', markersize=1,
                        ecolor='black',
                        markeredgecolor="black", color='w', uplims=True)

            # plt.ion()
            ax.set_xticks(xax)
            ax.set_xlim(-0.5, xlen + 0.5)
            ax.set_ylim(0, 1.05)
            if tasks.index(task) == 0:
                ax.set_ylabel('P (same choice)')
            if tasks.index(task) == int(len(tasks) / 2):
                lgnd = ax.legend(loc="upper center", bbox_to_anchor=(0.5, -0.05), ncol=2,
                                 mode="expand")
            if tasks.index(task) in [0, len(tasks) - 1]:
                ax.set_xlabel('Trial')
            ax.set_title('{}'.format(task))
        # label
        # plt.legend()
        lgnd.get_frame().set_linewidth(0.0)
        plt.savefig("no{:03d}_prob4.png".format(mouse_id))
        plt.show()

    list(map(calc, mice))


def view_sigletask_prob(tdata, mice, task):
    """ fig5 A """
    tdata_ci = tdata.mice_task[tdata.mice_task.event_type.isin(["reward", "failure"])]
    tdata_ci = tdata_ci[tdata_ci.task.isin([task])].reset_index(drop=True)

    plt.style.use('default')
    fig, ax = plt.subplots(1, 2, sharey="all", sharex="all", figsize=(8, 4), dpi=100)
    forward_trace = 7
    range_lim = 50

    def calc(mouse_id):
        prob_index = ["c_same", "f_same", "task", "mouse_id"]
        data_tmp = tdata_ci[(tdata_ci.mouse_id == mouse_id)].reset_index(drop=True)

        for data, index in zip([data_tmp[:range_lim], data_tmp[-range_lim:]], [0, 1]):
            "確率を出す"
            after_prob_df = pd.DataFrame(columns=prob_index)
            data = data.reset_index(drop=True)
            after_correct_all = data.is_correct[data.is_correct == 1].count()
            after_incorrect_all = data.is_incorrect[data.is_incorrect == 1].count()
            correct_index = data[data.is_correct == 1].index
            incorrect_index = data[data.is_incorrect == 1].index
            df = pd.DataFrame(columns=range(forward_trace))
            same_correct = df.append(
                [(data[idx + 1:idx + min(forward_trace + 1, len(data[idx + 1:]))].hole_no == data.hole_no[
                    idx]).reset_index(drop=True).T for idx in correct_index]) * 1 if len(correct_index) else df
            same_incorrect = df.append(
                [(data[idx + 1:idx + min(forward_trace + 1, len(data[idx + 1:]))].hole_no == data.hole_no[
                    idx]).reset_index(drop=True).T for idx in incorrect_index]) * 1 if len(incorrect_index) else df
            same_correct.columns = same_correct.columns + 1
            same_incorrect.columns = same_incorrect.columns + 1
            after_prob_df = after_prob_df.append(pd.DataFrame({"c_same": same_correct.sum() / after_correct_all,
                                                               "f_same": same_incorrect.sum() / after_incorrect_all,
                                                               "task": task, "mouse_id": mouse_id}).fillna(0.0))

            c_same = after_prob_df[
                (after_prob_df['task'].isin([task])) &
                (after_prob_df["mouse_id"] == mouse_id)
                ]['c_same']

            f_same = after_prob_df[
                (after_prob_df['task'].isin([task])) &
                (after_prob_df["mouse_id"] == mouse_id)
                ]['f_same']

            # ここから描画
            xlen = c_same.size
            xax = np.array(range(1, forward_trace + 1))
            ax[index].plot(c_same, color="orange", label="rewarded start")
            ax[index].plot(f_same, color="skyblue", label="no-rewarded start")
            ax[index].set_xticks(xax)
            ax[index].set_xlim(0.5, xlen + 0.5)
            ax[index].set_ylim(0, 1.05)
            ax[index].set_xlabel('Trial')
        # label
        ax[0].set_ylabel('P (same choice)')
        ax[0].set_title('no{:03d}_{}_first{}step'.format(mouse_id, task, range_lim))
        ax[1].set_title('no{:03d}_{}_last{}step'.format(mouse_id, task, range_lim))
        plt.subplots_adjust(top=0.8)
        plt.legend()
        plt.savefig("no{:03d}_prob5_{}.png".format(mouse_id, task))
        plt.show()
        plt.close()

    list(map(calc, mice))


def view_pattern_entropy_summary(tdata, mice, task=None):
    data = tdata.mice_entropy
    average_all = None
    for mouse_id in mice:
        data_tmp = data[mouse_id].groupby(
            ["task", "correctnum_{}bit".format(tdata.bit)])
        mean = data_tmp.mean().reset_index()
        sd = data_tmp.std().reset_index()
        data_tmp = pd.merge(mean, sd, on=["task", "correctnum_{}bit".format(tdata.bit)], suffixes=["_mean", "_sd"])
        data_tmp = data_tmp.loc[:, data_tmp.columns.str.startswith(("task", "correctnum", "entropy"))].assign(
            mouse_id=mouse_id)
        average_all = data_tmp if isinstance(average_all, type(None)) else average_all.append(data_tmp)
    for group_info, data_tmp in average_all.groupby(["task", "correctnum_{}bit".format(tdata.bit)]):
        fig, ax = plt.subplots(1, 1)
        # error bar
        # ax.errorbar(data_tmp.loc[:, data_tmp.columns.str.endswith("mean")].columns.to_numpy().reshape(2),
        #             data_tmp.loc[:, data_tmp.columns.str.endswith("mean")].to_numpy(),
        #             yerr=data_tmp.loc[:, data_tmp.columns.str.endswith("sd")].to_numpy(),
        #             ecolor="black")
        # ax.errorbar(data_tmp.loc[:, data_tmp.columns.str.endswith("mean")].columns,
        #             data_tmp.loc[:, data_tmp.columns.str.endswith("mean")].to_numpy().reshape(2, )[1],
        #             yerr=data_tmp.loc[:, data_tmp.columns.str.endswith("sd")].to_numpy().reshape(2, )[1],
        #             ecolor="black")
        ax.errorbar(data_tmp.loc[:, data_tmp.columns.str.endswith("mean")].columns.to_numpy(),
                    data_tmp.groupby(["task", "correctnum_{}bit".format(tdata.bit)]).mean().loc[:,
                    data_tmp.groupby(["task", "correctnum_{}bit".format(tdata.bit)]).mean().columns.str.endswith(
                        "mean")].to_numpy().T,
                    yerr=np.mean(data_tmp.loc[:, data_tmp.columns.str.endswith("sd")].to_numpy(), axis=0),
                    ecolor="blue")
        # ax.errorbar(data_tmp.loc[:, data_tmp.columns.str.endswith("mean")].columns[1],
        #             data_tmp.groupby(["task", "correctnum_{}bit".format(tdata.bit)]).mean().loc[:,
        #             data_tmp.groupby(["task", "correctnum_{}bit".format(tdata.bit)]).mean().columns.str.endswith(
        #                 "mean")].to_numpy().reshape(2, )[1],
        #             yerr=data_tmp.loc[:, data_tmp.columns.str.endswith("sd")].to_numpy().reshape(2, )[1],
        #             ecolor="black")
        # mean
        ax.plot(data_tmp.loc[:, data_tmp.columns.str.endswith("mean")].columns,
                data_tmp.loc[:, data_tmp.columns.str.endswith("mean")].to_numpy().T,
                marker="o", color="black")
        # all average
        ax.plot(data_tmp.loc[:, data_tmp.columns.str.endswith("mean")].columns,
                data_tmp.groupby(["task", "correctnum_{}bit".format(tdata.bit)]).mean().loc[:,
                data_tmp.groupby(["task", "correctnum_{}bit".format(tdata.bit)]).mean().columns.str.endswith(
                    "mean")].to_numpy().T,
                marker="x", color="blue")
        # plt.show(block=True)
        plt.savefig(os.path.join(os.getcwd(), 'fig', 'pattern_ent',
                                 'pattern_ent_average_{}_correct{}.png'.format(group_info[0], group_info[1])))


def export_2bit_analyze(tdata, mice, tasks, bit=2, burst_len=10):
    """ task ごとにパターンの確率を算出して csv 出力 """
    pattern_range = range(pow(bit, 2))
    tmp = pd.DataFrame(columns=["{:02b}".format(i) for i in pattern_range]).fillna(0.0)
    prob_all = dict(zip(tasks, [tmp.copy() for _ in range(len(tasks))]))
    count_all = dict(zip(tasks,
                         [pd.DataFrame(columns=["{:02b}".format(pattern) for pattern in pattern_range]).copy() for _ in
                          range(len(tasks))]))
    # row_data = dict(
    #     zip(tasks, [pd.DataFrame(columns=["{:02b}".format(pattern) for pattern in pattern_range]) for _ in tasks]))
    for mouse_no in mice:
        data = tdata.mice_task[tdata.mice_task.mouse_id == mouse_no]
        data_ci = data[data.event_type.isin(["reward", "failure"])].reset_index(drop=True)
        f_same_prev = lambda x: data_ci.at[data_ci[data_ci.session_id == x].index[0] + 1, "hole_no"] == \
                                data_ci.at[data_ci[data_ci.session_id == x].index[0], "hole_no"]
        functions = lambda x: f_same_prev(x)
        data_bursts = data_ci[data_ci.burst.isin(
            data_ci.burst.unique()[data_ci.groupby("burst").burst.count() > burst_len])]
        bit_prob = dict(zip(tasks, [dict(zip(["f_same_prev"], [tmp.copy()])) for _ in range(len(tasks))]))
        for task in tasks:
            data_tmp = data_bursts[(data_ci.task.isin([task]))]
            tmp_df = []
            tmp_count = []
            for pat_tmp in pattern_range:
                # pattern count -> probability
                f_p = pd.DataFrame(list(data_tmp[data_tmp.pattern_2bit == pat_tmp].session_id[1:].map(functions)),
                                   columns=["f_same_prev"]).fillna(0.0)
                # row_data[task]["{:02b}".format(pat_tmp)] = row_data[task]["{:02b}".format(pat_tmp)].append((f_p * 1))
                # f_p.count().to_csv(os.path.join("data", "2bit",
                #                                 "no{:03d}_{}_pat{:02b}_burst_2bit.csv".format(mouse_no, task, pat_tmp)))
                tmp_count.append(len(data_tmp[data_tmp.pattern_2bit == pat_tmp]))
                if len(f_p):
                    """ 一例以上あった場合確率として計算 """
                    # Series
                    tmp_df.append((pd.DataFrame(list(f_p.f_same_prev)).sum().fillna(0.0) /
                                   len(data_tmp[data_tmp.pattern_2bit == pat_tmp])).values[0])

                else:
                    """ 一回もパターンが出ていない場合 """
                    tmp_df.append(np.nan)
            bit_prob[task]["f_same_prev"] = bit_prob[task]["f_same_prev"].append(
                pd.Series(tmp_df, index=bit_prob[task]["f_same_prev"].columns), ignore_index=True)
            prob_all[task] = prob_all[task].append(pd.Series(tmp_df, index=bit_prob[task]["f_same_prev"].columns),
                                                   ignore_index=True)
            count_all[task].loc["no{}".format(mouse_no)] = tmp_count
            # export
            bit_prob[task]["f_same_prev"].to_csv(
                os.path.join("data", "2bit_prob_task-{}_no{}.csv".format(task, mouse_no)), index=False, header=False)
            # graph
            fig = bit_prob[task]["f_same_prev"].T.plot.line(title="2bit no{:03d} task:{}".format(mouse_no, task),
                                                            style="bo-", ylim=(0.0, 1.0), ms=10)
            plt.savefig(os.path.join("fig", "2bit", "no{:03d}_{}_2bit.png".format(mouse_no, task)))
            # plt.show()
            plt.close()
    for task in tasks:
        fig = prob_all[task].mean().T.plot.line(title="2bit {} task:{}".format("all", task),
                                                style="ro-", ylim=(0.0, 1.0), ms=10)
        plt.savefig(os.path.join("fig", "2bit", "{}_{}_2bit.png".format("all", task)))
        plt.show()
        plt.close()
        prob_all[task].to_csv(os.path.join("data", "2bit", "allmice_{}_burst_2bit_prob.csv".format(task)), index=False)
        count_all[task].to_csv(os.path.join("data", "2bit", "n_{}_burst_2bit_count.csv".format(task)))
        count_all[task].count()
        # [row_data[task]["{:02b}".format(pat)].to_csv(
        #     os.path.join("data", "2bit", "{}_{}_pat{:02b}_burst_2bit.csv".format("allmice", task, pat))) for pat in
        #     pattern_range]
        # return bit_prob


def export_all_entropy(tdata, mice, tasks=["All5_90", "All5_30", "All5_30_drug"]):
    target_task = tasks
    ret_val = pd.DataFrame()
    datas = tdata if isinstance(tdata, list) else [tdata]

    def min_max(x, axis=None):
        np.array(x)
        min = np.array(x).min(axis=axis)
        max = np.array(x).max(axis=axis)
        result = (x - min) / (max - min)
        return result

    for d in datas:
        for mouse_id in mice:
            tmp = [mouse_id]
            for task in target_task:
                data = d.mice_task[
                    (d.mice_task.event_type.isin(["reward", "failure"]))
                    & (d.mice_task.task.isin([task]))
                    & (d.mice_task.mouse_id.isin([str(mouse_id)]))]
                if not len(data):
                    continue
                current_entropy = min_max([data["is_hole{}".format(str(hole_no))].sum() /
                                           len(data) for hole_no in [1, 3, 5, 7, 9]])
                tmp += [task, entropy(current_entropy, base=2)]
            if len(tmp) == 1:
                continue
            ret_val = ret_val.append([tmp])

    ret_val.to_csv(os.path.join("data", "entropy", "entropy_tasks_{}.csv".format("_".join(target_task))), index=False,
                   header=False)
    # [ret_val[ret_val.task.isin([task])].to_csv(os.path.join("data", "entropy", "entropy_task_{}.csv".format(task)),
    #                                            index=False, header=False) for task in target_task]


def view_converse_reaction_time(tdata, mice, tasks):
    tasks = tasks if isinstance(tasks, list) else [tasks]
    tdata = tdata if isinstance(tdata, list) else [tdata]
    for task in tasks:
        for data in tdata:
            data = data.mice_delta[task][
                (data.mice_delta[task].type.isin(["reaction_time"]))  # &
                # (tdata.mice_delta.mouse_id == mice)
            ].reaction_time_sec
            fig = plt.figure(figsize=(15, 8), dpi=100)
            data.plot.hist(bins=100)
            plt.title("reaction time task:{}".format(task))
            plt.xlabel("reaction time(s)")
            plt.rcParams["font.size"] = 18
            plt.savefig(os.path.join("fig", "task-{}_reaction_time.png".format(task)))
            plt.show()


def view_converse_reward_latency(tdata, mice, tasks, bin=100):
    tasks = tasks if isinstance(tasks, list) else [tasks]
    tdata = tdata if isinstance(tdata, list) else [tdata]
    for task in tasks:
        for data in tdata:
            data = data.mice_delta[task][
                (data.mice_delta[task].type.isin(["reward_latency"]))  # &
                # (tdata.mice_delta.mouse_id == mice)
            ].noreward_duration_sec
            fig = plt.figure(figsize=(15, 8), dpi=100)
            data.plot.hist(bins=bin)
            plt.xlabel("reward latency(s)")
            plt.title("reward latency task:{}".format(task))
            plt.rcParams["font.size"] = 18
            plt.savefig(os.path.join("fig", "task-{}_reward_latency.png".format(task)))
            plt.show()
            plt.close()
            # under 100
            fig = plt.figure(figsize=(15, 8), dpi=100)
            data[data <= 300].plot.hist(bins=bin)
            plt.xlabel("reward latency(s)")
            plt.title("reward latency task:{}".format(task))
            plt.rcParams["font.size"] = 18
            plt.savefig(os.path.join("fig", "task-{}_reward_latency_u300.png".format(task)))
            plt.show()
            plt.close()
            # under 30000
            fig = plt.figure(figsize=(15, 8), dpi=100)
            data[data <= 10000].plot.hist(bins=bin)
            plt.xlabel("reward latency(s)")
            plt.title("reward latency task:{}".format(task))
            plt.rcParams["font.size"] = 18
            plt.savefig(os.path.join("fig", "task-{}_reward_latency_u10000.png".format(task)))
            plt.show()
            plt.close()


def view_50step_entropy(tdata, mice, tasks):
    data = tdata
    if isinstance(tdata, list):
        data = pd.concat([d.mice_task for d in tdata])
    else:
        data = tdata.mice_task
    # entropy
    data = data[
        (data.task.isin(tasks)) &
        (data.event_type.isin(["reward"]))]
    for task in tasks:
        df = data[(data.task == task)].groupby(["cumsum_correct_taskreset"]).mean().head(150)
        fig, ax = plt.subplots(1, 1, figsize=(15, 8), dpi=100)
        fig.suptitle('50step entropy task:{}'.format(task))
        ax.plot(df.index, df['entropy_50'])
        ax.set_ylabel('Entropy (bit)')
        plt.rcParams["font.size"] = 18
        plt.savefig(os.path.join("fig", "task-{}_entropy_upto150.png".format(task)))
        plt.show()
        plt.close()
    for task in tasks:
        for mouse_id in mice:
            df = data[(data.task == task) & (data.mouse_id == mouse_id)].set_index("cumsum_correct_taskreset")
            fig, ax = plt.subplots(1, 1, figsize=(15, 8), dpi=100)
            fig.suptitle('50step entropy task:{}'.format(task))
            ax.plot(df.index, df['entropy_50'])
            ax.set_ylabel('Entropy (bit)')
            plt.savefig(os.path.join("fig", "no{}_task-{}_entropy.png".format(mouse_id, task)))
            plt.show()
            plt.close()


def export_previeous_entropy(tdata, mice, tasks):
    # entropy
    data = tdata.mice_task[
        (tdata.mice_task.task.isin(tasks)) &
        (tdata.mice_task.event_type.isin(["reward", "failure"]))]
    ret_val = dict(zip(tasks, [pd.DataFrame() for _ in tasks]))
    for task in tasks:
        for mouse_id in mice:
            df = data[(data.task == task) & (data.mouse_id == mouse_id)].tail(100).head(50).reset_index()
            ret_val[task] = ret_val[task].append(
                df["entropy_50"].to_frame().assign(mouse_id=mouse_id))
            data[(data.task == task) & (data.mouse_id == mouse_id)].head(100).reset_index().tail(50)[
                "entropy_50"].to_frame().to_csv(
                os.path.join("data", "pre_entropy_no{}_task_{}.csv".format(mouse_id, task)))
            data[(data.task == task) & (data.mouse_id == mouse_id)].reset_index().tail(50)[
                "entropy_50"].to_frame().to_csv(
                os.path.join("data", "post_entropy_no{}_task_{}.csv".format(mouse_id, task)))
        ret_val[task].to_csv(os.path.join("data", "allmice_{}_previous_100step_entropy.csv".format(task)))
        ret_val[task].groupby(level=0).mean().entropy_50.to_csv(
            os.path.join("data", "mean_{}_previous_100step_entropy.csv".format(task)), index=False)


def export_prepost_entropy(tdata, mice, tasks):
    # entropy
    cidata = tdata.mice_task[
        (tdata.mice_task.task.isin(tasks)) &
        (tdata.mice_task.event_type.isin(["reward","failure"]))]
    for task in tasks:
        with open('data/prepost_entropy_task_{}.csv'.format(task), 'w', newline="") as f:
            writer = csv.writer(f)
            for mouse_id in mice:
                sz = cidata[(cidata.task == task) & (cidata.mouse_id == mouse_id)].reset_index()["entropy_50"].size
                print("size = {}".format(sz))
                if sz >= 300:
                    pre  = cidata[(cidata.task == task) & (cidata.mouse_id == mouse_id)].reset_index()["entropy_50"].iat[50]
                    post = cidata[(cidata.task == task) & (cidata.mouse_id == mouse_id)].reset_index()["entropy_50"].iat[300]
                    writer.writerow([mouse_id, pre, post])
                else:
                    print("error")


def view_averaged_prob_same_prev(tdata, mice, tasks):
    m = []
    t = []
    csame = []
    fsame = []

    for mouse_id in mice:
        for task in tasks:
            m += [mouse_id]
            t += [task]
            csame += [tdata.task_prob[task][tdata.task_prob[task].mouse_id == mouse_id]['c_same']]
            fsame += [tdata.task_prob[task][tdata.task_prob[task].mouse_id == mouse_id]['f_same']]

    after_prob_df = pd.DataFrame(
        data={'mouse_id': m, 'task': t, 'c_same': csame, 'f_same': fsame},
        columns=['mouse_id', 'task', 'c_same', 'f_same']
    )

    plt.style.use('default')
    fig = plt.figure(figsize=(8, 4), dpi=100)
    for task in tasks:
        plt.subplot(1, len(tasks), tasks.index(task) + 1)

        c_same = np.array(after_prob_df[after_prob_df['task'].isin([task])]['c_same'].to_list())
        c_same_avg = np.mean(c_same, axis=0)
        c_same_std = np.std(c_same, axis=0)
        c_same_var = np.var(c_same, axis=0)

        f_same = np.array(after_prob_df[after_prob_df['task'].isin([task])]['f_same'].to_list())
        f_same_avg = np.mean(f_same, axis=0)
        f_same_var = np.var(f_same, axis=0)

        # 各個体ごと平均確率からの差
        c_same_indiv_avg = c_same.groupby("mouse_id").mean()
        f_same_indiv_avg = f_same.groupby("mouse_id").mean()
        
        task_prob_df = after_prob_df[after_prob_df['task'].isin([task])]
        c_same_norm = ((task_prob_df['c_same']) - (pd.merge(task_prob_df['mouse_id'],c_same_indiv_avg,how='left',on="mouse_id")["c_same"])).to_numpy()
        f_same_norm =((task_prob_df['f_same']) - (pd.merge(task_prob_df['mouse_id'],c_same_indiv_avg,how='left',on="mouse_id")["f_same"])).to_numpy()
        after_prob_df = after_prob_df.assign(c_same_norm=c_same_norm)
        after_prob_df = after_prob_df.assign(f_same_norm=f_same_norm)
        
        xlen = len(c_same_avg)
        xax = np.array(range(1, xlen + 1))
        plt.plot(xax, c_same_avg, label="rewarded start")
        plt.errorbar(xax, c_same_avg, yerr=c_same_norm, capsize=2, fmt='o', markersize=1, ecolor='black',
                     markeredgecolor="black", color='w', lolims=True)

        plt.plot(np.array(range(1, xlen + 1)), f_same_avg, label="no-rewarded start")
        plt.errorbar(xax, f_same_avg, yerr=f_same_norm, capsize=2, fmt='o', markersize=1, ecolor='black',
                     markeredgecolor="black", color='w', uplims=True)

        # plt.ion()
        plt.xticks(np.arange(1, xlen + 1, 1))
        plt.xlim(0.5, xlen + 0.5)
        plt.ylim(0.25, 0.6)
        if tasks.index(task) == 0:
            plt.ylabel('P (same choice)')
            plt.legend()
        plt.xlabel('Trial')
        plt.title('{}'.format(task))
    # plt.savefig('fig/{}_prob_all4.png'.format(graph_ins.exportpath))
    plt.rcParams["font.size"] = 18
    plt.savefig('fig/prob_all4_{}.png'.format(tasks[0]))
    plt.show()
    
#    for mouse_id in mice:
#        for task in tasks:
#            after_prob_df[after_prob_df['mouse_id'].isin([mouse_id])][after_prob_df['task'].isin([task])][c_same].to_csv('data/{0}_{1}_correct.csv'.format(mouse_id, task))
#           after_prob_df[after_prob_df['mouse_id'].isin([mouse_id])][after_prob_df['task'].isin([task])][f_same].to_csv('data/{0}_{1}_correct.csv'.format(mouse_id, task))

#     after_prob_df[after_prob_df['task'].isin([task])].to_csv('data/{0}_ciprob.csv'.format(task))


In [None]:
mice90 = [35,36,38,39,40,41,42,43]
tasks90 = ["All5_90"]

logpath = '~/PycharmProjects/RaspSkinnerBox/MiceAnalysis'
tdata90 = task_data(mice90, tasks90, logpath)

#view_averaged_prob_same_prev(tdata90, mice90, tasks90)
#view_summary(tdata, mice, tasks)

# 45(体重無し?) ファイルが無い・要探索!
# 27(体重無し)
# 33: 317まで
mice50 = [27,30,31,33,47,49,50]
tasks50 = ["All5_50"]
tdata50 = task_data(mice50, tasks50, logpath)


In [None]:
export_P_20(tdata30, mice30, tasks30).to_csv("{}_aaaaa.csv".format(tasks30[0]))
export_P_20(tdata50, mice50, tasks50).to_csv("{}_aaaaa.csv".format(tasks50[0]))
export_P_20(tdata90, mice90, tasks90).to_csv("{}_aaaaa.csv".format(tasks90[0]))

In [None]:

def view_averaged_prob_same_prev_avg(tdata, mice, tasks):
    m = []
    t = []
    csame = []
    fsame = []
    m_average = []

    for mouse_id in mice:
        for task in tasks:
            m += [mouse_id]
            t += [task]
            csame += [tdata.task_prob[task][tdata.task_prob[task].mouse_id == mouse_id]['c_same']]
            fsame += [tdata.task_prob[task][tdata.task_prob[task].mouse_id == mouse_id]['f_same']]
            m_average += [export_P_20(tdata, mouse_id, task).AVG]

    after_prob_df = pd.DataFrame(
        data={'mouse_id': m, 'task': t, 'c_same': csame, 'f_same': fsame, 'm_average': m_average},
        columns=['mouse_id', 'task', 'c_same', 'f_same', 'm_average']
    )

    plt.style.use('default')
    fig = plt.figure(figsize=(8, 4), dpi=100)
    for task in tasks:
        plt.subplot(1, len(tasks), tasks.index(task) + 1)

        c_same = np.array(after_prob_df[after_prob_df['task'].isin([task])]['c_same'].to_list())
        c_same_avg = np.mean(c_same, axis=0)
        c_same_std = np.std(c_same, axis=0)
        c_same_var = np.var(c_same, axis=0)
        c_same_norm = c_same - np.repeat(np.mean(c_same,axis=1),9).reshape(-1,9)
        c_same_norm_var = np.var(c_same_norm,axis=0)
        c_same_norm_avg = np.mean(c_same_norm, axis=0)
        

        f_same = np.array(after_prob_df[after_prob_df['task'].isin([task])]['f_same'].to_list())
        f_same_avg = np.mean(f_same, axis=0)
        f_same_var = np.var(f_same, axis=0)
#        f_same_norm = f_same - np.repeat(np.mean(f_same,axis=1),9).reshape(-1,9)
        f_same_norm = f_same - after_prob_df[]

        f_same_norm_var = np.var(f_same_norm,axis=0)
        f_same_norm_avg = np.mean(f_same_norm, axis=0)

        xlen = len(c_same_avg)
        xax = np.array(range(1, xlen + 1))
        plt.plot(xax, c_same_norm_avg, label="rewarded start")
        plt.errorbar(xax, c_same_norm_avg, yerr=c_same_norm_var, capsize=2, fmt='o', markersize=1, ecolor='black',
                     markeredgecolor="black", color='w', lolims=True)

        plt.plot(np.array(range(1, xlen + 1)), f_same_norm_avg, label="no-rewarded start")
        plt.errorbar(xax, f_same_norm_avg, yerr=f_same_norm_var, capsize=2, fmt='o', markersize=1, ecolor='black',
                     markeredgecolor="black", color='w', uplims=True)

        # plt.ion()
        plt.xticks(np.arange(1, xlen + 1, 1))
        plt.xlim(0.5, xlen + 0.5)
#        plt.ylim(0.25, 0.6)
        plt.ylim(-0.1, 0.1)
        if tasks.index(task) == 0:
            plt.ylabel('P (same choice)')
            plt.legend()
        plt.xlabel('Trial')
        plt.title('{}'.format(task))
#        plt.hlines(average,-11,11)
        plt.hlines(0,-11,11)
    plt.rcParams["font.size"] = 18
    plt.savefig('fig/WSLS_{}.png'.format(tasks[0]))
    plt.show()
    
    #TODO: 1. c_same, f_sameに加えて，各マウスのaverageを各マウスのc_sameから引いて正規化したc_same_norm, f_same_normを計算
    #TODO: 2. mice_id, c_same, f_same, c_same_norm, f_same_norm CSV出力 (グラフは両方)
    
def export_entropy300(tdata, mice, tasks):
    # entropy
    cidata = tdata.mice_task[
        (tdata.mice_task.task.isin(tasks)) &
        (tdata.mice_task.event_type.isin(["reward","failure"]))]
    for task in tasks:
        with open('data/entropy_300_task_{}.csv'.format(task), 'w', newline="") as f:
            writer = csv.writer(f)
            for mouse_id in mice:
                sz = cidata[(cidata.task == task) & (cidata.mouse_id == mouse_id)]["entropy_50"].size
#                 print(cidata[(cidata.task == task) & (cidata.mouse_id == mouse_id)])
#                 print("size = {}".format(sz))
                if sz >= 300:
                    post = cidata[(cidata.task == task) & (cidata.mouse_id == mouse_id)].reset_index()["entropy_150"].iat[300]
                    writer.writerow([mouse_id, post])
                else:
                    print("error")


# logpath = '~/PycharmProjects/RaspSkinnerBox/MiceAnalysis'
logpath = "./"
#45(体重無し?) ファイルが無い・要探索!
# 27(体重無し)
# 33: 317まで
#mice50 = [27,30,31,33,47,49,50]
#mice50 = [52, 64, 67, 68, 70, 71,  27,30,31,33,47,49,50]
mice50 = [64, 67, 68, 70, 71,  27,30,31,33,47,49,50]

tasks50 = ["All5_50"]
tdata50 = task_data(mice50, tasks50, logpath)

export_entropy300(tdata50, mice50, tasks50)
view_averaged_prob_same_prev_avg(tdata50, mice50, tasks50)
#view_averaged_prob_same_prev(tdata50, mice50, tasks50)

# logpath = '~/PycharmProjects/RaspSkinnerBox/MiceAnalysis'
mice30 = [6, 7, 8, 11, 12, 13, 14, 17, 18, 19, 21, 22, 23, 24]
tasks30 = ["All5_30"]
tdata30 = task_data(mice30, tasks30, logpath)
export_entropy300(tdata30, mice30, tasks30)
view_averaged_prob_same_prev_avg(tdata30, mice30, tasks30)



In [None]:
# 10先~19先で基準と同じ穴を選ぶ確率を平均する
def export_P_20(tdata, mice, tasks, indiv=False):
    data = tdata.mice_task[(tdata.mice_task.mouse_id.isin(mice))&(tdata.mice_task.event_type.isin(["reward","failure"]))]
    avg = []
    avg_c = dict(zip(mice,[[] for _ in mice]))
    avg_i = dict(zip(mice,[[] for _ in mice]))
    for task in tasks:
        task_data = data[data.task.isin([task])]
        # 全mice平均
        for no in mice:
            mice_data = task_data[task_data.mouse_id.isin([no])]
            prob = []
            prob_co = []
            prob_in = []
            all_data = mice_data.reset_index()
            correct_data = mice_data[mice_data.event_type.isin(["reward"])].reset_index()
            incorrect_data = mice_data[mice_data.event_type.isin(["failure"])].reset_index()
            length = len(mice_data)- 20
            if length <= 0:
                continue
            for idx, dat in all_data.drop(range(len(all_data)-20,len(all_data))).iterrows():
                tmp = 0
                for j in range(10,20):
                    if dat["hole_no"] == all_data["hole_no"][idx + j]:
                        tmp += 1
                prob.append(tmp/10)
            avg.append(np.average(np.array(prob)))
            # correctデータのみを対象， taskぶち抜き
            for idx, dat in correct_data.drop(range(len(correct_data)-20,len(correct_data))).iterrows():
                tmp = 0
                for j in range(10,20):
                    if dat["hole_no"] == correct_data["hole_no"][idx + j]:
                        tmp += 1
                prob_co.append(tmp/10)
            avg_c[no].append(np.array(np.average(np.array(prob_co))))
            # incorrect
            for idx, dat in incorrect_data.drop(range(len(incorrect_data)-20,len(incorrect_data))).iterrows():
                tmp = 0
                for j in range(10,20):
                    if dat["hole_no"] == incorrect_data["hole_no"][idx + j]:
                        tmp += 1
                prob_in.append(tmp/10)
            avg_i[no].append(np.average(np.array(prob_in)))
    for mouse_id in mice:
        avg_c[mouse_id] = np.average(avg_c[mouse_id])
        avg_i[mouse_id] = np.average(avg_i[mouse_id])
    if indiv:
        return {"c_same":[np.average(np.array(list(avg_c.values())))],"f_same":[np.average(np.array(list(avg_i.values())))],
               "c_same_mice":avg_c,"f_same_mice":avg_i}
    return pd.DataFrame(data={"task":tasks,"AVG":[np.average(avg)]})

In [None]:
logpath = "./"
mice50 = [64, 67, 68, 70, 71,  27,30,31,33,47,49,50]
tasks50 = ["All5_50"]
tdata50 = task_data(mice50, tasks50, logpath)

In [None]:
# 選択確率(マウスごとの選択偏りを引いた)グラフ
mice = mice50
tdata = tdata50
tasks = tasks50

m = []
t = []
csame = []
fsame = []
m_average = []

for mouse_id in mice:
    for task in tasks:
        m += [mouse_id]
        t += [task]
        csame += [tdata.task_prob[task][tdata.task_prob[task].mouse_id == mouse_id]['c_same']]
        fsame += [tdata.task_prob[task][tdata.task_prob[task].mouse_id == mouse_id]['f_same']]
        m_average += [export_P_20(tdata, [mouse_id], [task]).AVG]

after_prob_df = pd.DataFrame(
    data={'mouse_id': m, 'task': t, 'c_same': csame, 'f_same': fsame, 'm_average': m_average},
    columns=['mouse_id', 'task', 'c_same', 'f_same', 'm_average']
)

for task in tasks:

    c_same = np.array(after_prob_df[after_prob_df['task'].isin([task])]['c_same'].to_list())
    c_same_avg = np.mean(c_same, axis=0)
    c_same_std = np.std(c_same, axis=0)
    c_same_var = np.var(c_same, axis=0)

    c_same_norm = c_same - np.repeat(np.mean(c_same,axis=1),9).reshape(-1,9)
    print(c_same_norm)
    av = after_prob_df[after_prob_df['task'].isin([task])]['m_average'].to_numpy()
    print(after_prob_df[after_prob_df['task'].isin([task])]['m_average'])
    avm = np.repeat(av,9).reshape(-1,9)

In [None]:
import numpy as np
import matplotlib as plt

x = 0.5
p = np.linspace(0,1,100)
q = p*x / (p*x + (1-p))

plt.style.use('default')
#fig = plt.figure(figsize=(8, 4), dpi=100)
plt.plot(p, q, label="rewarded start")

plt.show()


In [None]:
# 位相性
def export_phase_data(tdata, mice, tasks):
    for task in tasks:
        data = tdata.mice_task[
                (tdata.mice_task.task == task) &
                (tdata.mice_task.event_type.isin(["nose poke"]))]
        distance = np.array([
            [0,1,2,3,4],
            [1,0,1,2,3],
            [2,1,0,1,2],
            [3,2,1,0,1],
            [4,3,2,1,0]])
        mice_hole_after_times = {}
        mice_hole_select_times = np.empty((0,6), int)
        mice_distance = {}
        mice_correct = {}
        mice_incorrect = {}

        for mouse_id in mice:
            mice_data = data[data.mouse_id.isin([mouse_id])]
            tmp = np.zeros(5*5).reshape(5,5)
            if mice_data.empty:
                continue

            # 1a. 各マウス、各タスク毎に5x5の割合のマトリックスをcsvで出す。
            all_data = mice_data.reset_index()
            length = len(mice_data)- 1
            if length <= 0:
                continue
            for idx, dat in all_data.drop(range(len(all_data)-1,len(all_data))).iterrows():
                tmp[int((int(all_data["hole_no"][idx])-1)/2)][int((int(all_data["hole_no"][idx + 1])-1)/2)] += 1
            mice_hole_after_times.update({mouse_id:tmp/len(all_data)})
            # export
            np.savetxt("data/1a_no{}_task{}_selection.csv".format(mouse_id,task),mice_hole_after_times[mouse_id],fmt="%f",delimiter=',')
            
            # 2a. 各選択がサブタスク内でどのような割合であるかを示す基礎データをcsvに出す（1列目マウスID, 2～6列目1,3,5,7,9の選択数)。
            tmp_select_times = (all_data["hole_no"].value_counts() + pd.Series([0,0,0,0,0],index=["1","3","5","7","9"])).fillna(0)
            mouse_hole_select_times = np.concatenate((np.array([mouse_id]),tmp_select_times.to_numpy()/np.sum(tmp_select_times.to_numpy())))
            mice_hole_select_times = np.append(mice_hole_select_times,[mouse_hole_select_times],axis=0)

            # 3a. 横軸距離(1→1なら0, 1→3なら1, 1→5なら2, 1→7なら3, 1→9なら4, 3→5なら1 ...)、縦軸度数の割合（0～4の距離） のcsvを各マウス、各タスク毎に出力
            mouse_dist = [0,0,0,0,0]
            for dist,idx in zip(distance,range(len(distance))):
                for num,col in zip(dist,range(len(dist))):
                    mouse_dist[num] += tmp[idx,col]
            mouse_dist /= sum(mouse_dist)
            mice_distance.update({mouse_id:mouse_dist})
            # export 
            np.savetxt("data/3a_no{}_task{}_selection_distance.csv".format(mouse_id,task),mouse_dist,delimiter=',',fmt="%f")
            
            # 4a correctの次の移動距離とincorrectの次の移動距離
            dt = tdata.mice_task[(tdata.mice_task.task.isin([task]))]
            correct = [0,0,0,0,0]
            incorrect = [0,0,0,0,0]
            correct_session = dt[dt.is_correct.isin([1])].session_id.values.tolist()
            incorrect_session = dt[dt.is_failure.isin([1])].session_id.values.tolist()
            for idx, dat in all_data[all_data.session_id.isin(correct_session)].iterrows():
                try:
                    correct[int(abs((int(all_data["hole_no"][idx])-1)/2-
                            (int(all_data["hole_no"][idx+1])-1)/2))] += 1 
                except KeyError:
                    continue
            for idx, dat in all_data[all_data.session_id.isin(incorrect_session)].iterrows():
                try:
                    incorrect[int(abs((int(all_data["hole_no"][idx])-1)/2-
                            (int(all_data["hole_no"][idx+1])-1)/2))] += 1 
                except KeyError:
                    continue
            mice_correct.update({mouse_id:correct})
            mice_incorrect.update({mouse_id:incorrect})
            np.savetxt("data/4a_no{}_task{}_correct.csv".format(mouse_id,task),correct,delimiter=',',fmt="%f")
            np.savetxt("data/4a_no{}_task{}_incorrect.csv".format(mouse_id,task),incorrect,delimiter=',',fmt="%f")
          
        # 2a export
        np.savetxt("data/2a_task{}_selection_rate.csv".format(task),mice_hole_select_times,delimiter=',',fmt="%f")
        # 4a 
        np.savetxt("data/4a_no{}_task{}_correct.csv".format(mouse_id,task),correct,delimiter=',',fmt="%f")
          
        # 1b. 1aを基に各タスク毎、全マウス分の平均をcsvで出す。
        all_mean_1 = np.mean(np.array(list(mice_hole_after_times.values())),axis=0)
        np.savetxt("data/1b_allmice_task{}_selection_mean.csv".format(task),all_mean_1,delimiter=',',fmt="%f")
        # 2b. 2aを基に各タスク毎、全マウス分の平均をcsvで出す。
        all_mean_2 = np.delete(np.mean(mice_hole_select_times,axis=0),0,0)
        np.savetxt("data/2b_allmice_task{}_selection_rate_mean.csv".format(task),all_mean_2,delimiter=',',fmt="%f")
        # 3b. 3aを基に各タスク毎、全マウス分の平均をcsvで出力。            
        all_mean_3 = np.mean(np.array(list(mice_distance.values())), axis=0)
        np.savetxt("data/3b_allmice_task{}_selection_distance_mean.csv".format(task),all_mean_3,delimiter=',',fmt="%f")
        # 4b
        all_mean_4_co = np.mean(np.array(list(mice_correct.values())),axis=0)
        all_mean_4_in = np.mean(np.array(list(mice_incorrect.values())),axis=0)
        np.savetxt("data/4b_allmice_task{}_correct_mean.csv".format(task),all_mean_4_co,delimiter=',',fmt="%f")
        np.savetxt("data/4b_allmice_task{}_incorrect_mean.csv".format(task),all_mean_4_in,delimiter=',',fmt="%f")

        
    # 出力
    # 1a. 各マウス、各タスク毎に5x5の割合のマトリックスをcsvで出す。
    # 1b. 1aを基に各タスク毎、全マウス分の平均をcsvで出す。
    # 2a. 各選択がサブタスク内でどのような割合であるかを示す基礎データをcsvに出す（1列目マウスID, 2～6列目1,3,5,7,9の選択数)。
    # 2b. 2aを基に各タスク毎、全マウス分の平均をcsvで出す。
    # 3a. 横軸距離(1→1なら0, 1→3なら1, 1→5なら2, 1→7なら3, 1→9なら4, 3→5なら1 ...)、縦軸度数の割合（0～4の距離） のcsvを各マウス、各タスク毎に出力
    # 3b. 3aを基に各タスク毎、全マウス分の平均をcsvで出力。

In [None]:
export_phase_data(tdata50,mice50,tasks50)

In [None]:
print(count_task(tdata50,mice50,tasks50,[1,3,7,9]))
# print(count_task(tdata50,mice50,tasks50,5))
# print(count_task(tdata50,mice50,tasks50),[1,3,5,7,9])

In [None]:
import copy
def export_P_20_filter(dc,mice):
    hole_range = range(1,10)
    avg_c = np.empty((0,len(hole_range)))
    avg_f = np.empty((0,len(hole_range)))
    avg_c_base = np.empty(0)
    avg_i_base = np.empty(0)
    
    for no in mice:
        mice_data = dc[dc.mouse_id.isin([no])]
        prob_co = []
        prob_in = []
        prob_co_base = []
        prob_in_base = []
        correct_data = mice_data[mice_data.event_type.isin(["reward"])].reset_index()
        incorrect_data = mice_data[mice_data.event_type.isin(["failure"])].reset_index()
        length = len(mice_data)- 20
        if length <= 0:
            continue
        
        # c_same
        # export: mice * hole prob
        for idx, dat in correct_data.drop(range(len(correct_data)-10,len(correct_data))).iterrows():
            tmp = np.zeros((len(hole_range),))
            for j in hole_range:
                if dat["hole_no"] == correct_data["hole_no"][idx + j]:
                    tmp[j-1] += 1
            prob_co.append(tmp)
        avg_c = np.append(avg_c, np.array([np.sum(np.array(prob_co),axis=0)/(len(correct_data)-10)]),axis=0)
        
        # f_same
        # export: mice * hole prob
        for idx, dat in incorrect_data.drop(range(len(incorrect_data)-10,len(incorrect_data))).iterrows():
            tmp = np.zeros((len(hole_range),))
            for j in hole_range:
                if dat["hole_no"] == incorrect_data["hole_no"][idx + j]:
                    tmp[j-1] += 1
            prob_in.append(tmp)
        avg_f = np.append(avg_f, np.array([np.sum(np.array(prob_in),axis=0)/(len(incorrect_data)-10)]),axis=0)

        # base correctデータのみを対象， taskぶち抜き 
        # export: mice 
        for idx, dat in correct_data.drop(range(len(correct_data)-20,len(correct_data))).iterrows():
            tmp = 0
            for j in range(10,20):
                if dat["hole_no"] == correct_data["hole_no"][idx + j]:
                    tmp += 1
            prob_co_base.append(tmp/10)
        avg_c_base = np.append(avg_c_base, np.average(np.array(prob_co_base)))
        
        # incorrect base
        # export: mice
        for idx, dat in incorrect_data.drop(range(len(incorrect_data)-20,len(incorrect_data))).iterrows():
            tmp = 0
            for j in range(10,20):
                if dat["hole_no"] == incorrect_data["hole_no"][idx + j]:
                    tmp += 1
            prob_in_base.append(tmp/10)
        avg_i_base = np.append(avg_i_base, np.average(np.array(prob_in_base)))
        
    return {"c_same":avg_c,"f_same":avg_f,"c_same_base":avg_c_base,"f_same_base":avg_i_base}


In [None]:
def count_task(tdata, mice, tasks,selection=[1,3,5,7,9]) -> dict:
    # count_taskをclass task_dataの外で下記の仕様で再実装（classから消す必要はない）ßß
        if isinstance(selection,str):
            selection = [selection]
        elif isinstance(selection,int):
            selection = [str(selection)]
        selection = [str(num)for num in selection]
        dc = tdata.mice_task[tdata.mice_task["event_type"].isin(["reward", "failure"]) & tdata.mice_task.task.isin(tasks)]
        dc = dc.reset_index()

        after_c_all_task = {}
        after_f_all_task = {}

        after_c_starts_task = {}
        after_f_starts_task = {}

        prob_index = ["c_same", "f_same"]
        forward_trace = 10
        task_prob = {}
        task_prob_hole = dict(zip(selection,[[] for _ in selection]))
        tmp_dt = dc[dc["hole_no"].isin(selection)]
        prob = export_P_20_filter(tmp_dt,mice)
        prob["c_same"] = np.hstack((prob["c_same_base"].reshape((len(mice),1)), prob["c_same"]))
        prob["f_same"] = np.hstack((prob["f_same_base"].reshape((len(mice),1)), prob["f_same"]))
        
        #         prob["c_same"].to_csv("./data/mices_task{}_cistart.csv".format(,"".join(task))
        

        correct_data = pd.DataFrame(prob["c_same"],index=mice)
        incorrect_data = pd.DataFrame(prob["f_same"],index=mice)
        # 3のタスク毎の任意の複数の選択肢毎の全マウス平均をcsv出力 
        correct_data.to_csv("./data/mice_task{}_hole{}_cstart.csv".format("-".join(tasks),"".join(selection)))
        incorrect_data.to_csv("./data/mice_task{}_hole{}_fstart.csv".format("-".join(tasks),"".join(selection)))

        return task_prob


In [None]:
count_task(tdata50,mice50,tasks50,[1,3,7,9])

In [None]:
logpath = "~/Downloads/takarada"
mice30 = [6, 7, 8, 11, 12, 13, 14, 17, 18, 19, 21, 22, 23, 24]
tasks30 = ["All5_30"]
tdata30 = task_data(mice30, tasks30, logpath)
export_entropy300(tdata30, mice30, tasks30)
view_averaged_prob_same_prev_avg(tdata30, mice30, tasks30)

In [None]:
from pyentrp import entropy as pent
def calc_permutation_entropy(tdata,mice,tasks):
    for tsk in tasks:
        en_list1 = []
        en_list2 = []
        len_list = []
        for mouse_no in mice:
            data = tdata.mice_task[
                (tdata.mice_task.mouse_id == mouse_no)
                &(tdata.mice_task.task == tsk)
                &(tdata.mice_task.hole_no.isin(["1","3","5","7","9"]))
                &(tdata.mice_task.event_type.isin(["reward","failure"]))
            ][["mouse_id","task","hole_no","event_type","correct_times"]]
            data = data.reset_index()
            leng = len(data)
            choice_data1 = data.head(300).hole_no.astype(int).values
            choice_data2 = data.head(150).hole_no.astype(int).values
            en1 = pent.shannon_entropy(choice_data1)
            en2 = pent.shannon_entropy(choice_data2)
            en_list1.append(en1)
            en_list2.append(en2)
        en_list_vert = np.stack([mice, en_list1, en_list2])
        np.savetxt(f"./data/entropy/allmice_{tsk}.csv",en_list_vert,delimiter=',',fmt="%f")
    return en_list1, en_list2

In [9]:
a = export_daily_feeds(tdata30,mice30)

In [None]:
a[a.feed < 70]

In [None]:
b = export_daily_feeds(tdata50,mice50)

In [None]:
b[b.feed<70]

In [10]:
import datetime as dt
def export_daily_feeds(mice, end_task=""):
    feeds_data = pd.DataFrame(columns=["id","timestamp","day","feed","tasks"])
    # mouseid, 日付, 何日目, 粒数, そのときにやっていたタスクのリストをハイフン繋ぎ
    for mouse_id in mice:
        file = "./no{:03d}_action.csv".format(mouse_id)
        data = pd.read_csv(file,names=["timestamps", "task", "session_id", "correct_times", "event_type", "hole_no"],parse_dates=[0]) 
        if isinstance(data.iloc[0].timestamps, str):
            data = pd.read_csv(file,parse_dates=[0])
            data.columns = ["timestamps", "task", "session_id", "correct_times", "event_type", "hole_no"]
        data = data[data.event_type.isin(["reward"])].reset_index()
        if data.empty:
            continue
        start_timestamp = data.timestamps.iloc[0]
        start_date = start_timestamp.date() 
        base_time = start_timestamp.time()
        if end_task == "":
            end_task = list(data.task.unique())[-1]
        last_taskday = data[:data[data.task.isin([end_task])].index[-1]].timestamps.iat[-1]
        
        finish_date = last_taskday.date() - dt.timedelta(days=1)*(last_taskday.time() < base_time)
        data = data[data.timestamps < dt.datetime.combine(finish_date+dt.timedelta(days=1),base_time)]
        # task
        task = "-".join(list(data.task.unique()))
        # day
        for d in pd.date_range(start_date,finish_date,freq="D"):
            # index:d
            # range
            rew = data.event_type[(data.timestamps > dt.datetime.combine(d,base_time)) & 
                 (data.timestamps < dt.datetime.combine(d+dt.timedelta(days=1),base_time))].count()
            feeds_data = feeds_data.append(
                pd.Series([mouse_id,d,(d.date()-start_timestamp.date() + dt.timedelta(days=1)).days,rew,task],index=["id","timestamp","day","feed","tasks"]
                                                                         ),ignore_index=True)
        feeds_data[feeds_data.id.isin([mouse_id])].to_csv("./data/no{:03d}_feeds_summary_{}.csv".format(mouse_id,"-".join(list(data.task.unique()))))
    feeds_data.to_csv("./data/allmice_feeds_summary_{}.csv".format("-".join(list(data[:data[data.task.isin([end_task])].index[-1]].task.unique()))))
    return feeds_data

In [11]:
mice_lever = [139]
export_daily_feeds(mice_lever,"left")

Unnamed: 0,id,timestamp,day,feed,tasks
0,139,2020-08-12,1,120,T0-left-right
1,139,2020-08-13,2,73,T0-left-right
2,139,2020-08-14,3,77,T0-left-right


In [None]:
from datetime import datetime, timedelta
from typing import Union

import pandas as pd
import numpy as np
from scipy.stats import entropy
import sys
import matplotlib.pyplot as plt
import matplotlib.collections as collections
import matplotlib.markers as markers
import os
import csv

# サブタスクの実施時間をCSV出力
def export_subtask_duration_csv(mouse_id):
    if not isinstance(mouse_id,list):
        mouse_id = [mouse_id]
    for no in mouse_id:
        file = "./no{:03d}_action.csv".format(no)
        data = pd.read_csv(file,names=["timestamps", "task", "session_id", "correct_times", "event_type", "hole_no"],parse_dates=[0])
        if isinstance(data.iloc[0].timestamps, str):
            data = pd.read_csv(file,parse_dates=[0])
            data.columns = ["timestamps", "task", "session_id", "correct_times", "event_type", "hole_no"]
        data = data[["timestamps","event_type","task","hole_no"]]
        tasks = data.task.unique()
        print(tasks)
        event_times = pd.pivot_table(data[data.event_type.isin(["reward","failure","time over"])],index="event_type",columns="task",aggfunc="count").timestamps
        task_duration = pd.DataFrame(data.groupby("task").timestamps.max()-data.groupby("task").timestamps.min())
        task_duration.timestamps = task_duration.timestamps / np.timedelta64(1, 'h')
        ret_val = event_times.append(task_duration.T).fillna(0)
        ret_val = ret_val.rename(index={"timestamps":"duration in hours"})
        # 列の順番をtaskをやった順でソート
        ret_val = ret_val.loc[:,tasks]
        ret_val.to_csv("./data/no{:03d}_summary.csv".format(no))
        # 各タスクの各選択肢毎のerror数, correct数, total trial数を出したい
        tasks_df = []
        for task in tasks: 
            tasks_data = data[data.task.isin([task])]
            tasks_df.append(tasks_data[tasks_data.event_type.isin(["failure"])])
            tasks_data = pd.pivot_table(tasks_data[tasks_data.event_type.isin(["reward","failure","time over"])],index="event_type",columns="hole_no",aggfunc="count").timestamps.fillna(0)
            tasks_data.loc["total_trials"] = tasks_data.sum()
            tasks_data.to_csv("./data/no{:03d}_{}_selection_trials.csv".format(no,task))
        # return tasks_df

In [None]:
a = export_subtask_duration_csv(145)
a[1]

In [None]:
a[1].to_csv("./data/no145_left_fail.csv")