# Bin data on time

## Import packages

In [1]:
import pandas as pd
import xarray as xr
import math

## Convert dataframe to xarray

In [3]:
# https://notebook.community/jhamman/xarray/examples/xarray_multidimensional_coords

class DataBinner:
    
    def __init__(self, df):
        self.df = df
                

    def set_df_index(self):
        self.df_indexed = self.df.set_index(['cycle_number', 'lat', 'lon', 'datetime', 'year', 'month', 'day', 'pres'])
    
    
    def index_on_profile_id(self):
        self.df = self.df.set_index(['profile_id'])
        
    def get_df_grouped_by_id(self):
        
        # https://stackoverflow.com/questions/22219004/how-to-group-dataframe-rows-into-list-in-pandas-groupby
        
        self.index_on_profile_id()
        
        df = self.df
        
        # Create a datetime column
        self.df['datetime'] = pd.to_datetime(self.df['date'])

        def f_multi(df,col_names):
            if not isinstance(col_names,list):
                col_names = [col_names]

            values = df.sort_values(col_names).values.T

            col_idcs = [df.columns.get_loc(cn) for cn in col_names]
            other_col_names = [name for idx, name in enumerate(df.columns) if idx not in col_idcs]
            other_col_idcs = [df.columns.get_loc(cn) for cn in other_col_names]

            # split df into indexing colums(=keys) and data colums(=vals)
            keys = values[col_idcs,:]
            vals = values[other_col_idcs,:]

            # list of tuple of key pairs
            multikeys = list(zip(*keys))

            # remember unique key pairs and ther indices
            ukeys, index = np.unique(multikeys, return_index=True, axis=0)

            # split data columns according to those indices
            arrays = np.split(vals, index[1:], axis=1)

            # resulting list of subarrays has same number of subarrays as unique key pairs
            # each subarray has the following shape:
            #    rows = number of non-grouped data columns
            #    cols = number of data points grouped into that unique key pair

            # prepare multi index
            idx = pd.MultiIndex.from_arrays(ukeys.T, names=col_names) 

            list_agg_vals = dict()
            for tup in zip(*arrays, other_col_names):
                col_vals = tup[:-1] # first entries are the subarrays from above 
                col_name = tup[-1]  # last entry is data-column name

                list_agg_vals[col_name] = col_vals

            df2 = pd.DataFrame(data=list_agg_vals, index=idx)
            return df2


        df_group = f_multi(df, ['date'])
        
        return df_group
    
    def create_dataframe(self, df):
        
        #df = df_group.to_frame()

        # There is only one elem and it's a list because only one col header (the date)

        column_name = df.columns[0]

        # expand column into its own dataframe
        exploded = df[column_name].apply(pd.Series)

        # rename each variable 
        exploded = exploded.rename(columns = lambda x : 'tag_' + str(x))

        # transpose the df and remove index column
        df = exploded.T

        self.df = df.reset_index(drop=True)
        
        return self.df
    
    
    def apply_qc(self):

        # TODO
        # Use qc here to filter things

        # For now, I'm just going to drop the qc columns

        columns = list(self.df.columns)

        cols_to_drop = []
        for col in columns:

            if '_qc' in col:
                cols_to_drop.append(col)

        self.df = self.df.drop(columns=cols_to_drop)

    
    def create_xarray(self):
        
        # # Create a datetime column
        # self.df['datetime'] = pd.to_datetime(self.df['date'])
        
        self.set_df_index()
        
        self.ds = xr.Dataset.from_dataframe(self.df)
        
        self.ds = self.ds.set_coords(('cycle_number', 'lat', 'lon', 'datetime', 'year', 'month', 'day', 'pres'))
        
        return self.ds
    
    def get_pressure_bounds(self):
                
        min_pres = self.df['pres'].min()
        max_pres = self.df['pres'].max()
                
        # Do I round it off?
        min_pres = math.floor(min_pres)
        max_pres = math.ceil(max_pres)
        
        return min_pres, max_pres
    
    def bin_on_pressure(self, param, bin_size):
                
        #min_pres, max_pres = self.get_pressure_bounds()
        
        min_pres = 0
        max_pres = 2000
        
        pres_bins = np.arange(min_pres, max_pres, bin_size)
                
        # define a label for each bin corresponding to the centreal pressure
        pres_center = np.arange(min_pres + bin_size/2, max_pres - bin_size/2, bin_size)
        
        # group according to those bins and take the mean
        param_mean = ds[param].groupby_bins('pres', pres_bins, labels=pres_center).mean(dim=xr.ALL_DIMS)
        
        return param_mean
    
    
    def get_datetime_bounds(self):
        
        min_time = self.df['datetime'].min()
        max_time = self.df['datetime'].max()
        
        return min_time, max_time
    
    def bin_on_time(self, param, time_interval='month'):
                
        if time_interval == 'month':
            time_bins = np.arange(1,13)
            
        # define a label for each bin corresponding to the central value
        time_center = np.arange(1.5,12.5)
        
        # group according to those bins and take the mean
        param_mean = ds[param].groupby_bins('month', time_bins, labels=time_center).mean(dim=xr.ALL_DIMS)
        
        return param_mean
    