In [1]:
import numpy as np
import pandas as pd
import time

from csdid.att_gt import ATTgt
from diff_diff import CallawaySantAnna

np.random.seed(123)


In [2]:
# Synthetic panel similar to Callaway & Sant'Anna examples
n_units = 100_0000
n_periods = 10
years = np.arange(2010, 2010 + n_periods)

unit_ids = np.arange(n_units)
panel = pd.MultiIndex.from_product([unit_ids, years], names=['id', 'year'])
df = panel.to_frame(index=False)

# Assign first treatment year or never treated (0)
treat_years = np.random.choice([0, 2012, 2014, 2016, 2018], size=n_units, p=[0.3, 0.2, 0.2, 0.2, 0.1])
df = df.merge(pd.DataFrame({'id': unit_ids, 'g': treat_years}), on='id')

# Create treatment indicator
df['treated'] = (df['g'] > 0) & (df['year'] >= df['g'])

# Unit and time fixed effects + treatment effect
unit_fe = np.random.normal(0, 1, size=n_units)
time_fe = {year: val for year, val in zip(years, np.linspace(-0.2, 0.2, n_periods))}
df['unit_fe'] = df['id'].map(dict(zip(unit_ids, unit_fe)))
df['time_fe'] = df['year'].map(time_fe)
true_tau = 1.5
noise = np.random.normal(0, 1, size=len(df))
df['y'] = 2 + df['unit_fe'] + df['time_fe'] + true_tau * df['treated'].astype(int) + noise

df.rename(columns={'g': 'first_treat'}, inplace=True)
print('Rows:', len(df), 'Units:', df['id'].nunique())


Rows: 10000000 Units: 1000000


In [3]:
# diff_diff (CallawaySantAnna)
cs = CallawaySantAnna()
start = time.perf_counter()
cs_results = cs.fit(
    df,
    outcome='y',
    unit='id',
    time='year',
    first_treat='first_treat',
    aggregate='event_study'
)
cs_time = time.perf_counter() - start

# Average post-treatment ATT from event study effects
post_effects = [v['effect'] for k, v in cs_results.event_study_effects.items() if k >= 0]
cs_att = float(np.mean(post_effects))
print(f'diff_diff CallawaySantAnna time: {cs_time:.2f}s')
print(f'diff_diff post ATT (avg): {cs_att:.3f}')


diff_diff CallawaySantAnna time: 4.96s
diff_diff post ATT (avg): 1.500


In [None]:
# csdid (ATTgt)
start = time.perf_counter()
att_out = ATTgt(
    yname='y',
    gname='first_treat',
    idname='id',
    tname='year',
    data=df,
    control_group='never_treated'
).fit(est_method='dr')
att_time = time.perf_counter() - start

agg = att_out.aggte(typec='dynamic', na_rm=True)

def extract_dynamic(agg_obj):
    return agg_obj.summ_attgt().atte['overall_att']


att_egt = extract_dynamic(agg)
print(f'csdid ATTgt time: {att_time:.2f}s')
print(f'csdid post ATT (avg): {att_egt:.3f}')





Overall summary of ATT's based on event-study/dynamic aggregation:
   ATT Std. Error [95.0%  Conf. Int.]  
1.4995     0.0019 1.4958       1.5032 *


Dynamic Effects:
    Event time  Estimate  Std. Error  [95.0% Simult.   Conf. Band   
0           -7   -0.0002      0.0049          -0.0098      0.0094   
1           -6    0.0033      0.0049          -0.0064      0.0130   
2           -5   -0.0007      0.0029          -0.0063      0.0050   
3           -4    0.0036      0.0029          -0.0021      0.0094   
4           -3   -0.0003      0.0022          -0.0047      0.0041   
5           -2   -0.0015      0.0024          -0.0062      0.0032   
6           -1    0.0022      0.0019          -0.0016      0.0060   
7            0    1.4987      0.0020           1.4948      1.5026  *
8            1    1.4984      0.0019           1.4946      1.5021  *
9            2    1.4982      0.0021           1.4941      1.5024  *
10           3    1.4991      0.0022           1.4949      1.5034  *
11  