# Load Data

In [None]:
"""
Write a python function to load a CSV file with the following columns into a Pandas dataframe:
	
- instant: record index
- dteday : date
- season : season (1:winter, 2:spring, 3:summer, 4:fall)
- yr : year (0: 2011, 1:2012)
- mnth : month ( 1 to 12)
- hr : hour (0 to 23)
- holiday : weather day is holiday or not (extracted from http://dchr.dc.gov/page/holiday-schedule)
- weekday : day of the week
- workingday : if day is neither weekend nor holiday is 1, otherwise is 0.
- weathersit : 
    - 1: Clear, Few clouds, Partly cloudy, Partly cloudy
    - 2: Mist + Cloudy, Mist + Broken clouds, Mist + Few clouds, Mist
    - 3: Light Snow, Light Rain + Thunderstorm + Scattered clouds, Light Rain + Scattered clouds
    - 4: Heavy Rain + Ice Pallets + Thunderstorm + Mist, Snow + Fog
- temp : Normalized temperature in Celsius. The values are derived via (t-t_min)/(t_max-t_min), t_min=-8, t_max=+39 (only in hourly scale)
- atemp: Normalized feeling temperature in Celsius. The values are derived via (t-t_min)/(t_max-t_min), t_min=-16, t_max=+50 (only in hourly scale)
- hum: Normalized humidity. The values are divided to 100 (max)
- windspeed: Normalized wind speed. The values are divided to 67 (max)
- casual: count of casual users
- registered: count of registered users
- cnt: count of total rental bikes including both casual and registered

Convert the following column names as follows:

- dteday -> date
- yr -> year
- mnth -> month
- hr -> hour
- holiday -> is_holiday
- weekday -> day_of_week
- workingday -> is_workingday
- weathersit -> weather
- hum -> humidity
- causal -> num_casual_users
- registered -> num_registered_users
- cnt -> num_total_users

Convert the values in `season` column into categorical values using this map: (1:winter, 2:spring, 3:summer, 4:fall)
Convert values in year using this map: (0: 2011, 1:2012)
"""
import pandas as pd

def load_hour_data(filename):
    df = pd.read_csv(filename)
    df.rename(columns={'dteday': 'date', 'yr': 'year', 'mnth': 'month', 'hr': 'hour', 'holiday': 'is_holiday', 'weekday': 'day_of_week', 'workingday': 'is_workingday', 'weathersit': 'weather', 'hum': 'humidity', 'casual': 'num_casual_users', 'registered': 'num_registered_users', 'cnt': 'num_total_users'}, inplace=True)
    df['season'] = df['season'].map({1: 'winter', 2: 'spring', 3: 'summer', 4: 'fall'}).astype('category')
    df['year'] = df['year'].map({0: 2011, 1: 2012})

    df['timestamp'] = pd.to_datetime(df['date'] + ' ' + df['hour'].astype(str) + ':00:00')
    df.drop(['date', 'year'], axis=1, inplace=True)    
    return df


def load_day_data(filename):
    df = pd.read_csv(filename)
    df.rename(columns={'dteday': 'date', 'yr': 'year', 'mnth': 'month', 'holiday': 'is_holiday', 'weekday': 'day_of_week', 'workingday': 'is_workingday', 'weathersit': 'weather', 'hum': 'humidity', 'casual': 'num_casual_users', 'registered': 'num_registered_users', 'cnt': 'num_total_users'}, inplace=True)
    df['season'] = df['season'].map({1: 'winter', 2: 'spring', 3: 'summer', 4: 'fall'}).astype('category')
    df['year'] = df['year'].map({0: 2011, 1: 2012})

    df['timestamp'] = pd.to_datetime(df['date'])
    df.drop(['date', 'year'], axis=1, inplace=True)    
    return df

In [None]:
hour_file = '/Users/klogram/Downloads/Bike-Sharing-Dataset/hour.csv'
day_file = '/Users/klogram/Downloads/Bike-Sharing-Dataset/day.csv'

In these datasets, the instant (or the index) acts as a time step variable that we can use to compute time differences. In the hour dataset, the time step is one hour. In the day dataset, the time step in one day.

In [None]:
df_hour = load_hour_data(hour_file)
print(df_hour.shape)
df_hour.head()

In [None]:
df_day = load_day_data(day_file)
print(df_day.shape)
df_day.head()

# EDA

## Hour-level data

In [None]:
import matplotlib.pyplot as plt

# histogram of total users, with axis labels
df_hour['num_total_users'].hist()
plt.xlabel('Total Users')
plt.ylabel('Frequency')
plt.show()

In [None]:
# histogram of casual users, with axis labels
df_hour['num_casual_users'].hist()
plt.xlabel('Casual Users')
plt.ylabel('Frequency')

In [None]:
# histogram of registered users, with axis labels
df_hour['num_registered_users'].hist()
plt.xlabel('Registered Users')
plt.ylabel('Frequency')

# Time-dependent Causal Inference

We follow the description of how the day-level bike sharing dataset was analyzed in Zheng and Kleinberg, 2017. Like them, we divide the continuous variables into three bins of equal size.

In [None]:
"""
Write a function that takes a dataframe and column name as input and 
adds a new column to the dataframe that labels each row as 'low', 'medium', or 'high',
based on whether the value in the original column is above or below one 
standard deviation from the mean.
"""
def label_by_std(df, col_name):
    std = df[col_name].std()
    mean = df[col_name].mean()
    df[f'{col_name}_bin'] = df[col_name].apply(lambda x: 'low' if x < mean - std else ('high' if x > mean + 2*std else 'medium'))
    return df

df_hour = label_by_std(df_hour, 'num_total_users')
df_hour = label_by_std(df_hour, 'num_casual_users')
df_hour = label_by_std(df_hour, 'num_registered_users')

In [None]:
df_hour["num_total_users_bin"].value_counts()

In [None]:
"""
write a function that takes a dataframe column and converts the values into 
three bins of equal size, based on the values in the column.
"""
def convert_to_bins(df, column):
    return pd.qcut(df[column], 3, labels=['low', 'medium', 'high'])

# convert the temp column into three bins
df_hour['temp_bin'] = convert_to_bins(df_hour, 'temp')

# convert the atemp column into three bins
df_hour['atemp_bin'] = convert_to_bins(df_hour, 'atemp')

# convert the humidity column into three bins
df_hour['humidity_bin'] = convert_to_bins(df_hour, 'humidity')

# convert the windspeed column into three bins
df_hour['windspeed_bin'] = convert_to_bins(df_hour, 'windspeed')

# convert the num_total_users column into three bins
# df_hour['num_total_users_bin'] = convert_to_bins(df_hour, 'num_total_users')

# convert the num_casual_users column into three bins
# df_hour['num_casual_users_bin'] = convert_to_bins(df_hour, 'num_casual_users')

# convert the num_registered_users column into three bins
# df_hour['num_registered_users_bin'] = convert_to_bins(df_hour, 'num_registered_users')


## Variables

The variables represent the things that are potential causes. Here are the variables:

- is_holiday
- is_workingday
- is_weekend
- is_raining
- is_bad_weather
- is_mild_precipitation
- is_daytime
- is_nighttime
- is_`<season>`
- is_high_temp
- is_high_atemp
- is_low_temp
- is_low_atemp
- is_high_humidity
- is_low_humidity
- is_windy
- not_windy

### Functions for determing the state 

For each variable, we define a boolean function that tells us what is true or false at any point in time in the dataset. 


In [None]:
def is_holiday(df, instant):
    return 1 if df.loc[instant, 'is_holiday'] == 1 else 0

def is_workingday(df, instant):
    return 1 if df.loc[instant, 'is_workingday'] == 1 else 0

def is_weekend(df, instant):
    return 1 if df.loc[instant, 'day_of_week'] in [0, 6] else 0
    
def is_bad_weather(df, instant):
    return 1 if df.loc[instant, 'weather'] in [3, 4] else 0

def is_mild_precipitation(df, instant):
    return 1 if df.loc[instant, 'weather'] == 3 else 0

def is_rush_hour(df, instant):
    return 1 if df.loc[instant, 'hour'] in [7, 8, 9, 17, 18, 19] else 0

def is_daytime(df, instant):
    return 1 if df.loc[instant, 'hour'] in [6, 7, 8, 9, 10, 11, 12, 13, 14, 15] else 0

def is_nighttime(df, instant):
    return 1 if df.loc[instant, 'hour'] in [0, 1, 2, 3, 4, 5, 20, 21, 22, 23] else 0

def is_spring(df, instant):
    return 1 if df.loc[instant, 'season'] == 'spring' else 0

def is_summer(df, instant):
    return 1 if df.loc[instant, 'season'] == 'summer' else 0

def is_fall(df, instant):
    return 1 if df.loc[instant, 'season'] == 'fall' else 0

def is_winter(df, instant):
    return 1 if df.loc[instant, 'season'] == 'winter' else 0

def is_high_temp(df, instant):
    return 1 if df.loc[instant, 'temp_bin'] == 'high' else 0

def is_low_temp(df, instant):
    return 1 if df.loc[instant, 'temp_bin'] == 'low' else 0

def is_high_atemp(df, instant):
    return 1 if df.loc[instant, 'atemp_bin'] == 'high' else 0

def is_low_atemp(df, instant):
    return 1 if df.loc[instant, 'atemp_bin'] == 'low' else 0

def is_high_humidity(df, instant):
    return 1 if df.loc[instant, 'humidity_bin'] == 'high' else 0

def is_low_humidity(df, instant):
    return 1 if df.loc[instant, 'humidity_bin'] == 'low' else 0

def is_windy(df, instant):
    return 1 if df.loc[instant, 'windspeed_bin'] == 'high' else 0

def not_windy(df, instant):
    return 1 if df.loc[instant, 'windspeed_bin'] == 'low' else 0

labelers = {
    'is_holiday': is_holiday,
    'is_workingday': is_workingday,
    'is_weekend': is_weekend,
    'is_bad_weather': is_bad_weather,
    'is_mild_precipitation': is_mild_precipitation,
    'is_rush_hour': is_rush_hour,
    'is_daytime': is_daytime,
    'is_nighttime': is_nighttime,
    'is_spring': is_spring,
    'is_summer': is_summer,
    'is_fall': is_fall,
    'is_winter': is_winter,
    'is_high_temp': is_high_temp,
    'is_low_temp': is_low_temp,
    'is_high_atemp': is_high_atemp,
    'is_low_atemp': is_low_atemp,
    'is_high_humidity': is_high_humidity,
    'is_low_humidity': is_low_humidity,
    'is_windy': is_windy,
    'not_windy': not_windy
}

idx2labeler = {i: labeler for i, labeler in enumerate(labelers.keys())}

label_fns = [labelers[idx2labeler[i]] for i in range(len(labelers))]

We also want some boolean functions for the effect values.

In [None]:
def high_total_users(df, instant):
    return 1 if df.loc[instant, 'num_total_users_bin'] == 'high' else 0

def low_total_users(df, instant):
    return 1 if df.loc[instant, 'num_total_users_bin'] == 'low' else 0

def high_casual_users(df, instant):
    return 1 if df.loc[instant, 'num_casual_users_bin'] == 'high' else 0

def low_casual_users(df, instant):
    return 1 if df.loc[instant, 'num_casual_users_bin'] == 'low' else 0

def high_registered_users(df, instant):
    return 1 if df.loc[instant, 'num_registered_users_bin'] == 'high' else 0

def low_registered_users(df, instant):
    return 1 if df.loc[instant, 'num_registered_users_bin'] == 'low' else 0

## Using type-level relationships to explain observed events

Here, we identify type-level relationships that explain why bike rentals are high or low. 

Here are the steps in the process:

1. Identify the token-level relationships in the dataset that are between the variables listed above
2. compute the significance of the type-level relationships, $\epsilon_{avg}(c_{r-s},e)$

We aim to evaluate the significance of each type-level cause for each instance of high and low bike rentals. For each instance of high and low bike rentals, we then use a significance threshold or only keep the top $k$ relationships to partition the relationships/events into those for which sufficiently significant type-level causes have been found and those for which sufficiently significant type-level causes have *not* been found. 

### Identify token-level relationships

We will pass through the dataset, and at each point, we will use the values of the various variables to select *prima facie* causes that can be given a significance score in the next step. 

In [None]:
import numpy as np

V = len(labelers)  # number of variables
T = df_hour.shape[0]  # number of time steps
D = np.zeros((T, V), dtype=np.int64)  # value of each variable at each time step

print("Number of potential type-level causes: %d" % V)
print("Number of time steps: %d" % T)

for t in range(T):
    for i in range(V):
        D[t, i] = label_fns[i](df_hour, t)

variable_names = idx2labeler.values()
df_D = pd.DataFrame(D, columns=variable_names)

# add effect columns to the dataframe
# this makes it easy to filter to compute proabilities
df_D["high_total_users"] = [high_total_users(df_hour, t) for t in range(T)]
df_D["low_total_users"] = [low_total_users(df_hour, t) for t in range(T)]
df_D["high_casual_users"] = [high_casual_users(df_hour, t) for t in range(T)]
df_D["low_casual_users"] = [low_casual_users(df_hour, t) for t in range(T)]
df_D["high_registered_users"] = [high_registered_users(df_hour, t) for t in range(T)]
df_D["low_registered_users"] = [low_registered_users(df_hour, t) for t in range(T)]

# df_hour = pd.concat([df_hour, df_D], axis=1)

# remove duplicate columns
# df_hour = df_hour.loc[:,~df_hour.columns.duplicated()]

# df_hour.head()

## Data Structures for Causal Relationships

Here we introduce several data structures that we will use to work with type-level and token-level relationships.
First, we have two class that let us store the names of causal variables, so that we don't have to keep passing around strings.

In [None]:
from typing import List

VariableIndex = int


class BidirectionalDict(dict):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._backward = {v: k for k, v in self.items()}

    def __setitem__(self, key, value):
        super().__setitem__(key, value)
        self._backward[value] = key

    def forward_lookup(self, key):
        return self[key]

    def backward_lookup(self, value):
        return self._backward[value]

    def keys(self):
        return list(super().keys())

    def values(self):
        return list(self._backward.keys())


class VariableStore:
    def __init__(self):
        self.storage = BidirectionalDict()

    def add(self, variable_name: str):
        if variable_name not in self.storage:
            self.storage[variable_name] = len(self.storage)

    def lookup_by_name(self, name: str) -> VariableIndex:
        return self.storage.forward_lookup(name)

    def lookup_by_index(self, index: VariableIndex) -> str:
        return self.storage.backward_lookup(index)

    def __len__(self) -> int:
        return len(self.storage)

    def __contains__(self, name) -> bool:
        return name in self.storage

    @property
    def names(self) -> List[str]:
        return sorted(self.storage.keys())

    @property
    def ids(self) -> List[VariableIndex]:
        return sorted(self.storage.values())


Next, we have data structures for time window, type-level and token-level causal relationships, and significance scores.

In [None]:
from dataclasses import dataclass
from typing import List, Optional, Set, Union

class Window:
    """
    A window of time, of the closed interval [start, end]
    """

    def __init__(self, start: int, end: int):
        if start > end:
            raise ValueError("Window start must be <= than end")
        if start < 0 or end < 0:
            raise ValueError("Window start and end must be >= 0")
        self.start = start
        self.end = end

    def __repr__(self):
        return f"Window({self.start}, {self.end})"

    def __eq__(self, __value: object) -> bool:
        if isinstance(__value, Window):
            return self.start == __value.start and self.end == __value.end
        return False

    def __hash__(self) -> int:
        return hash((self.start, self.end))


@dataclass(frozen=True)
class CausalRelation:
    """
    A cause and effect pair
    """

    cause: VariableIndex
    effect: VariableIndex


@dataclass(frozen=True)
class TokenCause:
    """
    A token event that supports a potential cause
    """

    relation: CausalRelation
    t_cause: int  # time step where cause is true
    t_effect: int  # time step where effect is true

    @property
    def lag(self) -> int:
        """The lag between the cause and effect"""
        return self.t_effect - self.t_cause


class TypeLevelCause:
    r"""
    A potential cause of an effect, with a time window, a probability of
    occurrence, and a list of token events that support this type-level
    cause. In PCTL language, `c \leadsto {}^{\geq r, \leq s}_{\geq p} e`
    """

    def __init__(
        self, relation: CausalRelation, window: Window, prob: float, token_events: List[TokenCause]
    ):
        """
        Create a type-level cause

        Parameters
        ----------
        relation : CausalRelation, the cause and effect
        window : Window, the time window in which the effect occurs after the cause
        prob : float, the probability of the effect occurring in the window
        token_events : List[TokenCause], the token events that support this type-level cause
        """
        self.relation = relation
        self.window = window
        self.prob = prob
        self.token_events = token_events

    def __repr__(self):
        return f"TypeLevelCause({self.relation}, {self.window}, {self.prob})"

    # define equality and hashing based on the relation, window, prob, and token events
    def __eq__(self, other):
        if isinstance(other, TypeLevelCause):
            return (
                self.relation == other.relation
                and self.window == other.window
                and self.prob == other.prob
                and self.token_events == other.token_events
            )
        return False

    def __hash__(self):
        return hash((self.relation, self.window, self.prob, tuple(self.token_events)))


@dataclass(frozen=True)
class CompositeScore:
    lag_scores: np.array

    @property
    def score(self):
        return np.sum(self.lag_scores)


Now we add a class that identifies the causal relationships and computes their relative significance. This makes it easy to compute things with a few simple method calls.

In [None]:
class CausalTree:
    def __init__(self, df: pd.DataFrame, cause_names: List[str], effect_name: str):
        CausalTree.validate_causal_variables(df, cause_names, effect_name)
        self.df = df
        self.store = VariableStore()

        for name in cause_names + [effect_name]:
            self.store.add(name)

        self.type_level_causes: List[TypeLevelCause] = []
        self.type_level_significance_scores: List[CompositeScore] = []
        self.max_lag = -1

    @staticmethod
    def validate_causal_variables(df: pd.DataFrame, cause_names: List[str], effect_name: str):
        # check that all variables are in the dataframe
        missing_names = set(cause_names) - set(df.columns)
        if effect_name not in df.columns:
            missing_names.add(effect_name)
        if len(missing_names) > 0:
            raise ValueError(f"Variables not in df: {missing_names}")

        # check that all variables are binary-valued
        non_binary_cols = []
        for col in df.columns:
            if not df[col].isin([0, 1]).all():
                non_binary_cols.append(col)
        if len(non_binary_cols) > 0:
            raise ValueError(f"Non-binary columns: {non_binary_cols}")

    def get_variable_name(self, i: VariableIndex) -> str:
        return self.store.lookup_by_index(i)

    def get_variable_index(self, name: str) -> VariableIndex:
        return self.store.lookup_by_name(name)

    def get_relation(self, cause: str, effect: str) -> CausalRelation:
        cause_idx = self.store.lookup_by_name(cause)
        effect_idx = self.store.lookup_by_name(effect)
        return CausalRelation(cause_idx, effect_idx)

    def cause_holds_at(self, t: int, cause: Union[str, VariableIndex]) -> bool:
        # Returns true if the given cause is true at the given time step
        cause_name = cause if isinstance(cause, str) else self.get_variable_name(cause)
        return self.df.at[t, cause_name] == 1

    def effect_holds_in(self, window: Window, effect: Union[str, VariableIndex]) -> bool:
        # Returns true if the given effect happened in the given window
        e = effect if isinstance(effect, str) else self.get_variable_name(effect)
        return self.df.loc[window.start : window.end, e].sum() > 0

    def get_token_effect_times(self, relation: CausalRelation, window: Window, t: int) -> List[int]:
        # Returns time step(s) where causal effect is true in the given window
        w = Window(t + window.start, t + window.end)
        cause_holds = self.cause_holds_at(t, relation.cause)
        effect_holds_in_window = self.effect_holds_in(w, relation.effect)

        if cause_holds and effect_holds_in_window:
            effect_name = self.get_variable_name(relation.effect)
            # find each time step where effect is true
            subset = self.df.loc[w.start : w.end, effect_name]
            indices = subset[subset == 1].index
            effect_times = [t for t in indices]
            return effect_times
        return []

    def c_leadsto_e_in(self, window: Window, relation: CausalRelation, t: int) -> int:
        # Returns 1 if cause leads to effect inside window, 0 otherwise.
        w = Window(t + window.start, t + window.end)
        cause_holds = self.cause_holds_at(t, relation.cause)
        effect_holds_in_window = self.effect_holds_in(w, relation.effect)

        if cause_holds and effect_holds_in_window:
            return 1
        return 0

    @property
    def num_time_steps(self) -> int:
        return self.df.shape[0]

    def identify_potential_cause(
        self, relation: CausalRelation, window: Window
    ) -> Optional[TypeLevelCause]:
        # Finds type-level cause that precedes effect inside the window, if it exists.
        dt = window.end - window.start
        token_causes = []
        num_c_leadsto_e = 0

        for t in range(self.num_time_steps - dt):
            num_c_leadsto_e += self.c_leadsto_e_in(window, relation, t)
            effect_times = self.get_token_effect_times(relation, window, t)
            for t_effect in effect_times:
                token_causes.append(TokenCause(relation, t, t_effect))

        if num_c_leadsto_e > 0:
            cause_column = self.get_variable_name(relation.cause)
            prob = num_c_leadsto_e / self.df[cause_column].sum()
            return TypeLevelCause(relation, window, prob, token_causes)
        return None

    def build(self, effect: str, max_lag: int = 5, verbose: bool = False):
        # Build the tree
        if effect not in self.store:
            raise ValueError(f"Effect variable {effect} not in store")

        self.max_lag = max_lag  # maximum lag to consider
        # create all possible windows
        windows = []
        for start in range(1, max_lag):
            for end in range(start + 1, max_lag + 1):
                windows.append(Window(start, end))
        print(f"Created {len(windows)} windows.")
        if verbose:
            for window in windows:
                print(window)

        # identify all potential type-level causes
        for cause in self.store.names:
            if cause == effect:
                continue
            if verbose:
                print(f"Finding valid windows for {cause} => {effect}")
            relation = self.get_relation(cause, effect)
            num_causes = 0
            num_token_events = 0

            for window in windows:
                type_level_cause = self.identify_potential_cause(relation, window)
                if type_level_cause is not None:
                    self.type_level_causes.append(type_level_cause)
                    num_causes += 1
                    num_token_events += len(type_level_cause.token_events)
            if verbose:
                print(f"Found {num_causes} type-level relations for {cause} => {effect}")
                print(f"Found {num_token_events} token events for {cause} => {effect}")

    def _prob_given_c_and_x(self, c: TypeLevelCause, x: TypeLevelCause) -> np.array:
        # Computes $P(e\vert c \land x)$ for all lags        
        window = c.window
        cause = self.get_variable_name(c.relation.cause)
        effect = self.get_variable_name(c.relation.effect)
        x_name = self.get_variable_name(x.relation.cause)
        dt = window.end - window.start
        num_e_after_cx = np.zeros(self.max_lag + 1)

        for t in range(self.num_time_steps - dt):
            # if cause or x did not happen, skip
            if self.df.at[t, cause] == 0 or self.df.at[t, x_name] == 0:
                continue
            for t1 in range(t + window.start, t + window.end + 1):
                if t1 >= self.num_time_steps:
                    break
                lag = t1 - t
                num_e_after_cx[lag] += self.df.at[t1, effect]

        if num_e_after_cx.sum() == 0:
            return np.zeros(self.max_lag + 1)

        num_c_and_x = self.df[(self.df[cause] == 1) & (self.df[x_name] == 1)].shape[0]
        if num_c_and_x == 0:
            return np.zeros(self.max_lag + 1)
        prob = num_e_after_cx / num_c_and_x
        return prob

    def _prob_given_notc_and_x(self, c: TypeLevelCause, x: TypeLevelCause) -> np.array:
        # Computes $P(e\vert \neg{c} \land x)$, for each lag
        window = c.window
        cause = self.get_variable_name(c.relation.cause)
        effect = self.get_variable_name(c.relation.effect)
        x_name = self.get_variable_name(x.relation.cause)
        dt = window.end - window.start
        num_e_after_notcx = np.zeros(self.max_lag + 1)

        for t in range(self.num_time_steps - dt):
            # if c happened or x did not happen, skip
            if self.df.at[t, cause] == 1 or self.df.at[t, x_name] == 0:
                continue
            for t1 in range(t + window.start, t + window.end + 1):
                if t1 >= self.num_time_steps:
                    break
                lag = t1 - t
                num_e_after_notcx[lag] += self.df.at[t1, effect]

        if num_e_after_notcx.sum() == 0:
            return np.zeros(self.max_lag + 1)

        num_notc_and_x = self.df[(self.df[cause] == 0) & (self.df[x_name] == 1)].shape[0]
        if num_notc_and_x == 0:
            return np.zeros(self.max_lag + 1)
        prob = num_e_after_notcx / num_notc_and_x
        return prob

    def filter_type_level_causes_by_window(self, window: Window) -> List[TypeLevelCause]:
        return [c for c in self.type_level_causes if c.window == window]

    def get_other_causes_in_window(self, c: TypeLevelCause) -> Set[TypeLevelCause]:
        return set(self.filter_type_level_causes_by_window(c.window)) - {c}

    def _compute_type_level_significance(self, cause: TypeLevelCause) -> np.array:
        e_avg = np.zeros(self.max_lag + 1)
        X = self.get_other_causes_in_window(cause)
        for x in X:
            e_avg += self._prob_given_c_and_x(cause, x) - self._prob_given_notc_and_x(cause, x)
        e_avg /= len(X)
        return e_avg

    def compute_significance(
        self,
        verbose: bool = False,
    ) -> List[CompositeScore]:
        #  Compute significance of each type-level cause
        scores = []
        for cause in self.type_level_causes:
            if verbose:
                print("Computing significance of type-level cause: %s" % cause)
            e_avg = self._compute_type_level_significance(cause)
            scores.append(CompositeScore(e_avg))

        self.type_level_significance_scores = scores
        return scores

    def prune(self):
        # Remove type-level causes with negative significance score from tree
        pruned_causes = []
        pruned_scores = []
        for cause, score in zip(self.type_level_causes, self.type_level_significance_scores):
            if score.score.sum() > 0:
                pruned_causes.append(cause)
                pruned_scores.append(score)
        self.type_level_causes = pruned_causes
        self.type_level_significance_scores = pruned_scores


Now we can identify the type-level relationships and find the most significanct ones.

In [None]:
tree = CausalTree(df_D, list(variable_names), "high_total_users")
tree.build("high_total_users", max_lag=3)
tree.compute_significance()
tree.prune()

In [None]:
from dataclasses import dataclass

@dataclass 
class ScoredTypeLevelCause:
    cause: TypeLevelCause
    score: float
    
    def __str__(self):
        return f"{self.cause}: {self.score:.2f}"

    def __repr__(self):
        return str(self)

scored_causes = []
for cause, composite_score in zip(tree.type_level_causes, tree.type_level_significance_scores):
    scored_causes.append(ScoredTypeLevelCause(cause, composite_score.score))

scored_causes.sort(key=lambda x: x.score, reverse=True)

In [None]:
for c in scored_causes[:10]:
    cause = tree.get_variable_name(c.cause.relation.cause)
    effect = tree.get_variable_name(c.cause.relation.effect)
    print(f"{cause} => {effect}: {c.score:.2f}, {c.cause.prob:.2f} p-value, {c.cause.window} ")
