# Data gathering from Allen

### Installing AllenSDK into your local environment. 

In [None]:
pip install allensdk

### Import Packages:

In [None]:
import os
import shutil
import allensdk
import pprint
from pathlib import Path

import numpy as np
import pandas as pd
import scipy.stats as st

from sklearn import svm
from sklearn.model_selection import KFold
from sklearn.metrics import confusion_matrix
from allensdk.brain_observatory.ecephys.visualization import plot_mean_waveforms, plot_spike_counts, raster_plot

import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
sns.set_context('notebook', font_scale=1.5, rc={'lines.markeredgewidth': 2})

In [None]:
# this code block should only be run if you are working with the neuropixels data
from allensdk.brain_observatory.ecephys.ecephys_project_cache import EcephysProjectCache

data_directory = '/overflow/NSCI274/projects/ecephysdata/' 

manifest_path = os.path.join(data_directory, "manifest.json")

cache = EcephysProjectCache.from_warehouse(manifest=manifest_path)

In [None]:
session = cache.get_session_table()

In [None]:
session = session[(['DG' in acronyms for acronyms in session.ecephys_structure_acronyms])]
len(session)

In [None]:
# store labeling as pandas
listP = [0,1,2,5,6,7,10,11,12,13,14,15,16,17,18,19,22,23,25,27,29,33,38,39,44,45,47,49,50,52,53,55,58,102]
listNonP = [3,4,8,9,21,24,26,28,32,34,35,36,40,41,42,46,48,51,54,56,57,112] 
listNonA = [20,30,31,37,43,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,103,104,105,106,107,108,109,110,111,113,114,115,116,117]

# print(len(listP),len(listNonP),len(listNonA))
listFrames = listP + listNonP + listNonA
listLabels = ['predator']*34 + ['non_predator']*22 + ['non_animal'] * 62

labeldict = {'frame': listFrames, 'labels': listLabels}
labels = pd.DataFrame(labeldict).set_index('frame').sort_values('frame')

In [None]:
# get sessions of interest and their index
malessid = list(session[session['sex'] == 'M'].index.values)
femalessid = list(session[session['sex'] == 'F'].index.values)

In [None]:
# function for getting spike mean
def get_spike_means(sessionID):
    ss = cache.get_session_data(sessionID)
    table = ss.get_stimulus_table("natural_scenes")
    units = ss.units[ss.units["ecephys_structure_acronym"] == 'DG']
    uid = units.index.values
    pid = ss.stimulus_presentations.loc[(ss.stimulus_presentations['stimulus_name'] == 'natural_scenes')].index.values
    
    if len(pid) < 1:
        result = None
        
    stat = ss.conditionwise_spike_statistics(
    stimulus_presentation_ids = pid,
    unit_ids = uid)
    stat = stat[['spike_mean']]
        
    fullchart = pd.merge(stat, ss.stimulus_conditions['frame'], 
                     left_on="stimulus_condition_id",
                     right_index=True)
    result = fullchart.groupby(['frame']).mean()
    return result

## Get Male Overall Data

The iteration could took up to an hour w/out computational GPU. 

In [None]:
# iterate all session by concating a list of MALE session tables
dfs = []
for i in malessid: 
    try:
        df = get_spike_means(i).drop(index=-1) # remove blank image
    except IndexError:
        continue
    dfs.append(df)

concatedM = pd.concat(dfs).groupby('frame', as_index=False).mean() # combine and get mean

In [None]:
concatedM.index.name = 'frame'
frequencyM = concatedM.mul(4).rename(columns={'spike_mean': 'firing_rate'})

## Get Female Overall Data

In [None]:
# iterate all session by concating a list of FEMALE session tables
dfs = []
for i in femalessid: 
    try:
        df = get_spike_means(i).drop(index=-1) # remove blank image
    except IndexError:
        continue
    dfs.append(df)

concatedF = pd.concat(dfs).groupby('frame', as_index=False).mean() # combine and get mean

In [None]:
# get frequency
concatedF.index.name = 'frame'
frequencyF = concatedF.mul(4).rename(columns={'spike_mean': 'firing_rate'})

## Group data by labels

In [None]:
labeledM = pd.merge(frequencyM, labels['labels'],
                  left_on='frame',
                  right_index=True)
labeledM.to_csv('raw_labeled_male.tsv', sep='\t')
# use df = pd.read_csv('raw_labeled_male.tsv', sep='\t') to read the data

In [None]:
labeledF = pd.merge(frequencyF, labels['labels'],
                  left_on='frame',
                  right_index=True)
labeledF.to_csv('raw_labeled_female.tsv', sep='\t')
# use labeledF = pd.read_csv('raw_labeled_female.tsv', sep='\t') to read the data

## Data Structuring

In [None]:
labeldM = pd.read_csv('raw_labeled_male.tsv', sep='\t')
labeldF = pd.read_csv('raw_labeled_female.tsv', sep='\t')

In [None]:
def get_group_means(df,labels):
    means = []
    for l in labels:
        means.append(df[df['labels'] == l]['firing_rate'].mean())
    return means

def get_group_sems(df,labels):
    sems = []
    for l in labels:
        sems.append(df[df['labels'] == l]['firing_rate'].sem())
    return sems

In [None]:
labeling = ['predator','non_predator','non_animal']

male_means = get_group_means(labeledM,labeling)
female_means = get_group_means(labeledF,labeling)
male_sems = get_group_sems(labeledM,labeling)
female_sems = get_group_sems(labeledF,labeling)
# print(male_means+female_means)
# print(male_sems+female_sems)

In [None]:
# calculate N
male_N = [40*62,40*34,40*22]
female_N = [11*62,11*22,11*34]

In [None]:
# AIOdf
aio_sex = ['Male','Male','Male','Female','Female','Female']
aio_labeling = ['Non-Animal Control','Non-Predator','Predator','Non-Animal Control','Non-Predator','Predator']
aio_mean = male_means+female_means
aio_sem = male_sems+female_sems
aio_N = male_N+female_N

aio_df = pd.DataFrame({'sex': aio_sex, 'label': aio_labeling, 'mean_frequency': aio_mean, 'sem': aio_sem, 'N':aio_N})
aio_df.to_csv('AIO.tsv', sep='\t')

In [None]:
aio_df