In [1]:
# make imports from pa_lib possible (parent directory of file's directory)
import sys
from pathlib import Path

file_dir = Path.cwd()
parent_dir = file_dir.parent
sys.path.append(str(parent_dir))

%load_ext autoreload
%autoreload

import pandas as pd
import numpy as np

import seaborn as sns
import matplotlib.pyplot as plt

sns.set()

from collections import namedtuple

from pa_lib.file import (
    project_dir,
    load_bin,
    write_xlsx,
    store_bin,
)
from pa_lib.data import as_dtype, dtFactor, desc_col, lookup, clean_up_categoricals
from pa_lib.util import (
    collect,
    value,
    flatten,
    normalize_rows,
    normalize_cols,
    list_items,
)
from pa_lib.log import time_log

# display long columns completely, show more rows
pd.set_option("display.max_colwidth", 300)
pd.set_option("display.max_rows", 200)
pd.set_option("display.max_columns", 200)

# Load data

In [2]:
with project_dir("axinova"):
    ax_data = load_bin("ax_data.feather")
    ax_var_struct = load_bin("ax_var_struct.feather")

Variable = namedtuple("Variable", ["Label", "Codes", "Order"])

var_info = {}
for var, data in ax_var_struct.groupby("Variable"):
    var_info[var] = Variable(
        data["Variable_Label"].max(),
        data["Label"].to_list(),
        data["Label_Nr"].to_list(),
    )

17:59:08 [INFO] Started loading binary file ...
17:59:08 [INFO] Reading from file C:\Users\kpf\data\axinova\ax_data.feather
17:59:10 [INFO] ... finished loading binary file in 1.43s (1.73s CPU)
17:59:10 [INFO] Started loading binary file ...
17:59:10 [INFO] Reading from file C:\Users\kpf\data\axinova\ax_var_struct.feather
17:59:10 [INFO] ... finished loading binary file in 0.01s (0.02s CPU)


In [3]:
def var_label(variable):
    return var_info[variable].Label


def var_codes(variable):
    return var_info[variable].Codes

In [4]:
weekday_order = ax_data["DayOfWeek"].cat.categories.to_list()
workdays = weekday_order[:5]
weekend = weekday_order[5:]

timeSlot_order = ax_data["TimeSlot"].cat.categories.to_list()
day = timeSlot_order[1:]
rush_hours = list_items(timeSlot_order, [1, 5])
day_no_rush = list_items(timeSlot_order, [2, 3, 4, 6])

stations_d = [
    "Aarau",
    "Basel SBB",
    "Bern",
    "Biel/Bienne",
    "Brig",
    "Chur",
    "Luzern",
    "Olten",
    "St. Gallen",
    "Winterthur",
    "Zug",
    "Zürich Enge",
    "Zürich Flughafen",
    "Zürich Flughafen - Airside",
    "Zürich Flughafen - Landside",
    "Zürich HB",
    "Zürich Hardbrücke",
    "Zürich Oerlikon",
    "Zürich Stadelhofen",
]
stations_f = [
    "Biel/Bienne",
    "Fribourg",
    "Genève Aéroport",
    "Genève Cornavin",
    "Lausanne",
    "M2",
    "Neuchatel",
]
stations_i = ["Bellinzona", "Lugano"]

In [5]:
ax_data.head()

Unnamed: 0,Station,DayOfWeek,Time,Variable,Code,Value,Year,Month,logValue,VarDesc,TimeSlot,ShortTime,Hour,is_weekend,is_day,is_rush,is_day_no_rush,TimeSlot_cat,StationSprache
0,Aarau,Monday,00:15 - 00:30,g_220,Keines,0.954451,2019,5,0.670109,Anzahl Autos im Haushalt,Nacht,00:15,0,False,False,False,False,Night,Deutsch
1,Aarau,Monday,00:15 - 00:30,g_500,quoted,0.954451,2019,5,0.670109,Zeitung: 20 Minuten / 20 Minutes / 20 Minuti,Nacht,00:15,0,False,False,False,False,Night,Deutsch
2,Aarau,Monday,00:15 - 00:30,g_501,not quoted,0.954451,2019,5,0.670109,Zeitung: Blick,Nacht,00:15,0,False,False,False,False,Night,Deutsch
3,Aarau,Monday,00:15 - 00:30,g_502,not quoted,0.954451,2019,5,0.670109,Zeitung: Tages-Anzeiger,Nacht,00:15,0,False,False,False,False,Night,Deutsch
4,Aarau,Monday,00:15 - 00:30,g_503,not quoted,0.954451,2019,5,0.670109,Zeitung: Mittelland Zeitung,Nacht,00:15,0,False,False,False,False,Night,Deutsch


# Define functions

## Select data

In [None]:
def _check_selection(data, selection, allowed_columns):
    allowed_values = {}
    for column in allowed_columns:
        allowed_values[column] = data[column].cat.categories
    if set(selection.keys()) - set(allowed_columns) != set():
        raise NameError(f"Unknown column name in selection: {selection.keys()}")
    clean_selection = {}
    for column in allowed_columns:
        if column in selection:
            col_values = list(flatten(selection[column]))
            if set(col_values) - set(allowed_values[column]) != set():
                raise ValueError(
                    f"Illegal value(s) in parameter {column}: {col_values}"
                )
            clean_selection[column] = col_values
        else:
            clean_selection[column] = None
    return clean_selection


def select_data(all_data, **selection):
    select_columns = "DayOfWeek Station Variable Month TimeSlot Hour Time TimeSlot_cat StationSprache Code".split()
    selection = _check_selection(all_data, selection, allowed_columns=select_columns)
    row_mask = pd.Series([True] * all_data.shape[0])
    for col in select_columns:
        if selection[col] is not None:
            row_mask &= all_data[col].isin(selection[col])
    return all_data.loc[row_mask].pipe(clean_up_categoricals).reset_index(drop=True)

In [None]:
select_data(
    ax_data,
    DayOfWeek=workdays,
    Station="Aarau",
    Variable="md_ek",
    Code=["Zwischen 9'001 und 12'000 CHF", "Mehr als 12'000 CHF"],
)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

sns.set()

In [None]:
subset = select_data(
    ax_data,
    DayOfWeek=workdays,
    Station="Zürich HB",
    Variable="md_ek",
    Code=["Zwischen 9'001 und 12'000 CHF", "Mehr als 12'000 CHF"],
)

In [None]:
fig, ax_example = plt.subplots(
    nrows=1, ncols=1, sharey="all", figsize=(12, 8)
)
sns.barplot(
    data=subset,
    x="Time",
    y="Value",
    ci=None,
    estimator=np.sum,
    ax=ax_example,
)
fig.autofmt_xdate(rotation=90, ha="left")

## Display selection as pivot table (Codes vs. Time) and heatmap

In [None]:
from scipy.stats import chi2_contingency


def _cont_table(var, data, index_by, aggfunc):
    all_codes = var_codes(var)
    observed_codes = (
        pd.Series(all_codes)
        .loc[pd.Series(all_codes).isin(data["Code"].unique())]
        .values
    )
    cont_table = data.pivot_table(
        index=index_by,
        columns="Code",
        values="Value",
        aggfunc=aggfunc,
        margins=False,
        fill_value=0,
    ).loc[:, observed_codes]
    return cont_table


def _show_chisq(var, actual, counts, resid_type):
    (_, p_val, _, expected) = chi2_contingency(actual)
    if resid_type == "absolute":
        residuals = actual - expected
    elif resid_type == "percent":
        residuals = (actual / expected * 100) - 100
    else:
        raise ValueError(
            f"Parameter resid_type not in ('absolute', 'percent') ('{resid_type}')"
        )
    (plot_rows, plot_cols) = residuals.shape
    plt.figure(figsize=(plot_cols * 2, plot_rows))
    plt.title(f"Abweichung: {var_label(var)} ({var}), p={round(p_val, 4)}", pad=12)
    annotations = residuals.round(1).values
    sns.heatmap(
        data=residuals,
        center=0,
        annot=annotations,
        linewidths=0,
        robust=True,
        fmt=".1f",
    )


def show_code_cont_tables(
    selection,
    index_by,
    aggfunc="sum",
    show_agg=True,
    show_normal=False,
    show_chisq=False,
    chisq_resid="absolute",
):
    data = select_data(ax_data, **selection)
    for var, subset in data.groupby("Variable"):
        # calculate contingency table
        prop_table = _cont_table(var, subset, index_by, aggfunc)
        count_table = _cont_table(var, subset, index_by, aggfunc="size")

        # show results
        if show_agg or show_normal:
            print(f"Variable: {var_label(var)} ({var})")
            print(
                "Selection: "
                + ", ".join(f"{col} = {values}" for col, values in selection.items())
            )
        if show_agg:
            display(prop_table.round(1))
        if show_normal:
            print("Code percentages:")
            display(normalize_rows(prop_table).round(3) * 100)
        if show_chisq:
            _show_chisq(
                var, actual=prop_table, counts=count_table, resid_type=chisq_resid,
            )

In [None]:
def _per_quarter(s):
    return s.sum() / s.size


def _extrapolate(s):
    return _per_quarter(s) * 4250


show_code_cont_tables(
    selection=dict(
        DayOfWeek=workdays, Station="Zürich HB", Variable=["md_ek"],
    ),
    index_by=["TimeSlot"],
    aggfunc="sum",
    show_chisq=True,
    chisq_resid="absolute",
)

## Estimate cell median and confidence intervals (quantiles)

In [None]:
from scipy.stats import chi2_contingency


def var_factor_dependance(data, variable, factor, partitions):
    subset = data.loc[data.Variable == variable].pipe(clean_up_categoricals)
    result = pd.DataFrame(columns=["Var", "Label", f"p_{factor}"])
    for label, partition in subset.groupby(partitions, observed=True):
        result_row = {"Var": variable, "Label": label}
        contingency_tab = partition.pivot_table(
            index="Code", columns=factor, aggfunc="size", fill_value=0, observed=True
        )
        p_factor = chi2_contingency(contingency_tab)[1]
        result_row[f"p_{factor}"] = p_factor
        result = result.append(result_row, ignore_index=True, sort=False)
    return result

In [None]:
def quartiles(s):
    try:
        (q25, med, q75) = np.percentile(s, [25, 50, 75])
        iqr = q75 - q25
        # return [q25 - 1.5 * iqr, med, q75 + 1.5 * iqr]  # Tukey's outlier limits
        return [q25, med, q75]  # quartiles
    except:
        return -1


def log_ci(s):
    try:
        log_series = np.log1p(s)
        log_mean = np.mean(log_series)
        log_std = np.std(log_series)
        (lo, md, hi) = log_mean - log_std, log_mean, log_mean + log_std
        return list(np.expm1((lo, md, hi)))
    except:
        return -1

In [None]:
selection = dict(DayOfWeek=workdays, Variable=["md_ek"], Station=stations_d,)

show_code_cont_tables(selection, index_by="TimeSlot", aggfunc=log_ci)

## Visualize estimations with confidence intervals

In [None]:
def plot_estimates(
    data,
    target,
    time_scale,
    plot_kind,
    plot_ci=100,
    heatmap_aggr="sum",
    heatmap_display_table=False,
    heatmap_norm=None,
):
    def _calc_heatmap(data, columns, **kwargs):
        heatmap = data.pivot_table(
            values=target,
            index=time_scale,
            columns="Code",
            aggfunc=heatmap_aggr,
            fill_value=0,
            observed=True,
            **kwargs,
        )
        if heatmap_norm == "Time":
            heatmap = normalize_rows(heatmap) * 100
        elif heatmap_norm == "Code":
            heatmap = normalize_cols(heatmap) * 100
        elif heatmap_norm == "both":
            heatmap = heatmap / heatmap.values.sum() * 100
        elif heatmap_norm is not None:
            raise ValueError(
                f"Parameter heatmap_norm must be one of [None, 'Time', 'Code', 'both'], is '{heatmap_norm}'"
            )
        heatmap_cols = list(filter(lambda col: col in heatmap.columns, columns))
        return heatmap.loc[:, heatmap_cols]

    for (var, data_subset) in data.groupby("Variable"):
        if plot_kind == "box":
            grid = sns.FacetGrid(
                data_subset, col=time_scale, col_wrap=3, height=4, aspect=1.5
            )
            grid.map(sns.boxenplot, "Code", target, order=var_codes(var))
            grid.fig.tight_layout(w_pad=1)
        elif plot_kind == "bar":
            grid = sns.FacetGrid(
                data_subset, col=time_scale, col_wrap=3, height=4, aspect=1.5
            )
            grid.map(sns.barplot, "Code", target, order=var_codes(var), ci=plot_ci)
            grid.fig.tight_layout(w_pad=1)
        elif plot_kind == "point":
            grid = sns.FacetGrid(
                data_subset, col=time_scale, col_wrap=3, height=4, aspect=1.5
            )
            grid.map(
                sns.pointplot,
                "Code",
                target,
                order=var_codes(var),
                ci=plot_ci,
                join=False,
                errwidth=2,
                capsize=0.1,
            )
            grid.fig.tight_layout(w_pad=1)
        elif plot_kind == "heatmap":
            plot_values = _calc_heatmap(data=data_subset, columns=var_codes(var),)
            (plot_rows, plot_cols) = plot_values.shape
            plt.figure(figsize=(plot_cols * 2, plot_rows))
            plt.title(f"{heatmap_aggr} of {var_label(var)} ({var})", pad=12)
            annotations = plot_values.round(1).values
            sns.heatmap(
                data=plot_values,
                center=0,
                annot=annotations,
                linewidths=0,
                robust=True,
                fmt=".1f",
            )
            if heatmap_display_table:
                display(
                    _calc_heatmap(
                        data=data_subset,
                        columns=var_codes(var) + ["Total"],
                        margins=True,
                        margins_name="Total",
                    ).round(1)
                )
        else:
            raise ValueError("Unknown plot_kind")

In [None]:
selection = dict(Station="Lausanne", Variable=["md_ek"], DayOfWeek=workdays,)

plot_estimates(
    data=select_data(ax_data, **selection),
    target="Value",
    time_scale="TimeSlot",
    plot_kind="point",
    plot_ci=95,
    heatmap_aggr="sum",
    heatmap_display_table=False,
    heatmap_norm="both",
)

## Chi-Square residuals

In [None]:
def _per_quarter(s):
    return s.sum() / s.size


def _extrapolate(s):
    return _per_quarter(s) * 4250


show_code_cont_tables(
    selection=dict(DayOfWeek=workdays, Station="Lausanne", Variable=["md_ek"]),
    index_by=["TimeSlot"],
    aggfunc="sum",
    show_agg=False,
    show_chisq=True,
    chisq_resid="absolute",
)