In [None]:
from math import isnan
import os

import numpy as np
import pandas as pd
from pandas.plotting._matplotlib.style import get_standard_colors
import matplotlib.pyplot as plt
from patsy import dmatrices
import statsmodels.api as sm

from src.evb import get_evb_incidence, get_evb_sick_notes
from src.rki import get_consultations, get_grippeweb

In [None]:
evb = get_evb_incidence().set_index('Kalenderwoche')
consultations = get_consultations().set_index('Kalenderwoche')
grippeweb = get_grippeweb().set_index('Kalenderwoche')
mvz = get_evb_sick_notes()

In [None]:
def model_did(
    treatment, control,
    start, end, intervention,
    plot=False
):
    def _extract(df, is_treatment = False):
        df = df.rename(columns={ df.columns[0]: 'incidence' }).loc[pd.to_datetime(start):pd.to_datetime(end)]
        df['group_is_treatment'] = 1 if is_treatment else 0
        return df
    
    df = pd.concat([_extract(treatment, is_treatment=True), _extract(control)]).reset_index()
    df['after_treatment'] = (df['Kalenderwoche'] > pd.to_datetime(intervention)).astype(int)
    y_train, X_train = dmatrices('incidence ~ group_is_treatment + after_treatment + group_is_treatment * after_treatment', df, return_type='dataframe')
    did_model = sm.OLS(endog=y_train, exog=X_train)
    did = did_model.fit()
    
    if plot:
        df['prediction'] = did.predict()

        grouped = df.set_index('Kalenderwoche').groupby('group_is_treatment')
        colors = get_standard_colors(num_colors=len(grouped))
        fig, ax = plt.subplots(figsize=(10, 5))
        axs = [ax, ax.twinx()]

        for (label, df), color, ax in zip(grouped, colors, axs):
            name = 'Treatment' if label == 1 else 'Control'
            df['incidence'].plot(ax=ax, label=name, color=color)
            df['prediction'].plot(ax=ax, label='mean', color=color, alpha=0.6)

        for ax, label in zip(axs, ['Control', 'Treatment']):
            ax.set_ylabel(label)
            ax.legend()

    return did

In [None]:
def plot_results(
    treatment, control,
    start='2020-W1', years = 3.5,
    test_range = 2, p_threshold = 0.05,
    plot_p = False, plot_interventions = True,
    save_fig = None
):
    start_date = pd.to_datetime(f'{start}-1', format='%G-W%V-%u')
    n_weeks = int(52 * years)

    weeks = [start_date + pd.Timedelta(i, 'W') for i in range(n_weeks)]
    models = [
        model_did(treatment, control, weeks[i - test_range], weeks[i + test_range],  weeks[i])
        for i in range(test_range, n_weeks - test_range)
    ]
    def _padded(data):
        padding = ([np.nan] * test_range)
        return padding + data + padding

    df = pd.DataFrame({
        'Kalenderwoche': weeks,
        'p-Wert': _padded([model.pvalues['group_is_treatment:after_treatment'] for model in models])
    }).set_index('Kalenderwoche')

    if plot_p:
        ax = df.plot(figsize=(10, 5), logy=True)
        plt.axhline(y=p_threshold, color='r', linestyle='-', label=f'{p_threshold * 100}%')
        plt.legend()

    df['Parameter'] = _padded([model.params['group_is_treatment:after_treatment'] for model in models])
    alphas = -np.log(df['p-Wert'])
    alphas = (alphas - (-np.log(p_threshold))) / (alphas.max() - alphas.min())
    alphas[alphas < 0] = 0

    fig, ax = plt.subplots(figsize=(10, 5))
    treatment_label = treatment.columns[0]
    ax.plot(df.index, treatment.loc[df.index[0]:df.index[-1]][treatment_label], label=treatment_label)
    control_label = control.columns[0]
    ax.plot(df.index, control.loc[df.index[0]:df.index[-1]][control_label], label=control_label)
    ax.set_ylabel("Inzidenz")
    fig.legend()
    
    for alpha, week in zip(alphas, weeks):
        if isnan(alpha): continue
        ax.axvspan(
            week - pd.Timedelta(3, 'D'), week + pd.Timedelta(3, 'D'),
            alpha=alpha,
            color='red' if df['Parameter'][week] > 0 else 'green'
        )
    
    ax.scatter(start_date, -1000, alpha=0)
    if plot_interventions:
        if os.path.isfile('data/interventions.tsv'):
            interventions = pd.read_csv('data/interventions.tsv', sep='\t')
            interventions['y'] = [-1000] * len(interventions)
            relaxed = interventions[interventions['measure_type_manually'] == 'relaxed']
            tightened = interventions[interventions['measure_type_manually'] == 'tightened']
            ax.scatter(pd.to_datetime(relaxed['timestamp']), relaxed.y, marker='o', color='gray', alpha=0.6)
            ax.scatter(pd.to_datetime(tightened['timestamp']), tightened.y, marker='^', color='gray', alpha=0.6)
        else:
            print('data/interventions.tsv not found, skipping')
    
    if save_fig:
        fig.savefig(save_fig)
    return df

In [None]:
_ = plot_results(evb, consultations, save_fig='evb-consultations-interventions.svg')