In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import pyfixest as pf

In [2]:
from saturated import test_treatment_heterogeneity
from dgp import panel_dgp_stagg, generate_treatment_effect
from plotters import diag_plot

In [3]:
def generate_dgp(
    num_periods=30,
    cohort_specs=None,  # dict with keys for each cohort
    sigma_i=2,
    sigma_t=1,
    sigma_epsilon=1,
    num_units=20_000,
):
    # Default cohort specs if none provided
    if cohort_specs is None:
        cohort_specs = {
            "cohort1": {
                "effect_type": "concave",
                "start_time": 10,
                "max_effect": 1,
                "size": num_units // 3,
            },
            "cohort2": {
                "effect_type": "concave",
                "start_time": 15,
                "max_effect": 1,
                "size": num_units // 3,
            },
            "cohort3": {
                "effect_type": "concave",
                "start_time": 20,
                "max_effect": 1,
                "size": num_units // 3,
            },
        }

    # Extract lists for panel_dgp_stagg
    treatment_starts = [spec["start_time"] for spec in cohort_specs.values()]
    cohort_sizes = [spec["size"] for spec in cohort_specs.values()]

    # Generate treatment effects
    base_treatment_effects = [
        generate_treatment_effect(
            effect_type=spec["effect_type"],
            T=num_periods,
            T0=spec["start_time"],
            max_effect=spec["max_effect"],
        )
        for spec in cohort_specs.values()
    ]
    return base_treatment_effects

    # Run DGP
    dgp = panel_dgp_stagg(
        num_units=num_units,
        num_treated=cohort_sizes,
        num_periods=num_periods,
        treatment_start_cohorts=treatment_starts,
        base_treatment_effects=base_treatment_effects,
        sigma_unit=sigma_i,
        sigma_time=sigma_t,
        sigma_epsilon=sigma_epsilon,
    )

    return dgp

In [4]:
homog_specs = {
    f"cohort{i}": {
        "effect_type": "concave",
        "start_time": start,
        "max_effect": 1,
        "size": size,
    }
    for i, (start, size) in enumerate(zip([10, 15, 20], [2500, 5000, 2500]))
}

df_homog = generate_dgp(cohort_specs=homog_specs)


In [5]:
df_homog

[array([0.09531018, 0.18232156, 0.26236426, 0.33647224, 0.40546511,
        0.47000363, 0.53062825, 0.58778666, 0.64185389, 0.69314718,
        0.74193734, 0.78845736, 0.83290912, 0.87546874, 0.91629073,
        0.95551145, 0.99325177, 1.02961942, 1.06471074, 1.09861229]),
 array([0.12516314, 0.23638878, 0.33647224, 0.42744401, 0.51082562,
        0.58778666, 0.65924563, 0.725937  , 0.78845736, 0.84729786,
        0.90286771, 0.95551145, 1.00552187, 1.05314991, 1.09861229]),
 array([0.18232156, 0.33647224, 0.47000363, 0.58778666, 0.69314718,
        0.78845736, 0.87546874, 0.95551145, 1.02961942, 1.09861229])]

In [None]:

# Extract parameters needed for diagnostic plot
treatment_starts = [spec["start_time"] for spec in homog_specs.values()]
base_effects = [
    generate_treatment_effect(
        effect_type=spec["effect_type"],
        T=30,
        T0=spec["start_time"],
        max_effect=spec["max_effect"],
    )
    for spec in homog_specs.values()
]

# Run diagnostic plot
diag_plot(df_homog, treatment_starts, base_effects)

In [6]:
test_treatment_heterogeneity(df_homog)

            The following variables are collinear: ['C(rel_time, contr.treatment(base=-1.0))[T.inf]', 'C(rel_time, contr.treatment(base=-1.0))[T.-20.0]:cohort_dummy_15', 'C(rel_time, contr.treatment(base=-1.0))[T.-19.0]:cohort_dummy_15', 'C(rel_time, contr.treatment(base=-1.0))[T.-18.0]:cohort_dummy_15', 'C(rel_time, contr.treatment(base=-1.0))[T.-17.0]:cohort_dummy_15', 'C(rel_time, contr.treatment(base=-1.0))[T.-16.0]:cohort_dummy_15', 'C(rel_time, contr.treatment(base=-1.0))[T.15.0]:cohort_dummy_15', 'C(rel_time, contr.treatment(base=-1.0))[T.16.0]:cohort_dummy_15', 'C(rel_time, contr.treatment(base=-1.0))[T.17.0]:cohort_dummy_15', 'C(rel_time, contr.treatment(base=-1.0))[T.18.0]:cohort_dummy_15', 'C(rel_time, contr.treatment(base=-1.0))[T.19.0]:cohort_dummy_15', 'C(rel_time, contr.treatment(base=-1.0))[T.inf]:cohort_dummy_15', 'C(rel_time, contr.treatment(base=-1.0))[T.-20.0]:cohort_dummy_20', 'C(rel_time, contr.treatment(base=-1.0))[T.-19.0]:cohort_dummy_20', 'C(rel_time, contr.tr

8.319654181015903e-05