In [3]:
from modeling import *
from behaviors import *
from peristimulus import *
import seaborn as sns

In [7]:
folder = "/Users/albertqu/Documents/7.Research/Wilbrecht_Lab/CADA_data/ProbSwitch_FP_data_new"
animal, session = "A2A-15B-B_RT", "p153_FP_LH"
files = encode_to_filename(folder, animal, session, ['processed'])
hfile = h5py.File(files)

In [120]:
class BehaviorMat:

    code_map = {1: ('center_in', 'center_in'),
                11: ('center_in', 'initiate'),
                2: ('center_out', 'center_out'),
                3: ('side_in', 'left'),
                4: ('side_out', 'left'),
                44: ('side_out', 'left'),
                5: ('side_in', 'right'),
                6: ('side_out', 'right'),
                66: ('side_out', 'right'),
                71.1: ('outcome', 'correct_unrewarded'),
                71.2: ('outcome', 'correct_rewarded'),
                72: ('outcome', 'incorrect_unrewarded'),
                73: ('outcome', 'missed'),  # saliency questionable
                74: ('outcome', 'abort')}  # saliency questionable

    # Always use efficient coding
    def __init__(self, animal, session, hfile, tau=np.inf):
        self.tau = tau
        self.animal = animal
        self.session = session
        if isinstance(hfile, str):
            print("For pipeline loaded hdf5 is recommended for performance")
            hfile = h5py.File(hfile, 'r')
        self.choice_sides = None
        self.exp_complexity = None  # Whether the ITI is complex (first round only analysis simple trials)
        self.struct_complexity = None
        self.trialN = 0
        self.hemisphere, self.region = None, None
        self.event_list = EventNode(None, None, None, None)
        self.initialize(hfile)

    def __str__(self):
        return f"BehaviorMat({self.animal}_{self.session}, tau={self.tau})"

    def initialize(self, hfile):
        self.hemisphere = 'right' if np.array(hfile["out/notes/hemisphere"]).item() else 'left'
        self.region = 'NAc' if np.array(hfile['out/notes/region']).item() else 'DMS'
        trialN = len(hfile['out/value/outcome'])
        self.trialN = trialN
        self.choice_sides = np.full(trialN, '', dtype='<U6')
        self.exp_complexity = np.full(trialN, True, dtype=bool) # default true detect back to back
        self.struct_complexity = np.full(trialN, False, dtype=bool) # default false detect double centers
        self.exp_complexity[0] = False
#         dup = {'correct_unrewarded': 0, 'correct_rewarded': 0, 'incorrect_unrewarded': 0, 
#                'missed': 0, 'abort': 0}
#         ndup = {'correct_unrewarded': 0, 'correct_rewarded': 0, 'incorrect_unrewarded': 0, 
#                'missed': 0, 'abort': 0}
#         self.struct_complexity[0] = False
        trial_event_mat = np.array(hfile['out/value/trial_event_mat'])

        # Parsing LinkedList
        prev_node = None
        # TODO: Careful of the 0.5 trial events
        for i in range(trial_event_mat.shape[0]):
            eventcode, etime, trial = trial_event_mat[i, :]
            old_eventcode = eventcode
            if eventcode == 44 or eventcode == 66:
                eventcode = eventcode // 10
            ctrial = int(np.ceil(trial))-1
            event, opt = BehaviorMat.code_map[eventcode]
#             if i % 1000 == 0:
#                 print(i, eventcode, old_eventcode, etime, trial, ctrial, event, opt)
            makenew = True

            if prev_node is not None:
                if eventcode > 70:
                    lat = prev_node.MLAT if eventcode < 73 else ""
                    self.choice_sides[ctrial] = lat
                    if prev_node.event == 'side_in':
                        prev_node.saliency = 'choice'
#                     if prev_node.etime == etime:
#                         dup[opt] += 1
#                     else:
#                         ndup[opt] += 1
                        #print('highlight', i, eventcode, etime, trial, prev_node.etime, prev_node.ecode)
                if prev_node.etime == etime:
                    if eventcode == prev_node.ecode:
                        makenew = False
                    elif eventcode < 70:
                        print(f"Warning! Duplicate timestamps({prev_node.ecode}, {eventcode}) in {str(self)}")
                    elif eventcode != 72:
                        print(f"Special Event Duplicate: {self.animal}, {self.session}, ", event, opt)
                elif eventcode == 72:
                    print(f"Unexpected non-duplicate for {trial}, {opt}, {self.animal}, {self.session}")
            else:
                assert eventcode < 70, 'outcome cannot be the first node'
            if makenew:
                # potentially fill out all properties here; then make merge an inheriting process
                evnode = self.event_list.append(event, etime, trial, eventcode)
                # Filling MLAT for side ports, Saliency for outcome and initiate
                if event == 'outcome':
                    assert self.choice_sides[ctrial] == prev_node.MLAT
                    evnode.MLAT = prev_node.MLAT
                if eventcode > 6:
                    evnode.saliency = opt
                elif eventcode > 2:
                    evnode.MLAT = opt
                prev_node = evnode
        # temporal adjacency merge
        assert not self.event_list.is_empty()
        curr_node = self.event_list.next
        while not curr_node.sentinel:
            if '_out' in curr_node.event:
                # COULD do an inner loop to make it look more straightforward
                next_node = curr_node.next
                prev_check = curr_node.prev
                if next_node.sentinel:
                    print(f"Weird early termination with port_out?! {str(curr_node)}")
                # TODO: sanity check: choice side_in does not have any mergeable port before them.
#                 print(curr_node.ecode, next_node.ecode, curr_node.etime, next_node.etime)
                if (next_node.ecode == curr_node.ecode-1) and (next_node.etime - curr_node.etime < self.tau):
                    merge_node = next_node.next
                    if merge_node.sentinel:
                        print(f"Weird early termination with port_in?! {str(next_node)}")
                    assert merge_node.ecode == curr_node.ecode, f"side in results in {str(merge_node)}"
                    merge_node.merged = True
                    self.event_list.remove_node(curr_node)
                    self.event_list.remove_node(next_node)
                    assert prev_check.next is merge_node and merge_node.prev is prev_check, "Data Structure BUG"
                    curr_node = prev_check  # jump back to previous node

            # Mark features so far saliency: only choice/outcome/initiate, MLAT: outcome/side_port
            if not curr_node.next.merged: # only trigger at "boundary events"
                # Make sure this is not a revisit due to merge
                prev_node = curr_node.prev
                next_node = curr_node.next
                if curr_node.event == 'center_in':
                    # just need MLAT
                    if prev_node.event == 'side_out':
                        curr_node.MLAT = prev_node.MLAT
                    # update structural complexity
                    if curr_node.saliency == 'initiate':
                        breakflag = False
                        cursor = curr_node.prev
                        while (not cursor.sentinel) and (cursor.event != 'outcome'):
                            if cursor.event == 'center_in':
                                self.struct_complexity[curr_node.trial_index()] = True
                                breakflag = True
                                break
                            cursor = cursor.prev
                        if not breakflag and cursor.MLAT:
                            assert (cursor.sentinel) or (cursor.next.event == 'side_out'), f"weird {cursor}, {cursor.next}"
                elif curr_node.event == 'center_out':
                    if next_node.event == 'side_in':
                        curr_node.MLAT = next_node.MLAT
                    if next_node.saliency == 'choice':
                        # assume "execution" is at center_out, recognizing that well trained animal might
                        # already have executed a program from side_out (denote side port using first/last)
                        curr_node.saliency = 'execution'
                elif curr_node.event == 'side_out':
                    sals = []
                    # TODO: with different TAU we might not want the first side out as salient event
                    if prev_node.event == 'outcome':
                        sals.append('first')
                    if next_node.event == 'center_in':
                        safe_last = True
                        cursor = next_node
                        print(curr_node.trial_index(), cursor.event, cursor.saliency)
                        while cursor.saliency != 'initiate':
                            print(cursor.event, cursor.saliency)
                            if cursor.sentinel:
                                print(f"Weird early termination with port_out?! {str(cursor.prev)}")
                            if cursor.event == 'side_in':
                                safe_last = False
                                break
                            cursor = cursor.next
                        print(curr_node.trial_index(), cursor.event, cursor.saliency, safe_last)
                        if safe_last:
                            sals.append('last')
                    curr_node.saliency = "_".join(sals)
                    if len(sals) == 2:
                        print(str(curr_node.prev), str(curr_node), str(curr_node.next))
                        self.exp_complexity[int(curr_node.trial)] = False
            curr_node = curr_node.next

    def get_event_times(self, event, simple=True, saliency=True):
        curr = self.event_list.next
        event_times = []
        trials = []
        sals = None
        if saliency and 'side_out' in event:
            event, sals = event.split("__")
        else:
            salmap = {'center_in': 'initiate',
                      'center_out': 'execution',
                      'side_out': 'choice',
                      'outcome': ['correct_unrewarded', 'correct_rewarded', 'incorrect_unrewarded']}
            sals = salmap[event]

        while not curr.sentinel:
            if curr.event == event:
                complex_check = True
                cti = curr.trial_index()
                if simple and event in ['center_in', 'side_out'] and \
                        (self.exp_complexity[cti] or self.struct_complexity[cti]):
                    complex_check = False

                if (not saliency) or (curr.saliency is not None and curr.saliency in sals) \
                        and complex_check:
                    event_times.append(curr.etime)
                    trials.append(curr.trial_index())
        event_times = np.array(event_times)
        trials = np.array(trials)
        if saliency:
            if simple:
                assert len(event_times) == np.sum(self.exp_complexity)
            else:
                assert len(event_times) == self.trialN
        return event_times, trials

    def get_trial_features(self):
        pass

    def get_inter_trial_stats(self, option='MT'):
        """
        :param option:
            'MT': movement times (vigor)
            'ITI_full': full ITI for decay
            'MT_full': movement times
        :return:
        """
        side_out_firsts, _ = self.get_event_times('side_out__first', False, True)
        initiates, _ = self.get_event_times('center_in', False, True)
        outcomes, _ = self.get_event_times('outcome', False, True)
        #
        if option == 'MT_full':
            results = initiates - side_out_firsts
        elif option == 'ITI_full':
            results = np.zeros(self.trialN)
            results[1:] = initiates[1:] - outcomes[:-1]
        else:
            raise NotImplementedError(f"{option} not implemented")
        return results


class EventNode:
    def __init__(self, event, etime, trial, ecode):
        self.event = event
        self.trial = trial
        self.etime = etime
        self.ecode = ecode # For debug purpose
        # Use "" for Null
        self.MLAT = ""
        #self.OLAT = "" OLAT should be more dynamic
        self.saliency = ""
        self.merged = False
        self.next = None
        self.prev = None
        if event is None:
            # Implements a circular LinkedList
            self.sentinel = True
            self.next = self
            self.prev = self
            self.size = 0
        else:
            self.sentinel = False

    def mvmt_dynamic(self):
        pass

    def trial_index(self):
        # 0.5 is ITI but considered in trial 0
        return int(np.ceil(self.trial)) - 1

    def __str__(self):
        return f"EventNode({self.event}, {self.trial}, {self.etime:.1f}ms, {self.ecode})"

    # Methods Reserved For Sentinel Node
    def __len__(self):
        assert self.sentinel, 'must be sentinel node to do this'
        return self.size

    def append(self, event, etime, trial, ecode):
        assert self.sentinel, 'must be sentinel node to do this'
        evn = EventNode(event, etime, trial, ecode)
        old_end = self.prev
        assert old_end.next is self, "what is happening"
        old_end.next = evn
        evn.prev = old_end
        self.prev = evn
        evn.next = self
        self.size += 1
        return evn

    def prepend(self):
        # Not important
        assert self.sentinel, 'must be sentinel node to do this'
        pass

    def remove_node(self, node):
        assert self.sentinel, 'must be sentinel node to do this'
        assert self.size, 'list must be non-empty'
        next_node = node.next
        prev_node = node.prev
        prev_node.next = next_node
        next_node.prev = prev_node
        node.next = None
        node.prev = None
        self.size -= 1

    def get_last(self):
        assert self.sentinel, 'must be sentinel node to do this'
        return self.prev

    def get_first(self):
        assert self.sentinel, 'must be sentinel node to do this'
        return self.next

    def is_empty(self):
        assert self.sentinel, 'must be sentinel node to do this'
        return self.size == 0

In [56]:
trial_event_mat.shape

(10365, 3)

In [121]:
bmat = BehaviorMat(animal, session, hfile)
# bmat.initialize(hfile)

0 center_in initiate
0 center_in initiate True
1 center_in initiate
1 center_in initiate True
EventNode(outcome, 1.0, 41464.4ms, 71.2) EventNode(side_out, 1.5, 43853.6ms, 4.0) EventNode(center_in, 2.0, 44417.5ms, 11.0)
2 center_in initiate
2 center_in initiate True
EventNode(outcome, 2.0, 44868.5ms, 71.2) EventNode(side_out, 2.5, 50770.5ms, 4.0) EventNode(center_in, 3.0, 51219.0ms, 11.0)
3 center_in initiate
3 center_in initiate True
EventNode(outcome, 3.0, 51624.5ms, 71.2) EventNode(side_out, 3.5, 52900.0ms, 4.0) EventNode(center_in, 4.0, 54265.6ms, 11.0)
4 center_in initiate
4 center_in initiate True
EventNode(outcome, 4.0, 54770.6ms, 71.2) EventNode(side_out, 4.5, 55367.1ms, 4.0) EventNode(center_in, 5.0, 55706.1ms, 11.0)
5 center_in initiate
5 center_in initiate True
EventNode(outcome, 5.0, 56168.6ms, 71.2) EventNode(side_out, 5.5, 56797.6ms, 4.0) EventNode(center_in, 6.0, 57228.6ms, 11.0)
6 center_in 
center_in 
center_out 
side_in 
6 side_in  False
NO bueno
6 center_in initiate
6

225 center_in initiate True
EventNode(outcome, 225.0, 1259903.3ms, 71.2) EventNode(side_out, 225.5, 1262142.3ms, 4.0) EventNode(center_in, 226.0, 1262610.3ms, 11.0)
226 center_in initiate
226 center_in initiate True
EventNode(outcome, 226.0, 1263122.3ms, 71.2) EventNode(side_out, 226.5, 1266345.6ms, 4.0) EventNode(center_in, 227.0, 1266798.6ms, 11.0)
227 center_in 
center_in 
center_out 
center_in 
center_out 
side_in 
227 side_in  False
NO bueno
227 center_in initiate
227 center_in initiate True
228 center_in initiate
228 center_in initiate True
EventNode(outcome, 228.0, 1270502.4ms, 71.2) EventNode(side_out, 228.5, 1273197.0ms, 4.0) EventNode(center_in, 229.0, 1273722.0ms, 11.0)
229 center_in 
center_in 
center_out 
229 center_in initiate True
EventNode(outcome, 229.0, 1274200.0ms, 71.1) EventNode(side_out, 229.5, 1274963.0ms, 4.0) EventNode(center_in, 229.5, 1275440.5ms, 1.0)
230 center_in 
center_in 
center_out 
230 center_in initiate True
EventNode(outcome, 230.0, 1276542.5ms, 72.

center_in 
center_out 
center_in 
center_out 
458 center_in initiate True
EventNode(outcome, 458.0, 2389201.8ms, 72.0) EventNode(side_out, 458.5, 2389495.3ms, 6.0) EventNode(center_in, 458.5, 2389889.3ms, 1.0)
459 center_in 
center_in 
center_out 
center_in 
center_out 
side_in 
459 side_in  False
NO bueno
459 center_in initiate
459 center_in initiate True
460 center_in 
center_in 
center_out 
460 center_in initiate True
EventNode(outcome, 460.0, 2395042.8ms, 71.1) EventNode(side_out, 460.5, 2395897.3ms, 4.0) EventNode(center_in, 460.5, 2396488.8ms, 1.0)
461 center_in 
center_in 
center_out 
center_in 
center_out 
side_in 
461 side_in  False
NO bueno
461 center_in initiate
461 center_in initiate True
462 center_in initiate
462 center_in initiate True
EventNode(outcome, 462.0, 2401499.8ms, 71.2) EventNode(side_out, 462.5, 2403505.0ms, 4.0) EventNode(center_in, 463.0, 2403886.0ms, 11.0)
463 center_in initiate
463 center_in initiate True
EventNode(outcome, 463.0, 2404500.1ms, 71.2) EventN

710 center_in initiate
710 center_in initiate True
711 center_in initiate
711 center_in initiate True
EventNode(outcome, 711.0, 3725928.9ms, 71.2) EventNode(side_out, 711.5, 3726741.9ms, 6.0) EventNode(center_in, 712.0, 3727231.4ms, 11.0)
712 center_in initiate
712 center_in initiate True
EventNode(outcome, 712.0, 3727676.4ms, 71.2) EventNode(side_out, 712.5, 3729519.4ms, 6.0) EventNode(center_in, 713.0, 3729955.4ms, 11.0)
713 center_in initiate
713 center_in initiate True
EventNode(outcome, 713.0, 3730797.8ms, 71.2) EventNode(side_out, 713.5, 3731291.3ms, 6.0) EventNode(center_in, 714.0, 3731687.8ms, 11.0)
714 center_in initiate
714 center_in initiate True
EventNode(outcome, 714.0, 3735685.7ms, 71.2) EventNode(side_out, 714.5, 3736232.2ms, 6.0) EventNode(center_in, 715.0, 3736607.2ms, 11.0)
715 center_in 
center_in 
center_out 
side_in 
715 side_in  False
NO bueno
715 center_in initiate
715 center_in initiate True
716 center_in initiate
716 center_in initiate True
EventNode(outcome, 7

927 side_in  False
NO bueno
927 center_in initiate
927 center_in initiate True
928 center_in initiate
928 center_in initiate True
EventNode(outcome, 928.0, 4888389.7ms, 71.2) EventNode(side_out, 928.5, 4889276.8ms, 6.0) EventNode(center_in, 929.0, 4889765.8ms, 11.0)
929 center_in initiate
929 center_in initiate True
EventNode(outcome, 929.0, 4890245.3ms, 71.2) EventNode(side_out, 929.5, 4892134.8ms, 6.0) EventNode(center_in, 930.0, 4893188.8ms, 11.0)
930 center_in initiate
930 center_in initiate True
EventNode(outcome, 930.0, 4893634.8ms, 71.2) EventNode(side_out, 930.5, 4896812.9ms, 6.0) EventNode(center_in, 931.0, 4897293.9ms, 11.0)
931 center_in initiate
931 center_in initiate True
EventNode(outcome, 931.0, 4897739.9ms, 71.1) EventNode(side_out, 931.5, 4899414.9ms, 6.0) EventNode(center_in, 932.0, 4899812.4ms, 11.0)
932 center_in initiate
932 center_in initiate True
EventNode(outcome, 932.0, 4900232.4ms, 71.2) EventNode(side_out, 932.5, 4902498.9ms, 6.0) EventNode(center_in, 933.0, 

In [111]:
for c in np.unique(bmat.choice_sides):
    print(np.sum(bmat.choice_sides == c))

1
549
549


In [105]:
np.sum(bmat.struct_complexity)

393

In [122]:
np.sum(bmat.exp_complexity)

284