<a href="https://colab.research.google.com/github/Ansebi/causal_inference/blob/kls/Mutual_Information.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Imports

In [None]:
from sklearn.metrics import mutual_info_score
from scipy.stats.contingency import crosstab
from sklearn.metrics import mutual_info_score
import pandas as pd
import numpy as np
from sklearn.feature_selection import mutual_info_classif, mutual_info_regression
from scipy.stats import entropy

# Data

In [None]:
def normalize(
    array: np.array,
    min_: float = None,
    max_: float = None
):
  if len(np.unique(array)) == 1:
    value = np.unique(array)[0]
    norm_value = 0.5
    if max_ != min_:
      norm_value = (value - min_) / (max_ - min_)
    return np.ones_like(array) * norm_value
  if min_ is None:
    min_ = array.min()
  if max_ is None:
    max_ = array.max()
  return (array - min_) / (max_ - min_)

In [None]:
df = pd.read_csv("PYATEROCHKA L7_Центральный_BEVERAGE.csv", sep=';')
df = df.dropna()
# df = df.head(3000)
df = df.drop(columns=['CALDR_DT', 'period_CALDR_DT_first', 'period_CALDR_DT_last', 'PROD_MAP'])
df = df.astype(float)
df_norm = df.apply(normalize)

In [None]:
df = pd.read_csv("/content/drive/MyDrive/Datasets/result_df_two_years.csv").iloc[:, 1:]
df = df.head(300000)
df = df.drop(columns=['Date'])
# df['SL, %'] = df['SL, %'].astype(bool)
# target = df['SL, %']
df.head(1)

Unnamed: 0,Region,Unit,Channel,Chain,Category,Brand,Reason,"SL, %"
0,East,FAR EAST,Alco Group,ALYBION|L7,Energy,ADRENALINE,Брак внутри паллета/загр.склад,0.608696


In [None]:
df = pd.read_csv("/content/discretized_data_300000.csv").iloc[:, 1:]
df.head()

Unnamed: 0,Date,Region,Unit,Channel,Chain,Category,Brand,Reason,"SL, %"
0,0,0,1,0,148,0,2,0,5
1,0,0,1,0,148,0,2,17,1
2,0,0,1,0,148,0,2,26,5
3,0,0,1,0,148,0,2,64,1
4,0,0,1,0,148,0,2,66,7


In [None]:
N_GARBAGE = 3
# magnitude = np.array([1, 7, 9, 2, 1] * 10)
magnitude = np.random.randint(0, 9, 100)

depends_on_magnitude = magnitude * 10 + np.random.rand(len(magnitude)) / 10
switch = np.ones_like(magnitude)
garbage = {f'garbage_{i}': np.random.random(len(magnitude)) for i in range(N_GARBAGE)}
outcome = depends_on_magnitude + 1 + np.random.rand(len(magnitude)) / 10
df = pd.DataFrame(
    {
        'magnitude': magnitude,
        'depends_on_magnitude': depends_on_magnitude,
        'switch': switch,
        **garbage,
        'outcome': outcome
    }
)

df_norm = df.apply(normalize)

In [None]:
# DANGER! OVERKILL HAZARD

import random

N_SAMPLES = 10000
N_GARBAGE = 5


source_a = np.random.randint(-100, -10, N_SAMPLES)
source_b = np.random.randint(0, 100, N_SAMPLES)
source_c = np.random.random(N_SAMPLES)


chain_a0 = np.ones(N_SAMPLES)
chain_b0 = np.ones(N_SAMPLES)
chain_c0 = np.ones(N_SAMPLES)

for i in range(N_SAMPLES):
  if source_a[i] >= -30:
    value = 5
  elif source_a[i] <= -70:
    value = 7
  else:
    value = np.sin(source_a[i])
  chain_a0[i] = value

for i in range(N_SAMPLES):
  if source_b[i] < 5:
    value = np.sin(source_b[i]) ** 2
  elif source_b[i] == 5:
    value = 2 * source_b[i]
  else:
    value = 1 / max([source_b[i],  1])
  chain_b0[i] = value

for i in range(N_SAMPLES):
  if source_c[i] < 0.5:
    value = (source_c[i] - 1) * 2
  else:
    value = source_c[i] / (np.cos(source_c[i]) - 1.1)
  chain_c0[i] = value


chain_a1 = np.ones(N_SAMPLES)
chain_b1 = np.ones(N_SAMPLES)
chain_c1 = np.ones(N_SAMPLES)

for i in range(N_SAMPLES):
  value = chain_a0[i]
  if '2' in str(round(value, 3)):
    value = np.sin(value)
  elif sum([int(char) for char in str(np.abs(value)).replace('.', '')]) % 2:
    value = 0.35
  else:
    value = 1 / min([value, 0.1])
  chain_a1[i] = value

for i in range(N_SAMPLES):
  value = chain_b0[i]
  if value > 0.8:
    value = np.tan(value)
  elif np.sin(value) > 0.5:
    value = -0.12
  else:
    value = (1 / min([value, -1])) ** 2
  chain_b1[i] = value

for i in range(N_SAMPLES):
  value = chain_c0[i]
  if value < 0:
    value = np.tan(value)
  else:
    value = 6
  chain_c1[i] = value


outcome = chain_a1 + chain_b1 + chain_c1

garbage = {
    f'garbage_{i}': random.choice(
        [
            np.random.randint(0, 9, N_SAMPLES),
            np.random.randint(0, 100, N_SAMPLES),
            np.random.random(N_SAMPLES)
        ]
    )
    for i in range(N_GARBAGE)
}

df_overkill = pd.DataFrame(
    {
      'source_a': source_a,
      'chain_a0': chain_a0,
      'chain_a1': chain_a1,
      'source_b': source_b,
      'chain_b0': chain_b0,
      'chain_b1': chain_b1,
      'source_c': source_c,
      'chain_c0': chain_c0,
      'chain_c1': chain_c1,
      'outcome': outcome,
      **garbage
    }
)

df = df_overkill.astype(float)
df_norm = df.apply(normalize)

In [None]:
df_norm.head()

Unnamed: 0,volume_aggregate,single_price,single_discount,price_rolling_vote,price_fill,label,period_single_price_mean,period_single_price_std,period_capped_volume_sum,period_capped_volume_std,period_duration,period_avg_daily_volume,pricing
0,0.000455,0.194941,0.0,0.194941,0.194941,0.0,0.194941,0.5,2e-05,0.0,0.0,0.001323,0.222772
1,0.047072,0.256624,0.300489,0.256624,0.256624,0.003205,0.256624,0.5,0.004922,0.067436,0.001938,0.092268,0.222772
2,0.023825,0.256624,0.300489,0.256624,0.256624,0.003205,0.256624,0.5,0.004922,0.067436,0.001938,0.092268,0.222772
3,0.032628,0.256624,0.300489,0.256624,0.256624,0.003205,0.256624,0.5,0.004922,0.067436,0.001938,0.092268,0.222772
4,2.1e-05,0.194941,0.0,0.194941,0.194941,0.00641,0.194941,0.5,3.4e-05,0.0,0.000969,0.001075,0.222772


In [None]:
def interpret_chain(chain, influence_threshold: float = 0.1):
  interpreted_chain = {}
  for name, df_influence in chain.items():
    query = df_influence.query('influence >= @influence_threshold')
    if not query.empty:
      interpreted_chain[name] = []
      for actor, row in query.iterrows():
        interpreted_chain[name].append((actor, row['influence']))
  return interpreted_chain


"""
e.g.
interpret_chain(
  chain=chain_from_bamt_bn_weights(bamt_bn_weights, 'outcome', 3, 0.1),
  influence_threshold=0.1
)
"""

"\ne.g.\ninterpret_chain(\n  chain=chain_from_bamt_bn_weights(bamt_bn_weights, 'outcome', 3, 0.1),\n  influence_threshold=0.1\n)\n"

# chain_from_MI_scores

In [None]:
from typing import List, Optional
def MI_scores(
    df,
    outcome_column_name: str,
    top_n: int,
    influence_threshold: float = 0.1,
    method: str = "regression",
    discrete: Optional[List[bool]] = None,
    normalize: bool = False,
    depth=None,
    branch=None,
    chain=None
):
    if chain is None:
      chain = {}
    if depth is None:
      depth = 0
    if branch is None:
      branch = 0

    if method == "classif":
        MI = mutual_info_classif
    elif method == "regression":
        MI = mutual_info_regression
    else:
        raise ValueError(f"Invalid method '{method}'")

    if discrete is not None:
        discrete_features = discrete
    else:
        discrete_features = 'auto'

    scores = MI(df, df[outcome_column_name], discrete_features=discrete_features)

    if normalize:
        scores /= scores[df.columns.get_loc(outcome_column_name)]

    df_influence = pd.DataFrame(columns=['influence'])
    for column, score in zip(df.columns, scores):
        df_influence.loc[column] = score
    df_influence = df_influence.sort_values('influence', ascending=False)
    df_influence = df_influence.drop(outcome_column_name)
    df_influence = df_influence.query('influence > 0.0')
    if df_influence.empty:
        return chain

    actors = None
    if influence_threshold is None:
      actors = list(df_influence.head(top_n).index)
    else:
      query = df_influence.query('influence >= @influence_threshold')
      if not query.empty:
        query = query.head(top_n)
        actors = list(query.index)
    chain[f'{depth}_{branch}::{outcome_column_name}'] = df_influence

    if actors is not None:
      for branch, actor in enumerate(actors):
        chain = MI_scores(
            df.drop(columns=[outcome_column_name] + actors[:branch] + actors[branch + 1:]),
            actor,
            top_n,
            influence_threshold,
            depth=depth + 1,
            branch=branch,
            chain=chain
        )

    return chain

In [None]:
MI_scores_chain = MI_scores(df, "SL, %", normalize=False, top_n=1, influence_threshold=0.01)

In [None]:
MI_scores_chain

{'0_0::SL, %':           influence
 Reason     0.319301
 Channel    0.049561
 Unit       0.008886
 Region     0.008609
 Brand      0.005153
 Category   0.004769
 Chain      0.003843
 Date       0.001387,
 '1_0::Reason':           influence
 Channel    0.087199
 Region     0.048903
 Unit       0.042949
 Category   0.026831
 Brand      0.026479
 Date       0.025112
 Chain      0.010254,
 '2_0::Channel':           influence
 Chain      0.138441
 Unit       0.082844
 Region     0.061633
 Category   0.031387
 Brand      0.012016
 Date       0.010933,
 '3_0::Chain':           influence
 Unit       0.741021
 Region     0.712782
 Category   0.084758
 Date       0.082571
 Brand      0.065787,
 '4_0::Unit':           influence
 Region     1.092461
 Date       0.178835
 Category   0.008437
 Brand      0.004134,
 '5_0::Region':           influence
 Date       0.081685
 Category   0.006683
 Brand      0.004824,
 '6_0::Date':           influence
 Category    0.00489}

In [None]:
interpret_chain(
  chain=MI_scores_chain,
  influence_threshold=0.01
)

{'0_0::SL, %': [('Reason', 0.3193006605267792),
  ('Channel', 0.049561334488466)],
 '1_0::Reason': [('Channel', 0.08719944837922355),
  ('Region', 0.048903000416197884),
  ('Unit', 0.04294931752213049),
  ('Category', 0.026830921572788213),
  ('Brand', 0.026479026487374213),
  ('Date', 0.025111630369210935),
  ('Chain', 0.01025386749326973)],
 '2_0::Channel': [('Chain', 0.13844067700492957),
  ('Unit', 0.08284428699013002),
  ('Region', 0.061633209655090226),
  ('Category', 0.03138739070107288),
  ('Brand', 0.012016322409471591),
  ('Date', 0.010933219006846073)],
 '3_0::Chain': [('Unit', 0.7410212571629771),
  ('Region', 0.7127818804131314),
  ('Category', 0.08475783755012634),
  ('Date', 0.08257140490948345),
  ('Brand', 0.06578693093917432)],
 '4_0::Unit': [('Region', 1.0924609205243723), ('Date', 0.17883528706528207)],
 '5_0::Region': [('Date', 0.08168504857364756)]}

# chain_from_bamt_bn_weights

In [None]:
bamt = {('chain_a1', 'chain_a0'): 0.14200882060084377,
 ('chain_b1', 'source_b'): 0.031185722570553725,
 ('source_c', 'chain_c1'): 0.7455230821178934,
 ('chain_a0', 'source_a'): 0.32612271818056227,
 ('chain_a1', 'source_a'): 0.008382118948296931,
 ('source_b', 'chain_b0'): 0.7626351993478976,
 ('chain_b1', 'chain_b0'): 0.008981732113565176,
 ('source_c', 'chain_c0'): 0.08648407629079621,
 ('chain_c1', 'chain_c0'): 0.20369053780858598,
 ('chain_a0', 'outcome'): 0.11397123701138077,
 ('chain_a1', 'outcome'): 0.04427446470407253,
 ('chain_b1', 'outcome'): 0.006637194695001037,
 ('chain_c0', 'outcome'): 0.4485699519482916,
 ('chain_c1', 'outcome'): 0.5368814146377209}

In [None]:
def chain_from_bamt_bn_weights(
    bamt_bn_weights,
    outcome_column_name: str,
    top_n: int,
    influence_threshold: float = 0.1,
    depth=None,
    branch=None,
    chain=None
):
  if chain is None:
    chain = {}
  if depth is None:
    depth = 0
  if branch is None:
    branch = 0

  df_influence = pd.DataFrame(columns=['influence'])
  bamt_bn_weights_ = copy.copy(bamt_bn_weights)

  for components, influence in bamt_bn_weights.items():
    if outcome_column_name in components:
      del bamt_bn_weights_[components]
      pair_component = (set(components) - set([outcome_column_name])).pop()
      df_influence.loc[pair_component] = influence
  if df_influence.empty:
    return chain
  df_influence = df_influence.sort_values('influence', ascending=False)
  actors = None
  if influence_threshold is None:
    actors = list(df_influence.head(top_n).index)
  else:
    query = df_influence.query('influence >= @influence_threshold')
    if not query.empty:
      query = query.head(top_n)
      actors = list(query.index)
  chain[f'{depth}_{branch}::{outcome_column_name}'] = df_influence
  if actors is not None:
    for branch, actor in enumerate(actors):
      weights_to_go = copy.copy(bamt_bn_weights_)
      for other_actor in actors:
        for components in bamt_bn_weights_.keys():
          if (other_actor != actor) and (other_actor in components):
            if components in weights_to_go:
              del weights_to_go[components]

      chain = chain_from_bamt_bn_weights(
          weights_to_go,
          actor,
          top_n,
          influence_threshold,
          depth + 1,
          branch,
          chain
      )
  return chain


"\ne.g.\ninterpret_chain(\n  chain=chain_from_bamt_bn_weights(bamt_bn_weights, 'outcome', 3, 0.1),\n  influence_threshold=0.1\n)\n"

In [None]:
chain_from_bamt_bn_weights(bamt, "outcome", 3)

{'0_0::outcome':           influence
 chain_c1   0.536881
 chain_c0   0.448570
 chain_a0   0.113971
 chain_a1   0.044274
 chain_b1   0.006637,
 '1_0::chain_c1':           influence
 source_c   0.745523,
 '1_1::chain_c0':           influence
 source_c   0.086484,
 '1_2::chain_a0':           influence
 source_a   0.326123
 chain_a1   0.142009}

In [None]:
interpret_chain(
  chain=chain_from_bamt_bn_weights(bamt, 'outcome', 3, 0.1),
  influence_threshold=0.1
)

{'0_0::outcome': [('chain_c1', 0.5368814146377209),
  ('chain_c0', 0.4485699519482916),
  ('chain_a0', 0.11397123701138077)],
 '1_0::chain_c1': [('source_c', 0.7455230821178934)],
 '1_2::chain_a0': [('source_a', 0.32612271818056227),
  ('chain_a1', 0.14200882060084377)]}