In [1]:
import numpy as np
import os 
import sys
import importlib
import copy

In [2]:
util_path = 'C:/ASM/Dropbox/Developments/Jupyter/Eating/myutils' if 'C:' in os.getcwd() else './myutils'
sys.path.append(util_path)
import my_file_utils as mfileu

In [3]:
all_proba = mfileu.read_file('generated_for_result', 'all_proba.pkl')
all_pct_proba = mfileu.read_file('generated_for_result', 'all_pct_proba.pkl')

In [4]:
def detect_bites(ix_proba, proba_th):
    min_interval = 2*16
    
    count = len(ix_proba)
    ix = ix_proba[:, 0]
    proba = ix_proba[:, 1]    
    
    peaks = []    
    for i in range(1, count-1):
        if proba[i-1]<proba[i]<=proba[i+1] and proba[i]>=proba_th:        
            peaks.append(i)
    #print(len(peaks))
    
    if len(peaks)==0:
        return []
    elif len(peaks)==1:
        return np.array([ix[peaks[0]]])
    
    while True:
        count = len(peaks)
        if count<=1:
            break
        
        bites = []
        
        if proba[peaks[0]]>proba[peaks[1]] or ix[peaks[1]] - ix[peaks[0]]>=min_interval:
            bites.append(peaks[0])            
        
        for i in range(1, count-1):
            cond1 = (proba[peaks[i]]>proba[peaks[i+1]]) and proba[peaks[i]]>proba[peaks[i-1]]
            cond2 = (ix[peaks[i+1]] - ix[peaks[i]]>=2*16) and (ix[peaks[i]] - ix[peaks[i-1]]>=min_interval)
            if cond1 or cond2:
                bites.append(peaks[i])
    
        if proba[peaks[count-1]]>proba[peaks[count-2]] or ix[peaks[count-1]] - ix[peaks[count-2]]>=min_interval:
            bites.append(peaks[count-1])
    
        if len(bites)==len(peaks):
            break
            
        peaks = bites
    
    
    if len(peaks)==0:
        return []
    
    indices = [ix[i] for i in peaks]
    return np.array(indices).astype(int)

In [5]:
def get_frames_bites(ix_proba, percentile_proba, percentile_proba_val, pct_proba=None, off_on=None, blank_array = []):
    assert off_on in [None, "offline", "online"]      
    assert percentile_proba in ["percentile", "proba"]    
    
    ba = blank_array
    frames = copy.deepcopy(blank_array)
    bites = copy.deepcopy(blank_array)
    
    for subj in range(len(ba)):
        for sess in range(len(ba[subj])):            
            ix_p = ix_proba[subj][sess][:, :2]
            ix_p[:, 0] = ix_p[:, 0]+40 #add offset
            
            if percentile_proba=='percentile':                
                cond  = (pct_proba[:, 0]==subj) & (pct_proba[:, 1]==sess) & (pct_proba[:, 2]==percentile_proba_val)
                assert np.sum(cond)==1                        
                proba_th = pct_proba[cond, -2] if off_on=="offline" else pct_proba[cond, -1]
            else:
                proba_th = percentile_proba_val
                        
            frames[subj][sess] = ix_p[ix_p[:, 1]>=proba_th, 0]            
            bites[subj][sess] = detect_bites(ix_p, proba_th=proba_th)                            
            
    assert len(frames)>=len(bites)        
    return frames, bites

In [6]:
def get_frames_bites_all(lab_free, clf):
    if lab_free =='lab':
        ba = mfileu.read_file('data', 'lab_data_steven_blank_array.pkl')
    else:
        ba = mfileu.read_file('data', 'free_data_steven_blank_array.pkl')
    
    ix_proba = all_proba[lab_free][clf]
    pct_proba = all_pct_proba[lab_free][clf]
    
    frames = {"proba":{}, "percentile_offline":{}, "percentile_online":{} }
    bites = {"proba":{}, "percentile_offline":{}, "percentile_online":{} }    
        
    for p in range(10, 95, 5):            
        proba = p/100
        print(proba, end=" | ")
        f, b = get_frames_bites(ix_proba, "proba", percentile_proba_val=proba, blank_array=ba)             
        frames["proba"][proba] = f   
        bites["proba"][proba] = b

    for p in range(9800, 10000):
        percentile = p/100                            
        print(percentile, end=" | ")

        f, b = get_frames_bites(ix_proba, "percentile", percentile_proba_val=percentile, pct_proba=pct_proba, off_on="offline", blank_array=ba)
        frames["percentile_offline"][percentile] = f
        bites["percentile_offline"][percentile] = b

        if lab_free == 'free':
            f, b = get_frames_bites(ix_proba, "percentile", percentile_proba_val=percentile, pct_proba=pct_proba, off_on="online", blank_array=ba)
            frames["percentile_online"][percentile] = f
            bites["percentile_online"][percentile] = b
    
    return frames, bites
   

In [7]:
#res_lab_free = mfileu.read_file('generated_for_result', 'all_frames_bites.pkl')
all_frames, all_bites = {}, {}

for lab_free in ['lab', 'free']:    
    frames_clf, bites_clf = {}, {}
    for clf in ['RF', 'our']:        
        print("\n\n--------------", lab_free, clf, "--------------")
        f, b = get_frames_bites_all(lab_free, clf)
        frames_clf[clf] = f
        bites_clf[clf] = b
        
    all_frames[lab_free] = frames_clf
    all_bites[lab_free] = bites_clf
            
mfileu.write_file('generated_for_result', 'all_frames.pkl', all_frames)
mfileu.write_file('generated_for_result', 'all_bites.pkl', all_bites)
print("Done!!!")



-------------- lab RF --------------
0.1 | 0.15 | 0.2 | 0.25 | 0.3 | 0.35 | 0.4 | 0.45 | 0.5 | 0.55 | 0.6 | 0.65 | 0.7 | 0.75 | 0.8 | 0.85 | 0.9 | 98.0 | 98.01 | 98.02 | 98.03 | 98.04 | 98.05 | 98.06 | 98.07 | 98.08 | 98.09 | 98.1 | 98.11 | 98.12 | 98.13 | 98.14 | 98.15 | 98.16 | 98.17 | 98.18 | 98.19 | 98.2 | 98.21 | 98.22 | 98.23 | 98.24 | 98.25 | 98.26 | 98.27 | 98.28 | 98.29 | 98.3 | 98.31 | 98.32 | 98.33 | 98.34 | 98.35 | 98.36 | 98.37 | 98.38 | 98.39 | 98.4 | 98.41 | 98.42 | 98.43 | 98.44 | 98.45 | 98.46 | 98.47 | 98.48 | 98.49 | 98.5 | 98.51 | 98.52 | 98.53 | 98.54 | 98.55 | 98.56 | 98.57 | 98.58 | 98.59 | 98.6 | 98.61 | 98.62 | 98.63 | 98.64 | 98.65 | 98.66 | 98.67 | 98.68 | 98.69 | 98.7 | 98.71 | 98.72 | 98.73 | 98.74 | 98.75 | 98.76 | 98.77 | 98.78 | 98.79 | 98.8 | 98.81 | 98.82 | 98.83 | 98.84 | 98.85 | 98.86 | 98.87 | 98.88 | 98.89 | 98.9 | 98.91 | 98.92 | 98.93 | 98.94 | 98.95 | 98.96 | 98.97 | 98.98 | 98.99 | 99.0 | 99.01 | 99.02 | 99.03 | 99.04 | 99.05 | 99.06 | 99.07 