In [None]:
import pandas as pd
import numpy as np
import os
import re
from tqdm import tqdm
from pyfixest.estimation import fepois
from enum import IntEnum

# =========================================================================
# 0. CONFIGURATION
# =========================================================================

class ControlMethod(IntEnum):
    COMPANY_WIDE = 1        
    COMPANY_SENIORITY = 2   

# *** SELECT YOUR METHOD HERE ***
CONTROL_METHOD = ControlMethod.COMPANY_SENIORITY 

# DATES
PRE_START = '2021-01-01'
EXP_END = '2025-09-01'
REF_PERIOD = '2022-10-01'

# PATHS
FILE_PATH = "./data/seniority_DWA_data_CLEAN/"
RESULTS_DIR = "./results"

# --- UPDATED PATHS BELOW ---
PATH_TO_FIRM_TOTALS = "./data/firm_level_totals/firm_month_totals.parquet"
PATH_TO_SENIORITY_TOTALS = "./data/firm_level_totals/firm_seniority_totals.parquet"
# ---------------------------

if not os.path.exists(RESULTS_DIR):
    os.makedirs(RESULTS_DIR)

files_seniority_DWA = sorted(os.listdir(FILE_PATH))
date_pattern = re.compile(r"\[T\.(\d{4}-\d{2}-\d{2})\]")

# =========================================================================
# 1. PRE-LOAD GLOBAL DATA
# =========================================================================

print("Loading Firm Totals Reference (Global Denominator)...")
try:
    df_totals = pd.read_parquet(PATH_TO_FIRM_TOTALS) 
    df_totals['month'] = pd.to_datetime(df_totals['month'])
    
    print("Calculating firm lifespans (Entry/Exit windows)...")
    firm_lifespans = df_totals.groupby('firm_id')['month'].agg(['min', 'max']).reset_index()
    firm_lifespans.rename(columns={'min': 'firm_start', 'max': 'firm_end'}, inplace=True)
    
except Exception as e:
    print(f"❌ CRITICAL ERROR: Could not load firm totals from {PATH_TO_FIRM_TOTALS}. {e}")
    raise e

# Load Seniority Totals (Only needed if using Method 2)
df_seniority_totals = None
if os.path.exists(PATH_TO_SENIORITY_TOTALS):
    print("Loading Seniority Totals Reference (Method 2 Denominator)...")
    df_seniority_totals = pd.read_parquet(PATH_TO_SENIORITY_TOTALS)
    df_seniority_totals['month'] = pd.to_datetime(df_seniority_totals['month'])
else:
    print(f"⚠️ Warning: Seniority totals file not found at {PATH_TO_SENIORITY_TOTALS}. Method 2 will fail if selected.")

# =========================================================================
# 2. SMART BALANCING FUNCTION
# =========================================================================

def balance_panel_smart(df, df_totals, firm_lifespans):
    """
    Balances the panel ONLY within the active lifespan of each firm.
    """
    print("  > 1/5: Preparing Dates and Indices...")
    df['month'] = pd.to_datetime(df['month'])
    
    unique_firm_tasks = df[['firm_id', 'dwa_id', 'seniority']].drop_duplicates()
    all_months = pd.date_range(start=PRE_START, end=EXP_END, freq='MS')
    
    # Efficient Cross Join
    print(f"  > 2/5: Creating Skeleton Grid ({len(unique_firm_tasks):,} firm-task pairs x {len(all_months)} months)...")
    index = pd.MultiIndex.from_product(
        [unique_firm_tasks.index, all_months], 
        names=['_temp_idx', 'month']
    )
    skeleton = pd.DataFrame(index=index).reset_index()
    skeleton = skeleton.merge(unique_firm_tasks, left_on='_temp_idx', right_index=True).drop(columns=['_temp_idx'])

    print("  > 3/5: Applying Lifespan Filter (Dropping ghosts)...")
    skeleton = skeleton.merge(firm_lifespans, on='firm_id', how='inner')
    skeleton = skeleton[
        (skeleton['month'] >= skeleton['firm_start']) & 
        (skeleton['month'] <= skeleton['firm_end'])
    ]
    skeleton = skeleton.drop(columns=['firm_start', 'firm_end'])

    print("  > 4/5: Merging Task Data (Numerator)...")
    balanced = skeleton.merge(
        df, 
        on=['firm_id', 'dwa_id', 'seniority', 'month'], 
        how='left'
    )
    balanced['FTE'] = balanced['FTE'].fillna(0)
    
    print("  > 5/5: Merging Firm Totals (Global Denominator)...")
    if 'firm_month_total_fte_all' in balanced.columns:
        balanced = balanced.drop(columns=['firm_month_total_fte_all'])
        
    balanced = balanced.merge(df_totals, on=['firm_id', 'month'], how='left')
    balanced = balanced.dropna(subset=['firm_month_total_fte_all'])
    
    return balanced

# =========================================================================
# 3. MAIN EXECUTION LOOP
# =========================================================================

for file_path in files_seniority_DWA:
    if not file_path.endswith('.parquet'): continue
    
    print(f"\nProcessing File: {file_path}")
    company_DWA_df = pd.read_parquet(os.path.join(FILE_PATH, file_path))
    
    # 1. Balance Smartly
    balanced_df = balance_panel_smart(company_DWA_df, df_totals, firm_lifespans)
    
    # 2. Create Controls
    balanced_df['month_str'] = balanced_df['month'].dt.strftime('%Y-%m-%d')
    
    # -------------------------------------------------------
    # CONTROL METHOD LOGIC (Robust "Leave-One-Out")
    # -------------------------------------------------------
    suffix = ""
    
    if CONTROL_METHOD == ControlMethod.COMPANY_WIDE:
        # Control: Log(Total Firm Size - This Task Size + 1)
        val = (balanced_df["firm_month_total_fte_all"] - balanced_df["FTE"]).clip(lower=0)
        balanced_df["Z_control"] = np.log(val + 1)
        suffix = "_method_company"
        
    elif CONTROL_METHOD == ControlMethod.COMPANY_SENIORITY:
        if df_seniority_totals is None:
            raise ValueError("Method 2 selected but 'firm_seniority_totals.parquet' is missing.")
            
        print("  > Merging Seniority Totals for Method 2...")
        # Merge the TRUE seniority totals (calculated before filtering)
        balanced_df = balanced_df.merge(
            df_seniority_totals, 
            on=['firm_id', 'month', 'seniority'], 
            how='left'
        )
        # Note: If firm_seniority_total_fte is NaN (rare), it means the firm had 0 people in that seniority tier
        balanced_df['firm_seniority_total_fte'] = balanced_df['firm_seniority_total_fte'].fillna(0)

        # Control: Log(True Seniority Tier Size - This Task Size + 1)
        val = (balanced_df['firm_seniority_total_fte'] - balanced_df['FTE']).clip(lower=0)
        balanced_df["Z_control"] = np.log(val + 1)
        suffix = "_method_company_seniority"
    # -------------------------------------------------------
    
    # 3. Regression Loop
    results_list = []
    grouped = balanced_df.groupby(['dwa_id', 'seniority'])
    
    print(f"  > Starting Regressions for {len(grouped)} tasks...")
    
    for (dwa_id, seniority), subset in tqdm(grouped, desc="Regressing Tasks", leave=False):
        try:
            # Singleton Filter
            obs_count = subset.groupby('firm_id').size()
            valid_firms = obs_count[obs_count > 1].index
            
            if len(valid_firms) < 2: continue
                
            subset_clean = subset[subset['firm_id'].isin(valid_firms)].copy()
            
            # Remove firms that NEVER perform this task (Separation Check)
            # A firm with sum(FTE) = 0 for this task contributes no variance to the FE estimator.
            firm_sums = subset_clean.groupby('firm_id')['FTE'].sum()
            active_firms = firm_sums[firm_sums > 0].index
            subset_clean = subset_clean[subset_clean['firm_id'].isin(active_firms)].copy()

            # Final safety check
            if subset_clean.empty or subset_clean['FTE'].sum() == 0: 
                continue

            # Run PPML (Poisson Pseudo Maximum Likelihood)
            model = fepois(
                fml = f"FTE ~ Z_control + i(month_str, ref='{REF_PERIOD}') | firm_id",
                data = subset_clean,
                vcov = {"CRV1": "firm_id"}
            )
            
            # Extract Results
            coefs = model.coef()
            se = model.se()
            
            for name, beta in coefs.items():
                match = date_pattern.search(name)
                if match:
                    results_list.append({
                        'dwa_id': dwa_id, 
                        'seniority': seniority,
                        'month': match.group(1),
                        'beta': beta,
                        'se': se.get(name, 0)
                    })
            
            # Add Ref Period
            results_list.append({
                'dwa_id': dwa_id, 'seniority': seniority,
                'month': REF_PERIOD, 'beta': 0.0, 'se': 0.0
            })
            
        except Exception:
            # Skip singular matrix errors silently
            pass

    # 4. Save Results
    if results_list:
        res_df = pd.DataFrame(results_list)
        
        # dynamic filename based on input
        # Note: file_path is a string in your loop, so we use os.path.basename
        base_name = os.path.basename(file_path).replace('.parquet', '')
        out_name = f"results_{base_name}_{CONTROL_METHOD.name}.csv"
        out_path = os.path.join(RESULTS_DIR, out_name)
        
        print(f"✅ Saving {len(res_df)} coefficients to: {out_path}")
        res_df.to_csv(out_path, index=False)
    else:
        print(f"⚠️ No results generated for {file_path}")

print("\n*** All Regressions Completed ***")

Loading Firm Totals Reference (Global Denominator)...
Calculating firm lifespans (Entry/Exit windows)...
Loading Seniority Totals Reference (Method 2 Denominator)...

Processing File: seniority_1_data.parquet
  > 1/5: Preparing Dates and Indices...
  > 2/5: Creating Skeleton Grid (2,006,421 firm-task pairs x 57 months)...
  > 3/5: Applying Lifespan Filter (Dropping ghosts)...
  > 4/5: Merging Task Data (Numerator)...
  > 5/5: Merging Firm Totals (Global Denominator)...
  > Merging Seniority Totals for Method 2...
  > Starting Regressions for 1664 tasks...


Regressing Tasks:  23%|██▎       | 388/1664 [01:10<04:26,  4.79it/s]