# This notebook provides the code used in the  "Real-time detection of neural oscillation bursts allows behaviourally relevant neurofeedback"

## The workflow is as follows:
In MATLAB
    - Timepoints occured beta events are read out from recorded files.[rwds]
    - Timepoints of frametriggers are read out. [frame_times]
    and saved into mat.file
In this Notebook
    -We create a json files for train and eval sets with concanternated lists with paths to flow files
    [samples[timepointofsample]]
    -Iterate through the data to train a StandardScaler and then a SVM model
    -Evaluate our trained models on eval dataset which has not been seen by the model before

In [None]:
#Import libraries
import glob,os,json,joblib,imageio,cv2,time
import matplotlib.pyplot as plt
import numpy as np
from scipy.io import loadmat
from scipy.stats import ks_2samp,ttest_ind
from flowutils import flow_to_image,readFlow,draw_flow
from sklearn.model_selection import train_test_split as train_test_split
from sklearn.decomposition import PCA,IncrementalPCA
from sklearn.metrics import accuracy_score, confusion_matrix,classification_report
from sklearn.svm import SVC,LinearSVC
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import StratifiedKFold,KFold
from sklearn.linear_model import SGDClassifier
from tqdm import tqdm,tqdm_notebook
from scipy.ndimage.filters import gaussian_filter
from scipy.stats import  ttest_ind
from scipy.stats import linregress

In [None]:
#Helper functions
class batchGenerator():
    """Class to create batches for training/evaluation"""
    def __init__(self,pos_file,neg_file,batch_size_per_class=20,num_epochs=10,sample=None,test=0):
        import numpy as np
        import json        
        self.batch_size=batch_size_per_class
        self.num_epochs=num_epochs
        self.epoch_counter=0
        self.batch_counter=0
        with open(pos_file) as fi:
            self.pos_data=json.load(fi)
        with open(neg_file) as fi:
            self.neg_data=json.load(fi)
        self.batchesPerEpoch=len(self.pos_data)//self.batch_size
        self.sample=sample
        self.test=test
        if self.test==0:
            self.shuffle()
    def checkSizes(self):
        if len(self.neg_data)==len(self.pos_data):
            print('OK')
        else:
            raise Exception ('Not equal number of pos and neg samples')            
    def shuffle(self):        
        np.random.shuffle(self.pos_data)
        np.random.shuffle(self.neg_data)
    def createABatch(self):
        data_list={}
        pos=self.pos_data[self.batch_counter*self.batch_size:self.batch_counter*self.batch_size+self.batch_size]
        neg=self.neg_data[self.batch_counter*self.batch_size:self.batch_counter*self.batch_size+self.batch_size]
        self.batch_counter=self.batch_counter+1
        if self.batch_counter==self.batchesPerEpoch:        
            self.shuffle()
            self.batch_counter=0
            self.epoch_counter=self.epoch_counter+1
            print('New Epoch>shuffling data')
        
        indexes2shuffle=np.arange(0,self.batch_size*2,1,dtype=np.int16)
        np.random.shuffle(indexes2shuffle)
        keys=np.zeros((self.batch_size*2,1))
        keys[:self.batch_size]=1
        all_d=pos+neg
        all_d=[all_d[indx] for indx in indexes2shuffle]
        keys=[keys[indx] for indx in indexes2shuffle]
        return all_d,np.squeeze(np.array(keys))
        
    def check(self):
        self.checkSizes()
    def get_specific_sample(self):        
        pos=self.pos_data[self.sample]
        return pos
    
def Get_number_evalSamples(file):
    """ Checks how many samples are in given dataset"""
    with open(file) as fi:
        a=json.load(fi)
    return len(a)

def get_neg_time4frames(labels,nroftimes,fs=976.5625,throwfirst=10,shift=500):
    """Create negative samples randomly but not overlapping with positive samples"""
    neg_times=np.where(labels==0)[0]
    neg_times=neg_times[np.int64(throwfirst*fs):]    #throw away first 10 sec due to baseline
    neg_times=neg_times[:np.int64(-1*fs)] #throw away last sec   
    pos_surr_times=np.reshape(np.tile(np.argwhere(labels==1).T,(shift*2+1,1)).T-np.arange(-shift,shift+1,1),-1)
    neg_times=np.setdiff1d(neg_times,pos_surr_times)    
    neg_time=np.random.choice(neg_times,nroftimes)/fs 
    neg_time.sort()
    recheck=1
    while recheck:
        if np.any(np.diff(neg_time)<2):
            nr2BEreplaced=np.sum(np.diff(neg_time)<2)
            neg_time=neg_time[np.insert(np.diff(neg_time)>2,0,1)]
            neg_time=np.insert(neg_time,-1,np.random.choice(neg_times,nr2BEreplaced)/fs)
            neg_time.sort()            
        else:        
            recheck=0                
    return neg_time

def split_json(name):
    """Splits dataset into train and eval"""
    files = glob.glob('./json/*'+name)
    for f in files:
        with open(f, 'r') as fi:
            data = json.load(fi)
        np.random.shuffle(data) #shuffle for a random train/eval split
        data_train = data[:int(round(len(data)*0.8))]
        data_test = [x for x in data if x not in data_train]
        file_name = os.path.splitext(os.path.basename(f))[0]
        with open('./json/splitted2/%s_train.json' % file_name, 'w') as fo:
            json.dump(data_train, fo)
        with open('./json/splitted2/%s_eval.json' % file_name, 'w') as fo:
            json.dump(data_test, fo)
            
def load_flow(data_list):
    """load flow files"""
    X=320
    Y=240
    path2flow_data='/media/deeplearning/BCB24522B244E30E/Neurofeedback/'
    flow_data=np.empty((len(data_list),len(data_list[0]),Y,X,2),dtype=np.float32)
    for flow_id in range(len(data_list)):
        for frame_id,frame_name in enumerate(data_list[flow_id]):
            try:
                flow_data[flow_id,frame_id,:,:,:]=np.float32(readFlow(path2flow_data+frame_name))[:Y,:X,:]
            except:
                frame_name = frame_name[:-17] + '%08d-%08d.flo' %((int(frame_name[-10:-4])-1),int(frame_name[-10:-4]))
                flow_data[flow_id,frame_id,:,:,:]=np.float32(readFlow(path2flow_data+frame_name))[:Y,:X,:]
    return flow_data

def FlowNames2PNG(data_list):
    """Get corresponding png filenames"""
    png_list=list()
    for sample in data_list:
        time_slice=list()
        for time_step in sample:
            time_slice.append(time_step[:-11]+'.png')
        png_list.append(time_slice)
    return png_list


In [None]:
"""
This cell reads the frame and beta-event timepoints from .mat files,
calculates corresponding framenumbers around the beta event,
creates negative examples,
writes samples into json file,
splits json into train/eval set,
""" 
lfp_fs=976.5625 #Sampling frequency of beta_reward frame_times

for ratNr in [373]: #iterate over rats
    for runId in [1,2,3]: #iterate over kfold splits
        #Create jsons for all_training 
        shiftB=1.3 # time in sec relative to reward when our samples start
        frames_for_window=50 # how many frames are in individual samples
        
        #path to mat files
        path2matfiles='/mat_files_beta/'        
        allsessions=glob.glob(path2matfiles+'data4python*') #
        
        #which sessions to include
        sessions=['G373_190307_1','G373_190307_2','G373_190308_1','G373_190310_1','G373_190310_2','G373_190311_1','G373_190312_1','G373_190313_1','G373_190314_1']
        flow_pos=list()
        flow_neg=list()
        for session in sessions: #iterate over sessions
            sess=[s for s in allsessions if session in s][0] # find the correspondind mat file
            print(sess)
            mat_file=loadmat(sess)
            rwrd_times=mat_file['rwrds']
            frame_times=np.squeeze(mat_file['frame_times'])
            rwrd_times=rwrd_times[(rwrd_times[:,1] - frame_times[-1])<-1,:] #if there is reward too close to video end
            
            path2flow_data='/media/deeplearning/'+session+'/'
           
            # create a vector with 1 where there are betaevents to2used to get negative timepoints
            beta_events_vector=np.zeros((np.ceil((rwrd_times[-1,1]+2)*lfp_fs).astype(np.int64)),np.bool)
            for rwrd in rwrd_times:
                beta_events_vector[(rwrd[0]*lfp_fs).astype(np.int64):(rwrd[1]*lfp_fs).astype(np.int64)]=1
            # create neg samples from randomly choose non overlapping timepoints    
            neg_times=get_neg_time4frames(beta_events_vector,int(rwrd_times.shape[0]),fs=976.5625,throwfirst=frame_times[0],shift=3000)
            neg_times.sort()
            
            # find the closest frame to the beta_event/neg_sample
            frame_nr_pos=list()
            frame_nr_neg=list()
            for rwrd_time,neg_time in zip(rwrd_times[:,0],neg_times):
                frame_nr_pos.append(np.argmin(np.power(frame_times-(rwrd_time-shiftB),2)))
                frame_nr_neg.append(np.argmin(np.power(frame_times-neg_time,2)))
            
            # We create a json file with paths to files but not load them yet into memory.. 
            
            # we create a list of all samples, each sample is a list of all flo files in this sample
            for frame_id,frame_nr in enumerate(frame_nr_pos):
                f_list=list()
                for flow_id in range(frames_for_window):        
                    f=frame_nr+flow_id
                    flow_file_name=session+'/cam2_all/%06d-%06d.flo' % (f , (f+1))
                    png_file_name=session+'/cam2_all/%06d.png' % (f)   
                    f_list.append(flow_file_name)
                flow_pos.append(f_list)
            for frame_id,frame_nr in enumerate(frame_nr_neg):
                f_list=list()
                for flow_id in range(frames_for_window):        
                    f=frame_nr+flow_id
                    flow_file_name=session + '/cam2_all/%06d-%06d.flo' % (f , (f+1))
                    png_file_name=session+'/cam2_all/%06d.png' % (f)  
                    f_list.append(flow_file_name)
                flow_neg.append(f_list)
        #Save the jsons        
        with open('./json/ALL_flow_pos_%d_%0.1f_run%d_rat%d.json'%(frames_for_window,abs(shiftB),runId,ratNr),'w') as fi:
            json.dump(flow_pos,fi)
        with open('./json/ALL_flow_neg_%d_%0.1f_run%d_rat%d.json'%(frames_for_window,abs(shiftB),runId,ratNr),'w') as fi:
            json.dump(flow_neg,fi)            
        #Split jsons into train/eval
        split_json('ALL_flow_*_%d_%0.1f_run%d_rat%d.json' %(frames_for_window,abs(shiftB),runId,ratNr))

## Now we have prepared out datasets and are ready to train some SVM
We first need to StandardScale the data to improve SVM learning, this is also done in batches due to large datasize
then we train the model in batches
as a control we use a model trained on data with shuffled identity

In [None]:
ratNrs=[373] #371 done
shiftB=1.3
frames_for_window=50
for ratNr in ratNrs:
    for runId in [1,2,3]:       
        print('Working on rat %d run %d' %(ratNr,runId)) 
        # load the train data(only paths)
        pos_file='/home/deeplearning/json/splitted2/ALL_flow_pos_%d_%0.1f_run%d_rat%d_train.json'%(frames_for_window,abs(shiftB),runId,ratNr)
        neg_file='/home/deeplearning/json/splitted2/ALL_flow_neg_%d_%0.1f_run%d_rat%d_train.json'%(frames_for_window,abs(shiftB),runId,ratNr)
        
        clf = SGDClassifier(loss="hinge", penalty="l2", fit_intercept=False,tol=1e-3)       
        clfS = SGDClassifier(loss="hinge", penalty="l2", fit_intercept=False,tol=1e-3) #SVM for shuffled data

        X=320
        Y=240
        with open(pos_file) as fi:
            pos_data=json.load(fi)
        import time
        
        #prepare a data scaler
        batch_scaling=70 #batch sizes
        batch_training=70
        batch_eval=70        
        epochs=60# aim for total 10e6 samples
        
        # create a batchgenerators
        gen_scaler=batchGenerator(pos_file,neg_file,batch_size_per_class=batch_scaling)
        gen_train=batchGenerator(pos_file,neg_file,batch_size_per_class=batch_training)
        gen_eval=batchGenerator(pos_file[:-10]+'eval.json',neg_file[:-10]+'eval.json',batch_size_per_class=batch_eval)
        
        acc=list()
        acc_shuffled=list()
        t0=time.time()

        #Check if we already have a Scaler trained for this dataset
        try:
            scaler=joblib.load('./json/Scaler_%d_%0.1f_rat%d.joblib'%(frames_for_window,abs(shiftB),ratNr))
            print('Found a Scaler, skip to training')
        except FileNotFoundError:
            #make a new scaler
            print('No Scaler found, fitting one..')
            scaler=StandardScaler()
            for batch_id in tqdm(range(len(pos_data)//batch_scaling)): 
                [data_list,_]=gen_scaler.createABatch()
                t2=time.time()
                train_batch=load_flow(data_list)
                print("Time loading %d"%(time.time()-t2))
                train_batch=np.reshape(train_batch,(len(data_list)*len(data_list[0]),-1))
                scaler.partial_fit(train_batch)   
                print("Time passed %d"%(time.time()-t0))
            #save the trained scaler
            joblib.dump(scaler,'./json/Scaler_%d_%0.1f_rat%d.joblib'%(frames_for_window,abs(shiftB),ratNr))
        
        t1=time.time()
        for batch_id in tqdm(range((len(pos_data)//batch_training)*epochs)): 
            [data_list,keys]=gen_train.createABatch()
            train_batch=load_flow(data_list) #load a batch of data into memory
            train_batch=np.reshape(train_batch,(len(data_list)*len(data_list[0]),-1)) #reshape to fit scaler
            train_batch=scaler.transform(train_batch)
            train_batch=np.reshape(train_batch,(len(data_list),len(data_list[0]),Y,X,-1)) #reshape to fit model 
            train_batch=np.reshape(train_batch,(train_batch.shape[0],-1)) #liniarize
            clf.partial_fit(train_batch,keys,classes=np.unique(keys))  #train model on batch          

            keys_shuffled=np.copy(keys)
            np.random.shuffle(keys_shuffled) # shuffle the keys for nonsense training
            clfS.partial_fit(train_batch,keys_shuffled,classes=np.unique(keys)) #train nonsense
            if np.remainder(batch_id,20) ==0: #evaluate the net every 20 batches
                [data_list,keys]=gen_eval.createABatch()
                train_batch=load_flow(data_list)
                train_batch=np.reshape(train_batch,(len(data_list)*len(data_list[0]),-1))
                train_batch=scaler.transform(train_batch)
                train_batch=np.reshape(train_batch,(len(data_list),len(data_list[0]),Y,X,-1))
                train_batch=np.reshape(train_batch,(train_batch.shape[0],-1))
                res=clf.predict(train_batch)
                res_shuffled=clfS.predict(train_batch)
                acc.append(accuracy_score(keys,res))
                acc_shuffled.append(accuracy_score(keys,res_shuffled))
                print('Step:%d Acc:%f'%(batch_id,accuracy_score(keys,res)))
                print('Shuffled Acc:%f'%(accuracy_score(keys,res_shuffled)))
        joblib.dump(clf,'./json/Model_%d_%0.1f_run%d_rat%d.joblib'%(frames_for_window,abs(shiftB),runId,ratNr))
        joblib.dump(clfS,'./json/ModelS_%d_%0.1f_run%d_rat%d.joblib'%(frames_for_window,abs(shiftB),runId,ratNr))     


In [None]:
#Evaluate the models

resultsPred={}
resultsPred_S={}

time_att=list()
time_att_all=list()

ratNrs=[371,373,206]

for ratNr in ratNrs:
    all_m=list()
    all_m_S=list()
    for runId in [100,101,102]:
        t0 = time.time()        
        X=320
        Y=240
        pos_file='/home/deeplearning/json/splitted2/ALL_flow_pos_%d_%0.1f_run%d_rat%d_train.json'%(frames_for_window,abs(shiftB),runId,ratNr)
        neg_file='/home/deeplearning/json/splitted2/ALL_flow_neg_%d_%0.1f_run%d_rat%d_train.json'%(frames_for_window,abs(shiftB),runId,ratNr)
        scaler=joblib.load('./json/Scaler_%d_%0.1f_rat%d.joblib'%(frames_for_window,abs(shiftB),ratNr))
        clf=joblib.load('./json/Model_%d_%0.1f_run%d_rat%d.joblib'%(frames_for_window, abs(shiftB), runId, ratNr))
        clfS=joblib.load('./json/ModelS_%d_%0.1f_run%d_rat%d.joblib'%(frames_for_window,abs(shiftB),runId,ratNr))

        samples2evaluate=Get_number_evalSamples(pos_file[:-10]+'eval.json')
        samplesperbatch=samples2evaluate//4    
        gen_eval=batchGenerator(pos_file[:-10]+'eval.json',neg_file[:-10]+'eval.json',batch_size_per_class=samplesperbatch,test=1)
        time_courses=list()
        time_courses_all=list()
        spatial_att=list()
        spatial_att_shuff=list()
        for batch_id in tqdm(range(samples2evaluate//samplesperbatch)):
            [data_list,keys]=gen_eval.createABatch()
            train_batch=load_flow(data_list)
            train_batch=np.reshape(train_batch,(len(data_list)*len(data_list[0]),-1))
            train_batch=scaler.transform(train_batch)
            train_batch=np.reshape(train_batch,(len(data_list),len(data_list[0]),Y,X,-1))
            train_batch=np.reshape(train_batch,(train_batch.shape[0],-1))
            res=clf.predict(train_batch)
            resS=clfS.predict(train_batch)
            print('Acc:%f'%(accuracy_score(keys,res)))
            print('Acc:%f'%(accuracy_score(keys,resS)))

            all_m.append(accuracy_score(keys,res))
            all_m_S.append(accuracy_score(keys,resS))
            
            attention=train_batch*np.squeeze(clf.coef_)
            attention=np.reshape(attention,(len(data_list),len(data_list[0]),Y,X,-1))
            
            correctly_predPos=np.squeeze(np.argwhere((res==keys)*(res==1)))
            time_course_all=np.sum(np.sum(np.sum(np.clip(attention,a_min=0,a_max=None),axis=-1),axis=-1),axis=-1)
            time_courses_all.append(time_course_all)
            attention=attention[correctly_predPos,:,:,:,:]      
            time_course_pos=np.sum(np.sum(np.sum(np.clip(attention,a_min=0,a_max=None),axis=-1),axis=-1),axis=-1)
            time_courses.append(time_course_pos)
        print('TAcc:%f'%np.mean(all_m))
        print('TAcc:%f'%np.mean(all_m_S))        
        a=np.concatenate(time_courses,axis=0)
        b=np.concatenate(time_courses_all,axis=0)
        time_att.append(a)
        time_att_all.append(b)
        
    resultsPred[str(ratNr)]=[np.mean(i) for i in [all_m[0:3], all_m[3:6], all_m[9:]]]
    resultsPred_S[str(ratNr)]=[np.mean(i) for i in [all_m_S[0:3], all_m_S[3:6], all_m_S[9:]]]


In [None]:
#plot the results of evaluation
plt.figure(18)
plt.clf()
plt.ylim([0.45,0.68])
plt.xlim([-0.2,1.3])
ax = plt.gca()
ax.plot([1,0.05],[np.mean(resultsPred['206']),np.mean(resultsPred_S['206'])],linestyle='-',marker='^',color=[0.5,0.5,0.5])
ax.plot([1,0.05],[np.mean(resultsPred['371']),np.mean(resultsPred_S['371'])],linestyle='-',marker='v',color=[0.5,0.5,0.5])
ax.plot([1,0.05],[np.mean(resultsPred['373']),np.mean(resultsPred_S['373'])],linestyle='-',marker='s',color=[0.5,0.5,0.5])
ax.plot([1.05,0],[np.mean([resultsPred['206'],resultsPred['371'],resultsPred['373']]),np.mean([resultsPred_S['206'],resultsPred_S['371'],resultsPred_S['373']])],linestyle='-',marker='o',color='k',linewidth=2)
plt.errorbar(0,np.mean([resultsPred_S['206'],resultsPred_S['371'],resultsPred_S['373']]),yerr=np.std([np.mean(resultsPred_S['206']),np.mean(resultsPred_S['371']),np.mean(resultsPred_S['373'])])/np.sqrt(3),color='k',capsize=2)
plt.errorbar(1.05,np.mean([resultsPred['206'],resultsPred['371'],resultsPred['373']]),yerr=np.std([np.mean(resultsPred['206']),np.mean(resultsPred['371']),np.mean(resultsPred['373'])])/np.sqrt(3),color='k',capsize=2)
ax.legend(['Rat1','Rat2','Rat3','Mean'],loc='lower right')
ax.set_xticks([0.05,1])
ax.set_xticklabels(['Shuffled','Real'])
print(ttest_ind(np.reshape([np.mean(resultsPred['206']),np.mean(resultsPred['371']),np.mean(resultsPred['373'])],-1),np.reshape([np.mean(resultsPred_S['206']),np.mean(resultsPred_S['371']),np.mean(resultsPred_S['373'])],-1),equal_var=False))
ax.plot([0,1.05],[0.65,0.65],'k',linewidth=0.7)
ax.plot([0,0],[0.64,0.65],'k',linewidth=0.7)
ax.plot([1.05,1.05],[0.64,0.65],'k',linewidth=0.7)
plt.rcParams.update({'font.size': 14})
plt.text(1.05/2,0.655,'*',fontsize=13)


In [None]:
# Code for figure 3

plt.figure(181,figsize=(11,9))
plt.clf()
ax1 = plt.subplot2grid((9, 11), (0, 0), colspan=4,rowspan=4)
ax2 = plt.subplot2grid((9, 11), (5, 0), colspan=4,rowspan=4)
ax3 = plt.subplot2grid((9, 11), (0, 5),colspan=3,rowspan=3)
ax4 = plt.subplot2grid((9, 11), (3, 5),colspan=3,rowspan=3)
ax5 = plt.subplot2grid((9, 11), (6, 5),colspan=3,rowspan=3)
ax6 = plt.subplot2grid((9, 11), (0, 8),colspan=3,rowspan=3)
ax7 = plt.subplot2grid((9, 11), (3, 8),colspan=3,rowspan=3)
ax8 = plt.subplot2grid((9, 11), (6, 8),colspan=3,rowspan=3)
xticks_time=[-1,-0.6,-0.2]
y_lims_att=[0.1,1.1]
yticks_time=[0.2,0.4,0.6,0.8,1]

#Figure1
ax1.axes.set_ylim([0.45,0.68])
ax1.axes.set_xlim([-0.2,1.3])
ax1.plot([1,0.05],[np.mean(resultsPred['206']),np.mean(resultsPred_S['206'])],linestyle='-',marker='^',color=[0.5,0.5,0.5])
ax1.plot([1,0.05],[np.mean(resultsPred['371']),np.mean(resultsPred_S['371'])],linestyle='-',marker='v',color=[0.5,0.5,0.5])
ax1.plot([1,0.05],[np.mean(resultsPred['373']),np.mean(resultsPred_S['373'])],linestyle='-',marker='s',color=[0.5,0.5,0.5])
ax1.plot([1.05,0],[np.mean([resultsPred['206'],resultsPred['371'],resultsPred['373']]),np.mean([resultsPred_S['206'],resultsPred_S['371'],resultsPred_S['373']])],linestyle='-',marker='o',color='k',linewidth=2)
ax1.errorbar(0,np.mean([resultsPred_S['206'],resultsPred_S['371'],resultsPred_S['373']]),yerr=np.std([np.mean(resultsPred_S['206']),np.mean(resultsPred_S['371']),np.mean(resultsPred_S['373'])])/np.sqrt(3),color='k',capsize=2)
ax1.errorbar(1.05,np.mean([resultsPred['206'],resultsPred['371'],resultsPred['373']]),yerr=np.std([np.mean(resultsPred['206']),np.mean(resultsPred['371']),np.mean(resultsPred['373'])])/np.sqrt(3),color='k',capsize=2)
ax1.legend(['Rat1','Rat2','Rat3','Mean'],loc='lower right')
ax1.set_xticks([0.05,1])
ax1.set_xticklabels(['Shuffled','Real'])
ax1.plot([0,1.05],[0.65,0.65],'k',linewidth=0.7)
ax1.plot([0,0],[0.64,0.65],'k',linewidth=0.7)
ax1.plot([1.05,1.05],[0.64,0.65],'k',linewidth=0.7)
ax1.text(1.05/2,0.655,'*',fontsize=13)
ax1.set_ylabel('Classification accuracy',fontsize=17)

#Figure2
t_values=[-(shiftB-i/50)+0.2 for i in range(frames_for_window)]
values=np.concatenate([gaussian_filter(t1,(0,2))  for t1 in time_att],axis=0)
sem=np.std([np.mean(values[0:3]),np.mean(values[3:6]),np.mean(values[6:])],axis=0)/np.sqrt(3)
ax2.plot(t_values,np.mean(values,axis=0)/np.max(np.mean(values,axis=0)),'k')
ax2.fill_between(t_values, (np.mean(values,axis=0)+sem)/np.max(np.mean(values,axis=0)), (np.mean(values,axis=0)-sem)/np.max(np.mean(values,axis=0)), facecolor='k', alpha=0.2)
slope, intercept, r_value, p_value, std_err=linregress(t_values,np.mean(values,axis=0))
regress_vals=[slope*t_values[0]+intercept,slope*t_values[-1]+intercept]
ax2.plot([t_values[0],t_values[-1]],[regress_vals[0]/np.max(np.mean(values,axis=0)),regress_vals[-1]/np.max(np.mean(values,axis=0))],'k--')
ax2.legend(['Attention','Linear regression','SEM'],loc='upper left')
ax2.set_xlabel('Time before burst (s)',fontsize=17)
ax2.set_ylabel('Attention (normalized)',fontsize=17)
ax2.set_xticks(xticks_time)
ax2.set_yticks([0.75,0.8,0.85,0.9,0.95,1])

#Figure3 Rat206
ratNr=206
runId=101
if ratNr==373 and runId==102:
    runId=103
X=320
Y=240#215
pos_file='/home/deeplearning/json/splitted2/ALL_flow_pos_%d_%0.1f_run%d_rat%d_train.json'%(frames_for_window,abs(shiftB),runId,ratNr)
neg_file='/home/deeplearning/json/splitted2/ALL_flow_neg_%d_%0.1f_run%d_rat%d_train.json'%(frames_for_window,abs(shiftB),runId,ratNr)
scaler=joblib.load('./json/Scaler_%d_%0.1f_rat%d.joblib'%(frames_for_window,abs(shiftB),ratNr))
clf=joblib.load('./json/Model_%d_%0.1f_run%d_rat%d.joblib'%(frames_for_window, abs(shiftB), runId, ratNr))
with open(pos_file[:-10]+'eval.json') as fi:
    pos_data=json.load(fi)
Nr=[idx for idx,i in enumerate(pos_data) if '016842' in i[0]]
gen_eval=batchGenerator(pos_file[:-10]+'eval.json',neg_file[:-10]+'eval.json',sample=Nr[0],test=1)
data_list=[gen_eval.get_specific_sample()]
train_batch=load_flow(data_list)
png_list=FlowNames2PNG(data_list)
train_batch=np.reshape(train_batch,(len(data_list[0]),-1))
train_batch=scaler.transform(train_batch)
train_batch=np.reshape(train_batch,(1,len(data_list[0]),Y,X,-1))
train_batch=np.reshape(train_batch,(train_batch.shape[0],-1))
res=clf.predict(train_batch)
attention=train_batch*np.squeeze(clf.coef_)
attention=np.reshape(attention,(len(data_list),len(data_list[0]),Y,X,-1))
attention=gaussian_filter(attention,(0,2,3,3,0))
time_course_all=np.sum(np.sum(np.sum(np.clip(attention,a_min=0,a_max=None),axis=-1),axis=-1),axis=-1)
path2flow_data='/media/deeplearning/Neurofeedback/'
path2save='/media/deeplearning/temp_images/Attention/'
t_values=[-(shiftB-i/50)+0.2 for i in range(50)]

time_id=np.argmax(zScore(time_course_all[0]))
ax3.plot(t_values,time_course_all[0]/np.max(time_course_all[0]),'k')
ax3.plot(t_values[time_id],(time_course_all[0]/np.max(time_course_all[0]))[time_id],'rx')
ax3.set_xticks(xticks_time)
ax3.axes.set_ylim(y_lims_att)
ax3.set_yticks(yticks_time)

ax6.imshow(imageio.imread(path2flow_data+png_list[0][time_id]))
ax6.imshow(np.squeeze(np.sum(np.clip(attention[0,time_id,:,:,:],a_min=0,a_max=None),axis=-1)),alpha=0.4,cmap='hot')
ax6.axis('off')
ax6.axes.set_title('Rat1')

#rat 371
ratNr=371
runId=101
X=320#298 #cut a useless piece
Y=240#215
pos_file='/home/deeplearning/json/splitted2/ALL_flow_pos_%d_%0.1f_run%d_rat%d_train.json'%(frames_for_window,abs(shiftB),runId,ratNr)
neg_file='/home/deeplearning/json/splitted2/ALL_flow_neg_%d_%0.1f_run%d_rat%d_train.json'%(frames_for_window,abs(shiftB),runId,ratNr)
scaler=joblib.load('./json/Scaler_%d_%0.1f_rat%d.joblib'%(frames_for_window,abs(shiftB),ratNr))
clf=joblib.load('./json/Model_%d_%0.1f_run%d_rat%d.joblib'%(frames_for_window, abs(shiftB), runId, ratNr))
with open(pos_file[:-10]+'eval.json') as fi:
    pos_data=json.load(fi)
Nr=[idx for idx,i in enumerate(pos_data) if '37039' in i[0]]
gen_eval=batchGenerator(pos_file[:-10]+'eval.json',neg_file[:-10]+'eval.json',sample=Nr[0],test=1)
data_list=[gen_eval.get_specific_sample()]
train_batch=load_flow(data_list)
png_list=FlowNames2PNG(data_list)
train_batch=np.reshape(train_batch,(len(data_list[0]),-1))
train_batch=scaler.transform(train_batch)
train_batch=np.reshape(train_batch,(1,len(data_list[0]),Y,X,-1))
train_batch=np.reshape(train_batch,(train_batch.shape[0],-1))
res=clf.predict(train_batch)
attention=train_batch*np.squeeze(clf.coef_)
attention=np.reshape(attention,(len(data_list),len(data_list[0]),Y,X,-1))
attention=gaussian_filter(attention,(0,2,3,3,0))
time_course_all=np.sum(np.sum(np.sum(np.clip(attention,a_min=0,a_max=None),axis=-1),axis=-1),axis=-1)
time_id=np.argmax(zScore(time_course_all[0]))
ax4.plot(t_values,time_course_all[0]/np.max(time_course_all[0]),'k')
ax4.plot(t_values[time_id],(time_course_all[0]/np.max(time_course_all[0]))[time_id],'rx')
ax4.set_ylabel('Attention (normalized)',fontsize=17)
ax4.set_xticks(xticks_time)
ax4.axes.set_ylim(y_lims_att)
ax4.set_yticks(yticks_time)

ax7.imshow(imageio.imread(path2flow_data+png_list[0][time_id]))
ax7.imshow(np.squeeze(np.sum(np.clip(attention[0,time_id,:,:,:],a_min=0,a_max=None),axis=-1)),alpha=0.4,cmap='hot')
ax7.axis('off')
ax7.axes.set_title('Rat2')
ratNr=373
runId=100
X=320
Y=240
pos_file='/home/deeplearning/json/splitted2/ALL_flow_pos_%d_%0.1f_run%d_rat%d_train.json'%(frames_for_window,abs(shiftB),runId,ratNr)
neg_file='/home/deeplearning/json/splitted2/ALL_flow_neg_%d_%0.1f_run%d_rat%d_train.json'%(frames_for_window,abs(shiftB),runId,ratNr)
scaler=joblib.load('./json/Scaler_%d_%0.1f_rat%d.joblib'%(frames_for_window,abs(shiftB),ratNr))
clf=joblib.load('./json/Model_%d_%0.1f_run%d_rat%d.joblib'%(frames_for_window, abs(shiftB), runId, ratNr))
with open(pos_file[:-10]+'eval.json') as fi:
    pos_data=json.load(fi)
Nr=[idx for idx,i in enumerate(pos_data) if '87158' in i[0]]
gen_eval=batchGenerator(pos_file[:-10]+'eval.json',neg_file[:-10]+'eval.json',sample=Nr[0],test=1)
data_list=[gen_eval.get_specific_sample()]
train_batch=load_flow(data_list)
png_list=FlowNames2PNG(data_list)
train_batch=np.reshape(train_batch,(len(data_list[0]),-1))
train_batch=scaler.transform(train_batch)
train_batch=np.reshape(train_batch,(1,len(data_list[0]),Y,X,-1))
train_batch=np.reshape(train_batch,(train_batch.shape[0],-1))
res=clf.predict(train_batch)
attention=train_batch*np.squeeze(clf.coef_)
attention=np.reshape(attention,(len(data_list),len(data_list[0]),Y,X,-1))
attention=gaussian_filter(attention,(0,2,3,3,0))
time_course_all=np.sum(np.sum(np.sum(np.clip(attention,a_min=0,a_max=None),axis=-1),axis=-1),axis=-1)
time_id=np.argmax(zScore(time_course_all[0]))
ax5.plot(t_values,time_course_all[0]/np.max(time_course_all[0]),'k')
ax5.plot(t_values[time_id],(time_course_all[0]/np.max(time_course_all[0]))[time_id],'rx')
ax5.set_xlabel('Time before burst (s)',fontsize=17)
ax5.set_xticks(xticks_time)
ax5.axes.set_ylim(y_lims_att)
ax5.set_yticks(yticks_time)
ax8.imshow(imageio.imread(path2flow_data+png_list[0][time_id]))
ax8.imshow(np.squeeze(np.sum(np.clip(attention[0,time_id,:,:,:],a_min=0,a_max=None),axis=-1)),alpha=0.4,cmap='hot')
ax8.axis('off')
ax8.axes.set_title('Rat3')
plt.subplots_adjust(wspace=0.2, hspace=0.1)
plt.text(-1016,-700,'a',fontsize=20,fontweight='bold')
plt.text(-439,-700,'c',fontsize=20,fontweight='bold')
plt.text(-1016,-140,'b',fontsize=20,fontweight='bold')
ax1.text(-0.34,0.508,'chance \nlevel',fontsize=10,horizontalalignment='center',fontweight='bold')