# Write/Manage animal tag file
## Contains functions to handle tag files
## DO NOT MODIFY if you are not sure
# 
 
#### The tag is extracted from: 1) Last tag in the .tag file 2) The tag in the tag.txt file(old system) 3) directly asked from the user

In [None]:
import os, logging
import glob
import numpy as np
from platform import system as OS
import pandas as pd
import scipy.stats
import datetime
from copy import deepcopy
import fnmatch
from IPython.display import clear_output

if "__file__" not in dir():
    %run Animal_Tags.ipynb
    
    if OS()=='Linux':
        root="/data"
    elif OS()=='Windows':
        root="C:\\data\\"
    else:
        root="/Users/davidrobbe/Documents/Data/"
    
    print('os:',OS(),'\nroot:',root,'\nImport successful!')


## 1. Backend functions

In [None]:
def get_current_animals(days_passed=4):
    now=datetime.date.today()
    thisRoot="/NAS02"
    all_animals=[os.path.basename(path) for path in sorted(glob.glob(os.path.join(thisRoot,"Rat???")))]
    if all_animals==[]:
        logging.warning('NAS02 not mounted!')
        return []
    
    last_modifTimes={}
    animalList=[]
    for animal in all_animals:
        experimentsPath=os.path.join(thisRoot,animal,"Experiments")
        if not os.path.exists(experimentsPath):
            continue
        sessionList=[os.path.basename(expPath) for expPath in glob.glob(os.path.join(thisRoot,animal,"Experiments","Rat???_20??_*"))]
        if not sessionList:
            continue
        sessionList=sorted(sessionList)
        lastSessionDate= datetime.datetime.strptime(sessionList[-1][7:17],'%Y_%m_%d').date()
        if (now-lastSessionDate).days<=days_passed:
            animalList.append(animal)
    
    return animalList

In [None]:
def read_last_line(filePath,maxLineLength):
    """
    This function returns the last line of a text file
    maxLineLength: maximum assumable line length in BYTES
    """
    try:
        with open(filePath,'rb') as f:
            fileSize=os.fstat(f.fileno()).st_size
            if maxLineLength > fileSize:
                maxLineLength=fileSize-1
            f.seek(-abs(maxLineLength)-1,os.SEEK_END)
            lines=f.readlines()
    except Exception as e:
        logging.warning('couldn\'t open file:'+filePath)
        logging.info(repr(e))
        return False
    return lines[-1].decode()

In [None]:
def read_in_file(session,paramName,extension=".behav_param",exclude=None,valueType=str):
    '''
    Use to read from .behav_param or .entrancetimes
    Look for lines containing "paramName" and not containing "exclude"
    Split them by white spaces 
    example: "treadmill speed:     30.00" becomes ["treadmill","speed:","30.00"])
    Return a list of their last element, in the specified valueType (in example: "30.00")
    '''
    behav=session+extension
    if not os.path.exists(behav):
        logging.warning("No file %s"%behav)
        return False
    result=[]
    trials=[0]
    with open(behav,"r") as f:
        for line in f:
            if "Trial #" in line:
                trials.append(int(float(line.split()[-1]))-1)
            if paramName in line:
                if (exclude is not None) and (exclude in line):
                    continue
                res=line.split()[-1]
                #integer or float: replace comma by dots
                if valueType in [int,float]:
                    res=res.replace(",",".")                 
                #integer: convert first to float ("0.00" -> 0.00 -> 0)
                if valueType is int:
                    res=int(float(res))
                #boolean "TRUE" "FALSE"
                elif valueType is bool:
                    res=(res.lower()=="true")
                else:
                    res=valueType(res)
                result.append( (trials[-1],res) )
    out=[np.nan]*(trials[-1]+1)
    for item in result:
        out[item[0]]=item[1]
    return np.asarray(out)

In [None]:
def has_behavior(root,animal,sessionList=None):
    if sessionList is None:
        sessionList=[os.path.basename(expPath) for expPath in 
                     glob.glob(os.path.join(root,animal,"Experiments",animal+'*'))]
        sessionList=sorted(sessionList)
        
    out=[False]*len(sessionList)
    for idx,session in enumerate(sessionList):
        path2files=os.path.join(root,animal,'Experiments',session,session)
        if os.path.exists(path2files+'.entrancetimes') and os.path.exists(path2files+'.behav_param'):
            out[idx]=True
    return out

In [None]:
def is_tag_valid(root,animal):
    tagPath=os.path.join(root,animal,animal+'.tag')
    if not os.path.exists(tagPath):
        return False    #tag not available
    header=read_tag_header(tagPath)
    if isinstance(header,bool):
        return False    #tag header not correct
    if header['name'] != animal:
        return False     #tag animal name not correct
    try: float(header['initialSpeed'])
    except: return False #tag initial speed not correct
    if header['rewardType'] != 'Progressive' and header['rewardType'] != 'Not-Progressive':
        return False     #tag rewardType not correct
    
    
    return True

In [None]:
def read_tag_header(tagPath):
    out=dict()
    try:
        with open(tagPath,'r') as f:
            for line in f:
                if line[0]=='#':
                    items=line.split(':')
                    out[items[0][1:]]=items[1][:-1]
                else:
                    break
    except:
        return False
    return out
    

In [None]:
def write_tag_header(root,animal,rewardType='Progressive',initialSpeed=10):
    PlaceHolders={
        'name':animal,
        'rewardType':rewardType,
        'initialSpeed':initialSpeed
    }
    content="""#info: 
#name:%(name)s
#rewardType:%(rewardType)s
#initialSpeed:%(initialSpeed)s
#option:not used
Sessions\tTag\tSpeed\tType\tEvent\tLabel\n"""%PlaceHolders

    tagPath=os.path.join(root,animal,animal+'.tag')
    try:
        with open(tagPath,'w') as f:
            f.write(content)
    except:
        return False
    return True

def update_tag_header(root,animal,sessionList,overwrite):
    tagPath=os.path.join(root,animal,animal+'.tag')
    if (not is_tag_valid(root,animal)) or overwrite:
        #compute the initial speed
        initialSpeed=scipy.stats.mode(read_in_file(
            os.path.join(root,animal,'Experiments',sessionList[0],sessionList[0]),
            "computed treadmill speed"),nan_policy='omit')[0][0]
        initialSpeed=int(float(initialSpeed))

        #compute the reward type: progressive,not-progressive
        rewardType=np.mean(read_in_file(session=os.path.join(root,animal,'Experiments',sessionList[0],sessionList[0]),
                                        paramName='no reward if interruption',valueType=float))
        if np.isnan(rewardType):
            rewardType='Not-Progressive'
        else:
            rewardType='Progressive'

        isHeaderWritten=write_tag_header(root,animal,rewardType,initialSpeed)
        if not isHeaderWritten:
            logging.error('failed to write the tag header')
            return False
    return True

In [None]:
def write_session_info(tagFile,data):
    try:
        with open(tagFile,'a') as f:
            for session in data:
                line='\t'.join(session)
                f.write('%'+line+'\n')
    except:
        return False
    return True

In [None]:
def write_animal_tag_file(root,animal,until_date='',overwrite=False):
    
    tagFile=os.path.join(root,animal,animal+'.tag')
    
    #compute the valid session list
    sessionList=[os.path.basename(expPath) for expPath in 
                 glob.glob(os.path.join(root,animal,"Experiments",animal+'*'))]
    sessionList=sorted(sessionList)
    try:
        untilDate=datetime.datetime.strptime(until_date,"%Y_%m_%d")
        for idx,session in enumerate(sessionList):
            sessionDate=datetime.datetime.strptime(session,animal+"_%Y_%m_%d_%H_%M")
            if sessionDate > untilDate:
                sessionList=sessionList[:idx]
                break
    except:
        pass
    if len(sessionList)<1:
        logging.warning("no session included:"+animal)
        return False
    withBehav=has_behavior(root,animal,sessionList)
    sessionList=[goodSession for idx,goodSession in enumerate(sessionList) if withBehav[idx] is True]
    if len(sessionList)<1:
        logging.warning("no session included:"+animal)
        return False

    #check or write the tag header
    isHeaderReady=update_tag_header(root,animal,sessionList,overwrite)
    if isHeaderReady is False:
        return False

    #Getting the last written session
    lastLine=read_last_line(tagFile,maxLineLength=200)
    sessionName=lastLine.find('%')
    lastWrittenTag=''
    lastWrittenSpeed=''
    if sessionName>=0:
        lastWrittenSession=lastLine[sessionName+1:].split('\t')[0]
        lastWrittenTag    =lastLine[sessionName+1:].split('\t')[1]
        lastWrittenSpeed  =lastLine[sessionName+1:].split('\t')[2]
        idx=sessionList.index(lastWrittenSession)
        if idx+1<len(sessionList):
            sessionList=sessionList[idx+1:]
        elif idx+1==len(sessionList):
            logging.info('No need to update the tag file')
            return False
        else:
            logging.warning("session list inconsistent!")
            return False
    
    
    #data structure:sessions*[tag|speed|type|event]
    data=[]
    #filling the data matrix
    
    for ID,session in enumerate(sessionList):
        sessionInfo=[]
        
        #===============TAG FILE COLUMNS===========================================
        
        #"Session" name
        sessionInfo.append(session)
        
        #"Tag" value
        if sessionName>=0:
            tag=lastWrittenTag
        elif ID==0:
            try:
                with open(os.path.join(root,animal,'Tag')) as f:
                    tag=f.readline()
            except:
                tag=input("type the tag for "+animal+" and press ENTER:")
        else:
            tag=data[-1][1]
        sessionInfo.append(tag)
        #----------------------------------------------
        #"Speed" value (MUST be integer or a string):
        spdValues=read_in_file(os.path.join(root,animal,'Experiments',session,session),"computed treadmill speed",valueType=float)
        spdMODE=scipy.stats.mode(spdValues,nan_policy='omit')
        spd=spdMODE.mode[0]
        spdCOUNT=spdMODE.count[0]
        if spdCOUNT<len(spdValues)/2:
            sessionInfo.append('var')
        else:
            sessionInfo.append(str(int(spd)))
        #----------------------------------------------
        #"Type" value:
        sessionInfo.append('Good')
        #----------------------------------------------
        #"Event" value:
        event='-'
        if ID>0:
            if sessionInfo[2] != data[-1][2]:
                event="SpeedChange"
        elif sessionName>=0 and lastWrittenTag != sessionInfo[1]:
            event='TagChange'
        elif sessionName>=0 and lastWrittenSpeed != sessionInfo[2]:
            event="SpeedChange"
        sessionInfo.append(event)
        #----------------------------------------------
        #"Label" value: (for later manual manipulation)
        sessionInfo.append('NA')
        #----------------------------------------------
        #Add other columns to the tag file below:
        #...
        #...
        #...
        #sessionInfo.append(newColumn)
        #=================TAG FILE END================================================
        data.append(sessionInfo)
    
    #writing the data to the tag file
    isFileWritten=write_session_info(tagFile,data)
    if isFileWritten is False:
        logging.info("couldn not write")
        return False
    logging.info("tag file is written for: "+animal)
    return True

## 2.Write tag files

In [None]:
if "__file__" not in dir():
    a=write_animal_tag_file('/data','Rat174',until_date="",overwrite=False)
    print(a)

### Write tag files as a batch

In [None]:
def write_animal_tag_batch(root,animalList,until_date,overwrite=False):
    failedAnimals=[]
    for animal in animalList:
        tmp=write_animal_tag_file(root,animal,until_date=until_date,overwrite=overwrite)
        if tmp is False:
            failedAnimals.append(animal)
    logging.info("failed animals:"+failedAnimals)
    logging.info("DONE!")
    

In [None]:
if "__file__" not in dir():
    root="/data"
    animalList=["Rat60"]
    overwrite= True
    until_date=''
    #================
    write_animal_tag_batch(root,animalList,until_date,overwrite=False)

## 3.Front-end Functions to work with tag files

In [None]:
def get_session_profile(root,animal,session):
    tagFile=os.path.join(root,animal,animal+'.tag')
    table=read_tag_table(tagPath=tagFile)
    try:
        index=table['Sessions'].index(session)
    except:
        return {'Sessions':'','Tag':'','Speed':'','Type':'','Event':'','Label':''}
    profile={key:table[key][index] for key in table.keys()}
    
    return profile

In [None]:
def read_tag_header(tagPath):
    out=dict()
    try:
        with open(tagPath,'r') as f:
            for line in f:
                if line[0]=='#':
                    items=line.split(':')
                    out[items[0][1:]]=items[1][:-1]
                else:
                    break
    except:
        return False
    return out

In [None]:
def read_tag_table(tagPath,headerSize=range(5)):
    """
    This function return the whole table of sessions in a tag file as a dictionary
    """
    try:
        table=pd.read_csv(tagPath,delim_whitespace=True,skiprows=headerSize)
    except Exception as e:
        logging.warning(repr(e))
        return {'Sessions':[],'Tag':'','Speed':'','Type':'','Event':'','Label':''}
    table.replace(to_replace= {'Sessions': {'%': ''}},regex=True,inplace=True)
    out={label:list(column) for label,column in zip(table.columns.values,table.values.T)}
    return out

In [None]:
def get_session_list(tagFile,profile={},until_date=''):
    """
    This function returns the list of the sessions within a tag file 
    meeting all the conditions in 'profile', MOST keys in 'profile' could be a list of accepted conditions
    Exception: keys corresponding to tag header
    EX: profile={'Speed':['10','20'],'rewardType':'Progressive,'Tag':'Early-DLS_Lesion','Type':'Good'}
    """
    table=read_tag_table(tagFile)
    if table['Sessions']==[]:
        return {'Sessions':[]}

    for key in profile:
        if isinstance(profile[key],str):
            profile[key]=profile[key].split()
        elif isinstance(profile[key],int):            
            profile[key]=str(profile[key]).split()
    Profile=deepcopy(profile)
    header=read_tag_header(tagFile)
    for key in set(Profile).intersection(set(header)):
        Profile.pop(key,None)
        if str(header[key]) not in profile[key]:
            return {'Sessions':[],'Tag':'','Speed':'','Type':'','Event':'','Label':''}
    goodSessions=[]
    for index,session in enumerate(sorted(table['Sessions'])):
        try:
            for tag,value in zip (Profile.keys(),Profile.values()):
                check=False
                for val in value:
                    if str(table[tag][index])==str(val):
                        check=True
                        break
                if not check:
                    raise NameError
            goodSessions.append(session)
        except Exception as e:
            continue
    
    try:
        untilDate=datetime.datetime.strptime(until_date,"%Y_%m_%d")
        for idx,session in enumerate(goodSessions):
            sessionDate=datetime.datetime.strptime(session,session[:6]+"_%Y_%m_%d_%H_%M")
            if sessionDate > untilDate:
                goodSessions=goodSessions[:idx]
                break
    except:
        pass
    goodIndex=[x for x,s in enumerate(table['Sessions']) if s in goodSessions]
    goodSessionsProfile={key:[str(table[key][idx]) for idx in goodIndex] for key in table.keys() }
    return goodSessionsProfile

In [None]:
if "__file__" not in dir():

    profile2={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':['10','20','var'],
             'Speed':'10',
             'Tag':'Late-Lesion_DMS'
             }
    a=get_session_list('/data/Rat162/Rat162.tag',profile=profile2,until_date='')
    print(a['Sessions'])

In [None]:
def batch_get_session_list(root,animalList=None,profile={},until_date=''):
    """
    This function returns list of sessions with certain 'profile' for all the animals
    in animalList. if animalList=[], it will search all the animals
    """
    clearScreen=False
    if animalList is None or animalList=='' or animalList==[]:
        animalPaths=glob.glob(os.path.join(root,'Rat*'))
        animalList=[os.path.basename(animalPaths[i]) for i,_ in enumerate(animalPaths)]
        clearScreen=True
    animalList=sorted(animalList)
    sessionList=[]
    profileDict={}
    for animal in animalList:
        tagPath=os.path.join(root,animal,animal+'.tag')
        sessionProfile=get_session_list(tagFile=tagPath,profile=profile,until_date=until_date)
        for key in sessionProfile:
            if key not in profileDict:
                profileDict[key]=[]
            profileDict[key].extend(sessionProfile[key])
    
    if clearScreen:
        clear_output()
        
    return profileDict

In [None]:
if "__file__" not in dir():

    a=batch_get_session_list(root,animalList=[],profile={})
    print(a['Tag'].index('Early-Lesion_DLS-DevaluedReward'))

In [None]:
def batch_get_animal_list(root,profile):
    """
    this function returns list of animals with at least one session matching the "profile"
    """
    dic=batch_get_session_list(root,animalList=[],profile=profile,until_date='')
    sessionList=dic['Sessions']
    animalList=[]
    for session in sessionList:
        animalList.append(session[:6])
    animalList=np.unique(animalList).tolist()
    return animalList

In [None]:
def update_tag_file(root,animal):
    """
    This function updates the tag for for the animal, comparing the content of tag file
    with the data available in in the computer.
    """
    tagFile=os.path.join(root,animal,animal+'.tag')
    tagLastSession=get_session_list(tagFile,profile={},until_date='')['Sessions'][-1]
    
    dataLastSession=sorted([os.path.basename(expPath) for expPath in 
                 glob.glob(os.path.join(root,animal,"Experiments",animal+'*'))])[-1]
    
    if dataLastSession > tagLastSession:
        logging.info('updating tag file...')
        write_animal_tag_file(root,animal,until_date="",overwrite=False)

In [None]:
def lesion_type(root,animal):
    """
    This function determines the types of the first lesion donr in the animal
    """
    tagList=batch_get_session_list(root,animalList=[animal],profile={},until_date='')['Tag']

    _types_=('DLS','DMS','DS','GPe','M1')

    lesion=''
    Tag=''
    for tag in tagList:
        if 'lesion' in tag.casefold():
            for t in _types_:
                if t.casefold() in tag.casefold():
                    lesion=t
                    Tag=tag
                    break
            break
                    
    return Tag, lesion

In [None]:
def event_detect(root,profile1,profile2,badAnimals=None):
    """
    This function finds the animals that match both profile1 and profile2 IN SUCCESSION
    input 'event' optionally determines the type of the event on the transition session
    """
    if badAnimals is None:
        badAnimals=[]
    animalList1=batch_get_animal_list(root,profile1)
    animalList2=batch_get_animal_list(root,profile2)
    animalList0=set(animalList1).intersection(set(animalList2))
    animalList0=[animal for animal in animalList0 if animal not in badAnimals]  #remove bad animals from animalList0
    sessionDic={key:[[],[]] for key in animalList0}
    animalList=[]
    for animal in animalList0:
        sessionListProfile1=batch_get_session_list(root,animalList=[animal],profile=profile1,until_date='')
        sessionListProfile2=batch_get_session_list(root,animalList=[animal],profile=profile2,until_date='')        
        sessionListTotal=batch_get_session_list(root,animalList=[animal],profile={'Type':'Good'},until_date='')
        try:
            index=sessionListTotal['Sessions'].index(sessionListProfile1['Sessions'][-1])
            if sessionListProfile2['Sessions'][0] == sessionListTotal['Sessions'][index+1]:
                animalList.append(animal)
                sessionDic[animal][0]=sessionListProfile1['Sessions']
                sessionDic[animal][1]=sessionListProfile2['Sessions']
        except Exception as e:
            logging.warning(repr(e))
    return animalList,sessionDic

In [None]:
if "__file__" not in dir():

    profile1={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':'10',
             'Tag':'Late-Lesion_DMS'
             }

    profile2={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':'10',
             'Speed':'10',
             'Tag':'Late-Lesion_DMS_GPi'
             }
    _,a=event_detect(root,profile1,profile2,badAnimals=[])
    print(a['Rat162'][0])

In [None]:
def update_animal_table_file(root,days_to_check=5):
    outputPath=os.path.join(root,'ALLRAT_Analysis','RatTable.html')
    animalList=get_current_animals(days_passed=days_to_check)
    sessionProfile=batch_get_session_list(root,animalList)
    Dates=[]
    Animals=[]
    dataList=[]
    for ID,session in enumerate(sessionProfile['Sessions']):
        sessionDate=session[7:-6]
        animal=session[:6]
        if sessionDate not in Dates:
            Dates.append(sessionDate)
        if animal not in Animals:
            Animals.append(animal)
        string=sessionProfile['Tag'][ID]+'@'+sessionProfile['Speed'][ID]
        if sessionProfile['Type'][ID] != 'Good':
            string='Bad Session'
        dataList.append((animal,sessionDate,string))
    Dates=np.unique(Dates).tolist()
    Dates=sorted(Dates,reverse=True)
    
    Table=[['' for i in Animals] for j in Dates ]
    
    for data in dataList:
        Table[Dates.index(data[1])][Animals.index(data[0])]=data[2]
    
    tableDataFrame=pd.DataFrame(data=Table, index=Dates, columns=Animals, dtype=None, copy=False)
    tableHTML=tableDataFrame.T.to_html()
    
    #adding the tag table
    allProfile=batch_get_session_list(root)
    allTag=np.unique(allProfile['Tag']).tolist()
    rem=len(allTag)%3
    allTag.extend(['' for i in range(3-rem)])
    allTag=np.reshape(allTag,(-1,3)).tolist()
    allTag=str(pd.DataFrame(data=allTag,columns=['Tags','Tags','Tags']).to_html())
    htmlStr="""<h1>List of ALL used tags</h1>
%s
<h1> Animal Table (<%s days)</h1>

    """%(str(allTag),str(days_to_check))
    
    with open(outputPath,'w') as f:
        f.write(htmlStr+str(tableHTML))

#### Working based on keywords in Tags

In [None]:
def get_pattern_session_list(tagFile,tagPattern=''):
    """
    This function returns the list of the sessions within a tag file 
    meeting all the conditions in 'profile', MOST keys in 'profile' could be a list of accepted conditions
    Exception: keys corresponding to tag header
    EX: profile={'Speed':['10','20'],'rewardType':'Progressive,'Tag':'Early-DLS_Lesion','Type':'Good'}
    """
    table=read_tag_table(tagFile)

    goodSessions=fnmatch.filter(table['Tag'], tagPattern)

    goodIndex=[x for x,s in enumerate(table['Tag']) if s in goodSessions]
    goodSessionsProfile={key:[str(table[key][idx]) for idx in goodIndex] for key in table.keys() }
    return goodSessionsProfile

In [None]:
if "__file__" not in dir():

    profile2={'Type':'Good',
             'rewardType':'Progressive',
             'initialSpeed':['10','20','var'],
             'Speed':'10',
             'Tag':'Late-Lesion_DMS'
             }
    a=get_pattern_session_list('/data/Rat162/Rat162.tag','*DMS*')
    print(a['Sessions'])

In [None]:
def batch_get_pattern_list(root,animalList=None,tagPattern=''):
    
    clearScreen=False
    if animalList is None or animalList=='' or animalList==[]:
        animalPaths=glob.glob(os.path.join(root,'Rat*'))
        animalList=[os.path.basename(animalPaths[i]) for i,_ in enumerate(animalPaths)]
        clearScreen=True
    animalList=sorted(animalList)

    profileDict={}
    for animal in animalList:
        tagPath=os.path.join(root,animal,animal+'.tag')
        sessionProfile=get_pattern_session_list(tagFile=tagPath,tagPattern=tagPattern)
        for key in sessionProfile:
            if key not in profileDict:
                profileDict[key]=[]
            profileDict[key].extend(sessionProfile[key])
    
    if clearScreen:
        clear_output()
    
    return profileDict

In [None]:
if "__file__" not in dir():

    a=batch_get_pattern_list(root,animalList=[],tagPattern='*_DLS-')
    print(a['Tag'])

In [None]:
def get_pattern_animalList(root,tagPattern):
    """
    this function returns list of animals with at least one session matching the "profile"
    """
    dic=batch_get_pattern_list(root,animalList=[],tagPattern=tagPattern)
    sessionList=dic['Sessions']
    animalList=[]
    for session in sessionList:
        animalList.append(session[:6])
    animalList=np.unique(animalList).tolist()
    return animalList

In [None]:
if "__file__" not in dir():

    a=get_pattern_animalList(root,tagPattern='*DLS?*')
    print(a)