In [None]:
import pandas as pd
import numpy as np

from scipy.stats import entropy

from sklearn.metrics import mutual_info_score, adjusted_mutual_info_score, normalized_mutual_info_score

### For discussion

- `information_gain` (cell below) and the scikit learn functions `mutual_info_score`, `adjusted_mutual_info_score` and `normalized_mutual_info_score` have the same interface, and called from `igr` interchangably

- `information_gain` and `mutual_info_score` compute the same function; the others are different

- ACE produces something different to all of these (on e.g. Nutrients.csv)

In [None]:
def information_gain(members, split):
    '''
    
    *** THIS FUNCTION IS COPIED FROM https://stackoverflow.com/questions/46752650/information-gain-calculation-with-scikit-learn
    
    Measures the reduction in entropy after the split  
    :param v: Pandas Series of the members
    :param split:
    :return:
    '''
    entropy_before = entropy(members.value_counts(normalize=True))
    split.name = 'split'
    members.name = 'members'
    grouped_distrib = members.groupby(split) \
                        .value_counts(normalize=True) \
                        .reset_index(name='count') \
                        .pivot_table(index='split', columns='members', values='count').fillna(0) 
    entropy_after = entropy(grouped_distrib, axis=1)
    entropy_after *= split.value_counts(sort=False, normalize=True)
    return entropy_before - entropy_after.sum()

In [None]:
def igr(df_features, target):
    """
    Calculate the information gain ratio for each feature in a dataframe

    Parameters
    ----------
    df_features : Dataframe
        The features for which the information gain ratio will be calculated
    target : Series
        The targets for which the information gain ratio with each feature will be calculated

    Returns
    -------
    A dictionary of feature names to information gain ratio, for each feature in df_features.
    """
    return {
        col: normalized_mutual_info_score(df_features[col], target)
        for col in df_features.columns
    }