<center> <h1> Descriptive statistics </h1>

<h4> Description: This notebook contains functions for calculating descriptive statistics on categorical and numerical variables. These functions are tested on a hand-generated dataset. </h4> </center>
<h5> Information: If the descriptive statistics are computed on a dataframe for a <b> binary classification task </b>, the generated descriptive statistics dataframes (categorical and numerical) can be saved in xlsx format and their content can be copied and pasted directly into the "raw_cat_statsdesc" and "raw_num_statsdesc" sheets of the Excel file "statsdesc_template.xlsx" (located in the same folder) to obtain a harmonised Excel. </h5>

## Table of contents

* [1. Imports](#chapter1)
* [2. Descriptive statistics functions](#chapter2)
* [3. Hand-generated dataset](#chapter3)
* [4. Compute descriptive statistics](#chapter4)

# 1. Imports <a class="anchor" id="chapter1"></a>

In [6]:
import pandas as pd
import numpy as np
from typing import Union, List

# 2. Descriptive statistics functions <a class="anchor" id="chapter2"></a>

In [7]:
def q1(x):
    """
    Retrieves first quartile
    """
    return np.nanquantile(x, 0.25)

def q3(x):
    """
    Retrieves third quartile
    """
    return np.nanquantile(x, 0.75)


def set_to_str(df: pd.DataFrame,
              list_cols: Union[str, List[str]]) -> pd.DataFrame:
    """Sets types of cols in col_list as str type

    Args:
        df (pd.DataFrame): data
        list_cols (List[str]): name of column or list of columns to change type into float

    Returns:
        pd.DataFrame: 
    """
    if not isinstance(list_cols, list):
        list_cols = [list_cols]
    for col in list_cols:
        df[col] = df[col].astype(str)
    return df


def flatten_multiindex(col_names: list,
                       preproc: bool=True) -> list:
    """Flattens multiindex into 1 dimensional index
    If preproc is true, column names are put into uppercase and ' ' replaced by '_'
   
    Args:
        col_names (list): multiindex of dimension n x m
        preproc (bool, optional): Wether to preprocess column names. Defaults to True.

    Returns:
        list: list of size n x m

    Example:
        col_names = [["Placebo", "Treatment"], ["MEAN", "STD"]]
        preproc = True
        will return 4 new column names : ["PLACEO_MEAN", "PLACEO_STD", "TREATMENT_MEAN", "TREATMENT_STD"]
    """
    col_levels = [list(col) for col in col_names]
    col_names_new = ["_".join(col) for col in col_levels]

    # preprocessing of column names
    if preproc:
        col_names_new = [col.replace(' ', '_').upper() for col in col_names_new]

    return col_names_new

def replace_na_in_col(df: pd.DataFrame,
                      list_col: list,
                      value: Union[int, str]) -> pd.DataFrame:
    
    
    mapping_col_values = {col: {np.nan: value} for col in list_col}
    df = df.replace(mapping_col_values)
    return df
    

def get_stats_descs(df: pd.DataFrame,
                    var_type: str,
                    list_var: List[str],
                    list_col_aggr: List[str],
                    list_agg_func: list,
                    col_id: str) -> pd. DataFrame:
    # quality check on listed variables to compute stats for
    list_var_not_in_df =  [var for var in list_var if var not in df.columns]
    assert len(list_var_not_in_df) == 0, f"Following variables in list of variables not present in dataframe: {list_var_not_in_df}"

    # pass columns to aggreate on into str
    # for simplicity when creating columns from pivot multiindex
    df = set_to_str(df=df, list_cols=list_col_aggr)

    # create table with 1 row / PTID | Variable | Category
    df_melt = pd.melt(df, id_vars=list_col_aggr + [col_id], value_vars=list_var, var_name="Variable", value_name="Category")

    # compute stats with pivot function 
    if var_type == 'categorical':
        index = ["Variable", "Category"]
    else:
        index = ["Variable"]
        # remove unuseful columns in numerical case
        # in numerical case pivot does not support additional cols
        df_melt = df_melt.drop(columns=[col_id])
    df_stats = pd.pivot_table(df_melt,
                              index=index,
                              columns=list_col_aggr,
                              aggfunc=list_agg_func)

    # flatten columns' mutiindex to get 1 dimensional column index
    df_stats = df_stats.droplevel(1, axis=1)
    df_stats.columns = flatten_multiindex(col_names=df_stats.columns)

    # get name of aggregation groups, useful to compute derivative columns over aggregation groups
    aggr_groups = [col.replace("COUNT_", "") for col in df_stats.columns]
    df_stats = df_stats.reset_index()

    # replace na by 0 in count columns
    list_col_count = df_stats.filter(regex="^COUNT")
    df_stats = replace_na_in_col(df=df_stats, list_col=list_col_count, value=0)
 
    if var_type == 'categorical': 
        # compute total count by category
        df_count_tot = df_stats.drop(columns=["Category"]).groupby("Variable").sum()
        
        # formatting: replace count by tot in count tot dataframe
        df_count_tot.columns = [col.replace("COUNT", "TOT") for col in df_count_tot.columns]
            
        # add tot count information 
        df_stats = df_stats.merge(df_count_tot, on="Variable", how="inner")
            
        # compute prevalence 
        for col in aggr_groups:
            df_stats[f"PCT_{col}"] = df_stats[f"COUNT_{col}"]/ df_stats[f"TOT_{col}"]

    return df_stats
  

def main_cat_stats_descs(df: pd.DataFrame, 
                        list_var_cat: List[str],
                        list_col_aggr: List[str],  
                        col_id: str) -> pd.DataFrame:
    """Retrieves categorical stat descs from a feature matrix given a list of variables and modalities to aggregate stats on.
    This function contains a list_agg_func argument which is the list of functions to compute : it has to be equal to ['count'] because the categorical stat descs are computed.
    NB: This function deals with null values. 

    Args:
        df (pd.DataFrame): 1 row / PTID | variable_name 
        list_var_cat (list): list of variables to compute stats on
        list_col_aggr (list): list of modalities to aggregate stats on
        col_id (str): name of patient id column

    Returns:
        pd. DataFrame: 1 col / variable | (category) & 1 col / function & modality

    Example:
    Inputs:
        df: 1 row / PTID | gender (F or M)| outcome (0 or 1)
        list_var_cat = ["gender"]
        list_col_aggr = ["label"]
        col_id = "PTID"
    Output:
        df: 1 row / gender | count_0 | count_1 | tot_0 | tot_1 | pct_0 | pct_1
    """
    df = df.fillna("Missing value")

    # compute stats for entire dataframe without any aggregation
    df_all = df.copy()
    df_all[list_col_aggr] = "all"
    df = pd.concat([df, df_all], axis=0)
    df = get_stats_descs(df=df,
                var_type='categorical',
                list_var=list_var_cat,
                list_col_aggr=list_col_aggr,
                list_agg_func=["count"],
                col_id=col_id)

    # df = df.assign(LIFT=df.PCT_1 / df.PCT_0)
    
    return df

def main_num_stats_descs(df: pd.DataFrame, 
                   list_var_num: List[str],
                   list_col_aggr: List[str],  
                   col_id: str) -> pd.DataFrame:
    """Retrieves numerical stat descs from a feature matrix given a list of variables and modalities to aggregate stats on.
    This function contains a list_agg_func argument which is the list of functions to compute :
        Accepted : list containing any of these values : 'count', 'mean', 'std', 'min', 'median', 'max', q1, q3
        Default : equals to ['count', 'mean', 'std', 'min', 'median', 'max', q1, q3]
    NB: This function deals with null values. 
    
    Args:
        df (pd.DataFrame): 1 row / PTID | variable_name 
        list_var_num (list): list of variables to compute stats on
        list_col_aggr (list): list of modalities to aggregate stats on
        col_id (str): name of patient id column

    Returns:
        pd. DataFrame: 1 col / variable | (category) & 1 col / function & modality
        
    Example:
    Inputs:
        df: 1 row / PTID | age | outcome (0 or 1)
        list_var_num = ["age"]
        list_col_aggr = ["label"]
        col_id = "PTID"
    Output:
        df: 1 row / age | mean_0 | mean_1 |std_0 | std_1 
    """
    df_all = df.copy()
    df_all[list_col_aggr] = "all"
    df = pd.concat([df, df_all], axis=0)

    df = get_stats_descs(df=df,
                var_type='numerical',
                list_var=list_var_num,
                list_col_aggr=list_col_aggr,
                list_agg_func=['mean', 'std', 'min', q1, 'median', q3, 'max', 'count'],
                col_id=col_id)
    
    return df 

# 3. Hand-generated dataset <a class="anchor" id="chapter3"></a>

In [8]:
# Create a raw dataframe

# Fix randomness
np.random.seed(42)

# Define the number of patients in this dataframe
nb_patients = 30

df = pd.DataFrame({
    'PTID': ['PT' + str(i).zfill(3) for i in range(1, (nb_patients+1))],
    'gender': np.random.choice(['F', 'M'], size=nb_patients),
    'diagnosis': np.random.choice(['diag_1', 'diag_2', 'diag_3'], size=nb_patients),
    'prescription': np.random.choice(['presc_1', 'presc_2', 'presc_3'], size=nb_patients),
    'symptom_1': np.random.randint(2, size=nb_patients),
    'symptom_2': np.random.randint(2, size=nb_patients),
    'age': np.random.randint(18, 80, size=nb_patients),
    'lab_measure': np.random.random(size=nb_patients),
    'label': np.random.randint(2, size=nb_patients),
    'cluster': np.random.randint(5, size=nb_patients)
})

# Define list of columns
list_var_num, list_var_cat, col_id = ["age", "lab_measure"], ["gender", "diagnosis", "prescription", "symptom_1", "symptom_2"], "PTID"
list_col_aggr = ["label"] # ["label"]: can be used for a classification task; ["cluster"]: can be used for a regression task after a clustering
list_var = list_var_num + list_var_cat
nb_var = len(list_var)

# Add randomly NaN values
mask = np.random.choice([True, False], size=(nb_patients, nb_var), p=[0.3, 0.7])
df[list_var] = df[list_var].mask(mask)

# Check dtypes
# print(df.dtypes)

display(df.head())

Unnamed: 0,PTID,gender,diagnosis,prescription,symptom_1,symptom_2,age,lab_measure,label,cluster
0,PT001,F,diag_1,presc_2,0.0,1.0,57.0,,1,4
1,PT002,M,diag_1,,,0.0,39.0,,1,0
2,PT003,F,diag_2,presc_2,0.0,,44.0,,0,0
3,PT004,F,diag_2,presc_1,1.0,0.0,,0.410383,0,2
4,PT005,F,diag_1,presc_2,1.0,0.0,18.0,0.755551,1,1


# 4. Compute descriptive statistics <a class="anchor" id="chapter4"></a>

In [9]:
# Categorical statsdesc
df_cat_stats_descs = main_cat_stats_descs(df = df, list_var_cat = list_var_cat, list_col_aggr = list_col_aggr, col_id = col_id)
display(df_cat_stats_descs)

# Numerical statsdesc
df_num_stats_descs = main_num_stats_descs(df = df, list_var_num = list_var_num, list_col_aggr = list_col_aggr, col_id = col_id)
display(df_num_stats_descs)

Unnamed: 0,Variable,Category,COUNT_0,COUNT_1,COUNT_ALL,TOT_0,TOT_1,TOT_ALL,PCT_0,PCT_1,PCT_ALL
0,diagnosis,Missing value,3.0,3.0,6.0,13.0,17.0,30.0,0.230769,0.176471,0.2
1,diagnosis,diag_1,0.0,7.0,7.0,13.0,17.0,30.0,0.0,0.411765,0.233333
2,diagnosis,diag_2,5.0,3.0,8.0,13.0,17.0,30.0,0.384615,0.176471,0.266667
3,diagnosis,diag_3,5.0,4.0,9.0,13.0,17.0,30.0,0.384615,0.235294,0.3
4,gender,F,6.0,4.0,10.0,13.0,17.0,30.0,0.461538,0.235294,0.333333
5,gender,M,6.0,6.0,12.0,13.0,17.0,30.0,0.461538,0.352941,0.4
6,gender,Missing value,1.0,7.0,8.0,13.0,17.0,30.0,0.076923,0.411765,0.266667
7,prescription,Missing value,2.0,3.0,5.0,13.0,17.0,30.0,0.153846,0.176471,0.166667
8,prescription,presc_1,2.0,2.0,4.0,13.0,17.0,30.0,0.153846,0.117647,0.133333
9,prescription,presc_2,7.0,7.0,14.0,13.0,17.0,30.0,0.538462,0.411765,0.466667


Unnamed: 0,Variable,MEAN_0,MEAN_1,MEAN_ALL,STD_0,STD_1,STD_ALL,MIN_0,MIN_1,MIN_ALL,...,MEDIAN_ALL,Q3_0,Q3_1,Q3_ALL,MAX_0,MAX_1,MAX_ALL,COUNT_0,COUNT_1,COUNT_ALL
0,age,43.888889,47.846154,46.227273,19.757558,20.309465,19.707218,18.0,18.0,18.0,...,44.0,56.0,64.0,62.25,77.0,78.0,78.0,9,13,22
1,lab_measure,0.459456,0.477144,0.469283,0.343201,0.313158,0.317018,0.006952,0.119865,0.006952,...,0.413897,0.80761,0.791642,0.806498,0.860731,0.929698,0.929698,8,10,18


In [10]:
# Save statsdesc dataframes

# pip install openpyxl
# df_cat_stats_descs.to_excel('cat_statsdesc.xlsx', sheet_name='Sheet1', engine="openpyxl", index=False)
# df_num_stats_descs.to_excel('num_statsdesc.xlsx', sheet_name='Sheet1', index=False)