# Bin data on time

## Import packages

In [1]:
import pandas as pd
import xarray as xr
import math
from operator import itemgetter

## Convert dataframe to xarray

In [3]:
# https://notebook.community/jhamman/xarray/examples/xarray_multidimensional_coords

class DataBinner:
    
    def __init__(self, df, pressure_range, pressure_bin_size):
        
        self.df = df
        self.pressure_range = pressure_range
        self.pressure_bin_size = pressure_bin_size
                
            
    def get_coords(self):
        
        return ('pres', 'profile_id', 'cycle_number', 'lat', 'lon', 'date', 'year', 'month', 'day')

#     def set_df_index(self):
#         self.df_indexed = self.df.set_index(['cycle_number', 'lat', 'lon', 'datetime', 'year', 'month', 'day', 'pres'])
    
    
#     def index_on_profile_id(self):
#         self.df = self.df.set_index(['profile_id'])
        
        
#     def get_df_grouped_by_id(self):
        
#         # https://stackoverflow.com/questions/22219004/how-to-group-dataframe-rows-into-list-in-pandas-groupby
        
#         self.index_on_profile_id()
        
#         df = self.df
        
#         # Create a datetime column
#         self.df['datetime'] = pd.to_datetime(self.df['date'])

#         def f_multi(df,col_names):
#             if not isinstance(col_names,list):
#                 col_names = [col_names]

#             values = df.sort_values(col_names).values.T

#             col_idcs = [df.columns.get_loc(cn) for cn in col_names]
#             other_col_names = [name for idx, name in enumerate(df.columns) if idx not in col_idcs]
#             other_col_idcs = [df.columns.get_loc(cn) for cn in other_col_names]

#             # split df into indexing colums(=keys) and data colums(=vals)
#             keys = values[col_idcs,:]
#             vals = values[other_col_idcs,:]

#             # list of tuple of key pairs
#             multikeys = list(zip(*keys))

#             # remember unique key pairs and ther indices
#             ukeys, index = np.unique(multikeys, return_index=True, axis=0)

#             # split data columns according to those indices
#             arrays = np.split(vals, index[1:], axis=1)

#             # resulting list of subarrays has same number of subarrays as unique key pairs
#             # each subarray has the following shape:
#             #    rows = number of non-grouped data columns
#             #    cols = number of data points grouped into that unique key pair

#             # prepare multi index
#             idx = pd.MultiIndex.from_arrays(ukeys.T, names=col_names) 

#             list_agg_vals = dict()
#             for tup in zip(*arrays, other_col_names):
#                 col_vals = tup[:-1] # first entries are the subarrays from above 
#                 col_name = tup[-1]  # last entry is data-column name

#                 list_agg_vals[col_name] = col_vals

#             df2 = pd.DataFrame(data=list_agg_vals, index=idx)
#             return df2


#         df_group = f_multi(df, ['date'])
        
#         return df_group
    
#     def create_dataframe(self, df):
        
#         #df = df_group.to_frame()

#         # There is only one elem and it's a list because only one col header (the date)

#         column_name = df.columns[0]

#         # expand column into its own dataframe
#         exploded = df[column_name].apply(pd.Series)

#         # rename each variable 
#         exploded = exploded.rename(columns = lambda x : 'tag_' + str(x))

#         # transpose the df and remove index column
#         df = exploded.T

#         self.df = df.reset_index(drop=True)
        
#         return self.df
    
    
    def apply_qc(self, ds):
        
        # TODO
        # Use qc here to filter things

        # For now, I'm just going to drop the qc columns
        
        try:
            ds = ds.drop('pres_qc')
        except:
            pass
        
        try:
            ds = ds.drop(('psal_qc', 'temp_qc'))
        except:
            pass  

        return ds

    def create_mean_ds(self, ds_grouped):

        ds_list = []

        binned_pres = []
        temp = []
        psal = []
        meta = []

        for name, group in ds_grouped:

            metadata = {}

            temp_mean = group['temp'].mean().data
            psal_mean = group['psal'].mean().data

            # The name is the pressure bin label (numeric)
            binned_pres.append(name)
            temp.append(temp_mean.item())
            psal.append(psal_mean.item())


            for name in group.coords:

                if name == 'index' or name =='pres':
                    continue

                val = group[name][0].values.item()

                metadata[name] = val

            meta.append(metadata)

            # {'cycle_number': 17.0, 'profile_id': '3902235_17', 'lat': -14.35528, 
            # 'lon': -33.87904, 'date': '2020-01-30T22:17:05.002Z', 'year': 2020, 'month': 1, 'day': 30}



        # list of dicts, extract each key into list
        
        # Get values of particular key in list of dictionaries
        all_lat = list(map(itemgetter('lat'), meta))

        meta_names = list(meta[0].keys())

        all_elems = {}
        for name in meta_names:

            all_elems[name] = list(map(itemgetter(name), meta))


        da_all = {}

        for key,val in all_elems.items():

            da_all[key] = xr.DataArray(
                        data   = val,
                        dims   = ['pres'],
                        coords = {'pres': binned_pres}
                        )  

        da_temp = xr.DataArray(
                    data   = temp,
                    dims   = ['pres'],
                    coords = {'pres': binned_pres}
                    )

        da_psal = xr.DataArray(
                    data   = psal,
                    dims   = ['pres'],
                    coords = {'pres': binned_pres}
                    )

        da_shape = da_psal.shape

        ds_new = xr.Dataset({
            'temp': da_temp,
            'psal': da_psal})


        for name in meta_names:

            ds_new[name] = da_all[name]

        coords = self.get_coords()
            
        ds_new = ds_new.set_coords(coords)

        return ds_new



    def bin_on_pressure(self, ds_all_profile_groups):
        
        min_pres = self.pressure_range[0]
        max_pres = self.pressure_range[1]
        
        bin_size = self.pressure_bin_size
        
        ds_mean_list = []

        for name, profile_group in ds_all_profile_groups:
            
            pres_bins = np.arange(min_pres, max_pres, bin_size)

            # define a label for each bin corresponding to the centreal pressure
            pres_center = np.arange(min_pres + bin_size/2, max_pres - bin_size/2, bin_size)

            # group according to those bins and take the mean
            ds_grouped_pres = profile_group.groupby_bins('pres', pres_bins, labels=pres_center)

            ds_mean = self.create_mean_ds(ds_grouped_pres)

            ds_mean_list.append(ds_mean)
            
        ds_all = xr.concat(ds_mean_list, dim ='pres')
        
        return ds_all
        
    
    def create_xarray(self):
        
        # Create a datetime column
        # self.df['datetime'] = pd.to_datetime(self.df['date'])
        
        # Create a unique column of profile_id + date
        
        ds = df.to_xarray()
        
        # print(ds)
        # print(list(ds.coords))
        # print(list(ds.keys()))
        
        profile_id = ds['profile_id'].data
        
        date = ds['date'].data

        unique_col = np.array(list(map('_'.join, zip(profile_id, date))))

        ds["unique_id"]=(['index'],  unique_col)
                
        coords = self.get_coords()
        
        ds = ds.set_coords(coords)
        
        ds = ds.set_coords('unique_id')
                        
        ds = self.apply_qc(ds)
        
        ds_all_profile_groups = ds.groupby('unique_id')
        
        #ds_all_profile_groups = ds.groupby('profile_id')
        
        ds_all = self.bin_on_pressure(ds_all_profile_groups)
        
        
        return ds_all
    
    
    def get_datetime_bounds(self):
        
        min_time = self.df['datetime'].min()
        max_time = self.df['datetime'].max()
        
        return min_time, max_time
    
    def bin_on_time(self, ds_all, interval):
        
        if interval == 'day':
        
            bins = np.arange(0,31)

            bin_center = np.arange(.5,30.5)

            # group according to those bins and take the mean
            ds_all_grouped = ds_all.groupby_bins('day', bins, labels=bin_center)
            
        elif interval == 'month':
            
            bins = np.arange(1,13)

            # define a label for each bin corresponding to the central value
            bin_center = np.arange(1.5,12.5)

            # group according to those bins and take the mean
            ds_all_grouped = ds_all.groupby_bins('month', bins, labels=bin_center)         
        
        return ds_all_grouped
