# Updating downsampling procedure iteratively

code in this notebook is to validate data preprocessing and downsampling before before incorporating into 

In [None]:
%load_ext autoreload 
%autoreload 2 

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf

from src.utilities.pandas_helpers import flatten_dataframe, strip_columns

In [None]:
def plot_data(df, title):
    sns.lineplot(data=df, x='time', y='signal',linewidth=0.2)
    # change y axsis to be -2 to 2
    # plt.fill_between(df['time'], df['signal'] - df['signal_sem'], df['signal_mean'] + df['signal_sem'], alpha=.2, color='black')
    plt.ylim(-2, 2)
    plt.title(title)
    

## load raw data

In [None]:
RAW_PATH = '/projects/p31961/gaby_data/aggregated_data/aggregated_data.parquet.gzp'
raw_data = pd.read_parquet(RAW_PATH)

In [None]:
raw_data.head()

#### missing values in aggreated data

In [None]:
test_df = raw_data.head()
test_df

def tweak_name(df):
    return (
        df
        .assign(signal_mean=df['signal'].mean())
        .rename(columns = {'signal_mean': 'RENAME'}))
tweak_df = tweak_name(test_df)
tweak_df

In [None]:
nan_vals = raw_data[raw_data.isna().any(axis=1)]
group_by = ['mouse_id', 'trial',  'day', 'event', 'sensor']
nan_vals.groupby(by = group_by, as_index=False).count()

## function to organize and downsample dataframe

In [None]:
def down_sample_data(df, group_by_cols, agg_dict, ignore_for_sorting, downsample_rate):
    sort_by_list = [col for col in group_by_cols if col != ignore_for_sorting]

    return (
        df
        .dropna(axis = 0, how = 'any') # drop any rows with nans
        .groupby(by=group_by_cols, as_index=False).agg(agg_dict)
        .pipe(flatten_dataframe) # flatten the multi-index
        .pipe(strip_columns) # fixes the column names by stripping _
        .drop(columns = 'index') # drop the index column
        # sort by everything but time and signal columns, 
        # by default puts time column in the correct orientation for downsampling
        .sort_values(by = sort_by_list)
        .rename(columns = {'signal_mean':'signal'}) # rename signal_mean to signal
        [::downsample_rate] # downsample by saving every 100th row
        )


In [None]:
group_by_list = ['time', 'sensor', 'trial', 'mouse_id', 'day', 'event'] # columns to group by
agg_dict = {'signal': ['mean']} # columns are aggregated to the mean and sem signal

In [None]:
downsampled = down_sample_data(df = raw_data, 
                               group_by_cols=group_by_list, 
                               agg_dict = agg_dict,
                               ignore_for_sorting = 'time',
                               downsample_rate=100)

In [None]:
downsampled

#### same process for raw data without downsampling

In [None]:
# function to query raw data
def agg_data_no_downsample(df, group_by_cols, agg_dict,):
    # sort_by = [col for col in group_by_cols if col != ignore_for_sorting]

    return (
        df
        .dropna(axis = 0, how = 'any') # drop any rows with nans
        .groupby(by=group_by_cols, as_index=False).agg(agg_dict)
        .pipe(flatten_dataframe) # flatten the multi-index
        .pipe(strip_columns) # fixes the column names by stripping _
        .drop(columns = 'index') # drop the index column
        # sort by everything but time and signal columns, 
        # by default puts time column in the correct orientation for downsampling
        .sort_values(by = [col for col in group_by_cols if col != 'time'])
        )


In [None]:
grouped_raw_data = agg_data_no_downsample(df = raw_data,
                                          group_by_cols = group_by_list, 
                                          agg_dict = agg_dict
                                        )
                                        #    ignore_for_sorting='time', 
                               

query for day 1, cue, dopamine

In [None]:
da1_cue_dopamine_query = 'day==5 & event=="cue" & sensor=="DA" & trial == 1'

In [None]:
#query downsampled data
d1_cue_da_ds = downsampled.query(da1_cue_dopamine_query)

In [None]:
#query raw data
# d1_cue_raw_no_ds = grouped_raw_data.query(da1_cue_dopamine_query)

#### plot a sample query to make sure data looks correct

In [None]:
plot_data(d1_cue_da_ds, 'Downsampled DA signal for mouse 142_237 on day 1 during cue event')
# plot_data(d1_cue_raw_no_ds, 'Raw DA signal for mouse 142_237 on day 1 during cue event')

In [None]:
downsampled.query('day==1 & event=="cue" & sensor=="DA" & trial == 1')