In [None]:
"""
基于 UMe 茶饮销售数据，评估天气、节假日、促销等因素的因果影响
"""

import pandas as pd
import numpy as np
pd.options.display.max_columns = None

from datetime import datetime, timedelta
import clickhouse_connect
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.model_selection import train_test_split
import warnings; warnings.filterwarnings("ignore")

# 因果推断
from dowhy import CausalModel
import statsmodels.api as sm

# EconML（异质效应分析）
from econml.metalearners import TLearner
from econml.dml import CausalForestDML, LinearDML

# 可视化
import plotly.graph_objects as go
import plotly.io as pio
pio.renderers.default = "browser"  # 或者 "notebook" 如果你希望直接在 Jupyter 中查看


class FBRCausalInference:
    """FBR 因果推断分析引擎"""

    # ------------------------------------------------------------
    # 0. 初始化
    # ------------------------------------------------------------
    def __init__(self, ch_config: dict):
        self.ch_client = clickhouse_connect.get_client(**ch_config)
        self.scaler = StandardScaler()

    # ------------------------------------------------------------
    # 1. 数据抽取
    # ------------------------------------------------------------
    def load_integrated_data(self, start_date: str, end_date: str) -> pd.DataFrame:
        """拉取日 × 店铺销售聚合 + 类别特征"""
        sales_query = f"""
        WITH
            toDate(created_at_pt)          AS dt,
            toDayOfWeek(created_at_pt)     AS day_of_week,
            toHour(created_at_pt)          AS hour_of_day,
            substring(location_name, position(location_name,'-')+1, 2) AS state
        SELECT
            dt  AS date,
            location_id,
            location_name,
            state,
            day_of_week,
            hour_of_day,
            /* 指标 */
            countDistinct(order_id)                              AS order_count,
            sum(item_total_amt)                                  AS total_revenue,
            avg(item_total_amt)                                  AS avg_order_value,
            sum(item_discount)                                   AS total_discount,
            sum(item_discount>0)                                 AS discount_orders,
            countDistinct(customer_id)                           AS unique_customers,
            sum(is_loyalty)                                      AS loyalty_orders,
            sum(arrayExists(x->x='BOGO',assumeNotNull(campaign_names))) AS bogo_orders,
            countDistinct(category_name)                         AS category_diversity
        FROM dw.fact_order_item_variations
        WHERE
            created_at_pt >= '{start_date}'
            AND created_at_pt <= '{end_date}'
            AND pay_status = 'COMPLETED'
        GROUP BY
            date, location_id, location_name, state, day_of_week, hour_of_day
        ORDER BY
            date, location_id
        """
        sales_df = self.ch_client.query_df(sales_query)

        # —— 类别特征（示例）———————————————————————————
        category_query = f"""
        SELECT
            toDate(created_at_pt) AS date,
            location_id,
            category_name,
            count()               AS category_orders,
            sum(item_total_amt)   AS category_revenue
        FROM dw.fact_order_item_variations
        WHERE created_at_pt BETWEEN '{start_date}' AND '{end_date}'
              AND pay_status='COMPLETED'
        GROUP BY date, location_id, category_name
        """
        cat_df = self.ch_client.query_df(category_query)

        hot_categories = ['Milk Tea', 'Fruit Tea', 'Coffee', 'Snacks']
        for cat in hot_categories:
            tmp = cat_df[cat_df['category_name'] == cat].copy()
            tmp = tmp.rename(columns={
                'category_orders': f"{cat.lower().replace(' ','_')}_orders",
                'category_revenue': f"{cat.lower().replace(' ','_')}_revenue"
            })
            sales_df = sales_df.merge(
                tmp[['date','location_id',
                     f"{cat.lower().replace(' ','_')}_orders",
                     f"{cat.lower().replace(' ','_')}_revenue"]],
                on=['date','location_id'],
                how='left'
            )

        return sales_df.fillna(0)

    # ------------------------------------------------------------
    # 2. 衍生处理变量
    # ------------------------------------------------------------
    @staticmethod
    def create_treatment_variables(df: pd.DataFrame) -> pd.DataFrame:
        df = df.copy()
        df['has_promotion']      = (df['total_discount'] > 0).astype(int)
        df['promotion_intensity'] = df['total_discount'].astype(float) /(df['total_revenue'].astype(float)  + df['total_discount'].astype(float)  + 1e-3)

        df['has_bogo']    = (df['bogo_orders'] > 0).astype(int)
        df['is_weekend']  = df['day_of_week'].isin([6,7]).astype(int)
        df['is_member_day'] = (df['day_of_week'] == 3).astype(int)
        df['is_peak_hour'] = df.groupby('location_id')['order_count'].transform(lambda x: (x > x.quantile(.75)).astype(int))
        return df

    # ------------------------------------------------------------
    # 3. 促销效应分析（DoWhy + EconML）
    # ------------------------------------------------------------
    def analyze_promotion_effect(self, df: pd.DataFrame, method='both') -> dict:
        """
        分析促销效应
        method: 'dowhy', 'econml', 'both'
        """
        results = {}
        
        if method in ['dowhy', 'both']:
            dowhy_results = self._analyze_promotion_dowhy(df)
            results['DoWhy'] = dowhy_results
            
        if method in ['econml', 'both']:
            econml_results = self._analyze_promotion_econml(df)
            results['EconML'] = econml_results
            
        return results

    # ------------------------------------------------------------
    # 3.1 DoWhy 方法（修复版）
    # ------------------------------------------------------------
    def _analyze_promotion_dowhy(self, df: pd.DataFrame) -> dict:
        print("\n=== 促销活动因果效应分析 (DoWhy) ===\n")

        treatment = 'has_promotion'
        outcome   = 'total_revenue'
        confs     = ['day_of_week','unique_customers',
                     'category_diversity','loyalty_orders','is_weekend']

        # 数值化处理
        df = self._force_numeric(df, [treatment,outcome]+confs).dropna()

        # 因果图定义
        graph = """
        digraph {
            is_member_day -> has_promotion;
            has_promotion -> total_revenue;

            day_of_week -> {has_promotion total_revenue};
            unique_customers -> total_revenue;
            category_diversity -> total_revenue;
            loyalty_orders -> {has_promotion total_revenue};
            is_weekend -> {has_promotion total_revenue};
        }
        """

        model = CausalModel(df, treatment, outcome, graph)
        ident = model.identify_effect(proceed_when_unidentifiable=True)
        print("识别的因果效应:\n", ident)

        results = {}

        # ① PSM
        try:
            psm = model.estimate_effect(ident,
                     method_name="backdoor.propensity_score_matching")
            results['PSM'] = round(psm.value,2)
            print(f"倾向得分匹配估计: {psm.value:.2f}")
        except Exception as e:
            print("PSM 失败:", e)

        # ② 线性回归
        lr = model.estimate_effect(ident,
                     method_name="backdoor.linear_regression")
        results['LinearRegression'] = round(lr.value,2)
        print(f"线性回归估计: {lr.value:.2f}")

        # ③ 工具变量
        try:
            iv = model.estimate_effect(ident,
                method_name="iv.instrumental_variable",
                method_params={'iv_instrument_name':'is_member_day'})
            results['IV'] = round(iv.value,2)
            print(f"工具变量估计: {iv.value:.2f}")
        except Exception as e:
            print("IV 估计失败:", e)

        # ④ 反事实分析（修复版）
        self._counterfactual_dowhy_fixed(model, ident, df, lr,
                                         treatment=treatment, outcome=outcome)

        return results

    # ------------------------------------------------------------
    # 3.2 修复的反事实分析
    # ------------------------------------------------------------
    def _counterfactual_dowhy_fixed(self, model, ident, df, estimator, treatment, outcome):
        """修复版：使用已估计的模型进行反事实推断"""
        try:
            # 获取当前平均值
            y_actual = df[outcome].mean()
            
            # 使用估计器的ATE值进行反事实计算
            ate = estimator.value  # 这是估计的平均处理效应
            
            # 当前促销比例
            promo_rate = df[treatment].mean()
            
            # 反事实计算：
            # 1. 全部促销：当前营收 + ATE * (未促销的比例)
            y_all_promo = y_actual + ate * (1 - promo_rate)
            
            # 2. 全部不促销：当前营收 - ATE * (已促销的比例)
            y_no_promo = y_actual - ate * promo_rate
            
            print("\n=== 反事实推断 (基于线性回归 ATE) ===")
            print(f"促销覆盖率          : {promo_rate*100:.1f}%")
            print(f"估计的ATE          : ${ate:.2f}")
            print(f"当前平均营收        : ${y_actual:.2f}")
            print(f"全部促销 反事实营收 : ${y_all_promo:.2f}  (+{(y_all_promo/y_actual-1)*100:4.1f}%)")
            print(f"全部停促销反事实营收: ${y_no_promo:.2f}  ({(1-y_no_promo/y_actual)*100:4.1f}% ↓)")
            
            # 尝试获取更详细的模型信息
            if hasattr(estimator, 'estimator'):
                print("\n模型详情:")
                if hasattr(estimator.estimator, 'model'):
                    model_obj = estimator.estimator.model
                    if hasattr(model_obj, 'params'):
                        print(f"模型系数: {dict(model_obj.params)}")
                    if hasattr(model_obj, 'summary'):
                        print("\n回归模型摘要:")
                        print(model_obj.summary())
                        
        except Exception as e:
            print(f"反事实分析出错: {e}")
            print(f"错误详情: {type(e).__name__}")
            import traceback
            traceback.print_exc()

    # ------------------------------------------------------------
    # 3.3 EconML 方法
    # ------------------------------------------------------------
    def _analyze_promotion_econml(self, df: pd.DataFrame) -> dict:
        """使用EconML的CausalForest进行促销效应分析"""
        print("\n=== 促销活动因果效应分析 (EconML) ===\n")
        
        treatment = 'has_promotion'
        outcome = 'total_revenue'
        feature_cols = ['day_of_week', 'unique_customers', 'category_diversity',
                       'loyalty_orders', 'is_weekend']
        
        # 数据预处理
        df = df.copy()
        df[feature_cols + [outcome]] = df[feature_cols + [outcome]].apply(
            pd.to_numeric, errors='coerce')
        df = df.dropna(subset=feature_cols + [outcome])
        
        Y = df[outcome].values
        T = df[treatment].values.astype(int)  # 确保是整数类型
        X = df[feature_cols].values
        
        # 分割训练集和测试集
        X_tr, X_te, T_tr, T_te, Y_tr, Y_te = train_test_split(
            X, T, Y, test_size=0.2, random_state=42)
        
        # 训练CausalForest模型
        # 对于二值处理变量，使用 discrete_treatment=True
        cf = CausalForestDML(
            model_t='auto',  # 让模型自动选择合适的第一阶段模型
            model_y=RandomForestRegressor(n_estimators=200, max_depth=6),
            discrete_treatment=True,  # 指定处理变量是离散的
            n_estimators=500,  # 减少树的数量以加快训练
            min_samples_leaf=50,
            random_state=42
        )
        
        try:
            cf.fit(Y_tr, T_tr, X=X_tr)
            
            # 计算效应
            cate_te = cf.effect(X_te)
            ate = float(cate_te.mean())
            
            # 反事实分析
            rev_actual = Y_te.mean()
            
            # 对未促销的样本，计算如果促销的效应
            no_promo_idx = T_te == 0
            promo_effect_on_no_promo = cate_te[no_promo_idx].mean() if no_promo_idx.any() else 0
            
            # 对已促销的样本，计算如果不促销的效应（负效应）
            promo_idx = T_te == 1
            no_promo_effect_on_promo = -cate_te[promo_idx].mean() if promo_idx.any() else 0
            
            # 计算反事实场景
            promo_rate = T_te.mean()
            rev_all_promo = rev_actual + promo_effect_on_no_promo * (1 - promo_rate)
            rev_no_promo = rev_actual + no_promo_effect_on_promo * promo_rate
            
            print(f"CausalForest ATE: ${ate:.2f}")
            print(f"\n=== 反事实推断 (EconML) ===")
            print(f"当前平均营收        : ${rev_actual:.2f}")
            print(f"全部促销 反事实营收 : ${rev_all_promo:.2f}  (+{(rev_all_promo/rev_actual-1)*100:4.1f}%)")
            print(f"全部停促销反事实营收: ${rev_no_promo:.2f}  ({(1-rev_no_promo/rev_actual)*100:4.1f}% ↓)")
            
            results = {
                'CausalForest_ATE': ate,
                'Actual_Revenue': rev_actual,
                'AllPromo_Revenue': rev_all_promo,
                'NoPromo_Revenue': rev_no_promo,
                'CATE_std': float(cate_te.std())
            }
            
            return results
            
        except Exception as e:
            print(f"EconML分析失败: {e}")
            print("尝试使用更简单的模型...")
            
            # 降级到线性DML模型
            ldml = LinearDML(
                model_t='auto',
                model_y=RandomForestRegressor(n_estimators=100, max_depth=5),
                discrete_treatment=True,
                random_state=42
            )
            
            ldml.fit(Y_tr, T_tr, X=X_tr)
            ate = ldml.ate(X_te)
            
            print(f"\nLinearDML ATE: ${ate:.2f}")
            
            return {
                'LinearDML_ATE': ate,
                'Actual_Revenue': Y_te.mean(),
                'Fallback': True
            }

    # ------------------------------------------------------------
    # 4. 天气效应（EconML TLearner）
    # ------------------------------------------------------------
    def analyze_weather_effect(self, sales_df, weather_df):
        print("\n=== 天气因果效应分析 (EconML TLearner) ===\n")

        merged = sales_df.merge(weather_df, on=['date','state'], how='left')
        merged['is_hot'] = (merged['temperature_2m_max'] > 30).astype(int)

        Y = merged['total_revenue'].values
        T = merged['is_hot'].values
        X = merged[['day_of_week','unique_customers','category_diversity']].values

        t_learner = TLearner(models=RandomForestRegressor(n_estimators=200,
                                                          random_state=42))
        t_learner.fit(Y, T, X=X)
        cate = t_learner.effect(X)

        print(f"高温天气平均因果效应: ${cate.mean():.2f} (±{cate.std():.2f})")
        return merged, cate

    # ------------------------------------------------------------
    # 5. 小工具
    # ------------------------------------------------------------
    @staticmethod
    def _force_numeric(df, cols):
        """把 object / Decimal 全转 float64，UInt* 转 int64"""
        out = df.copy()
        for c in cols:
            if c in out.columns:
                out[c] = pd.to_numeric(out[c], errors='coerce')
        return out

    # ------------------------------------------------------------
    # 6. 可视化
    # ------------------------------------------------------------
    @staticmethod
    def visualize_causal_effects(effects: dict):
        """可视化因果效应对比"""
        if 'DoWhy' in effects and 'EconML' in effects:
            # 组合显示
            dowhy_vals = effects['DoWhy']
            econml_res = effects['EconML']
            
            if 'Fallback' in econml_res and econml_res['Fallback']:
                methods = list(dowhy_vals.keys()) + ['LinearDML']
                values = list(dowhy_vals.values()) + [econml_res.get('LinearDML_ATE', 0)]
            else:
                methods = list(dowhy_vals.keys()) + ['CausalForest']
                values = list(dowhy_vals.values()) + [econml_res.get('CausalForest_ATE', 0)]
            
            fig = go.Figure(go.Bar(
                x=methods,
                y=values,
                text=[f"${v:.0f}" for v in values],
                textposition='auto'
            ))
            fig.update_yaxes(title='平均因果效应 ($)')
            fig.update_layout(title='促销因果效应估计对比 (DoWhy vs EconML)')
            
        elif 'DoWhy' in effects:
            fig = go.Figure(go.Bar(
                x=list(effects['DoWhy'].keys()),
                y=list(effects['DoWhy'].values()),
                text=[f"${v:.0f}" for v in effects['DoWhy'].values()],
                textposition='auto'
            ))
            fig.update_yaxes(title='平均因果效应 ($)')
            fig.update_layout(title='促销因果效应估计对比 (DoWhy)')
            
        else:
            # EconML only
            econml_res = effects['EconML']
            if 'Fallback' in econml_res and econml_res['Fallback']:
                ate = econml_res.get('LinearDML_ATE', 0)
                title = '促销因果效应 (LinearDML)'
                label = 'LinearDML ATE'
            else:
                ate = econml_res.get('CausalForest_ATE', 0)
                title = '促销因果效应 (CausalForest)'
                label = 'CausalForest ATE'
                
            fig = go.Figure(go.Bar(
                x=[label],
                y=[ate],
                text=[f"${ate:.0f}"],
                textposition='auto'
            ))
            fig.update_yaxes(title='平均因果效应 ($)')
            fig.update_layout(title=title)
            
        return fig

    @staticmethod
    def visualize_counterfactual_scenarios(results: dict):
        """可视化反事实场景对比"""
        scenarios = ['当前实际', '全部促销', '全部不促销']
        
        if 'EconML' in results:
            econml_res = results['EconML']
            
            # 检查是否是 fallback 模式
            if 'Fallback' in econml_res and econml_res['Fallback']:
                print("使用 LinearDML 结果，反事实场景可视化不可用")
                return None
                
            values = [
                econml_res.get('Actual_Revenue', 0),
                econml_res.get('AllPromo_Revenue', 0),
                econml_res.get('NoPromo_Revenue', 0)
            ]
        else:
            # 从DoWhy结果推算（简化版）
            return None
            
        fig = go.Figure(go.Bar(
            x=scenarios,
            y=values,
            text=[f"${v:.0f}" for v in values],
            textposition='auto',
            marker_color=['blue', 'green', 'red']
        ))
        
        fig.update_yaxes(title='平均营收 ($)')
        fig.update_layout(title='反事实场景分析')
        
        return fig


# ──────────────────────────────────────────────
# 使用示例
# ──────────────────────────────────────────────
if __name__ == "__main__":
    CLICKHOUSE_CONFIG = dict(
        host="clickhouse-0-0.umetea.net",
        port=443,
        database="dw",
        user="ml_ume",
        password="hDAoDvg8x552bH",
        verify=False,
    )

    ci = FBRCausalInference(CLICKHOUSE_CONFIG)

    # 加载数据
    start_date, end_date = "2025-06-01", "2025-07-31"
    sales = ci.load_integrated_data(start_date, end_date)
    sales = ci.create_treatment_variables(sales)

    # 方法1：分析促销效应（使用两种方法）
    print("="*60)
    print("运行完整分析（DoWhy + EconML）...")
    print("="*60)
    promo_results = ci.analyze_promotion_effect(sales, method='both')
    
    # 可视化结果
    fig1 = ci.visualize_causal_effects(promo_results)
    fig1.show()
    
    # 可视化反事实场景
    fig2 = ci.visualize_counterfactual_scenarios(promo_results)
    if fig2:
        fig2.show()
        
    # 方法2：仅使用 EconML（推荐，更稳定）
    print("\n" + "="*60)
    print("仅运行 EconML 分析...")
    print("="*60)
    econml_only = ci.analyze_promotion_effect(sales, method='econml')
    fig3 = ci.visualize_causal_effects(econml_only)
    fig3.show()

In [1]:
"""
增强版 UMe 茶饮销售数据因果推断分析引擎
新增：天气、节假日、周末等外部因素的因果影响分析
"""

import pandas as pd
import numpy as np
pd.options.display.max_columns = None

from datetime import datetime, timedelta
import clickhouse_connect
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.model_selection import train_test_split
from typing import Dict, List, Any
import warnings; warnings.filterwarnings("ignore")

# 因果推断
from dowhy import CausalModel
import statsmodels.api as sm

# EconML（异质效应分析）
from econml.metalearners import TLearner
from econml.dml import CausalForestDML, LinearDML

# 天气数据API
import requests
import json

# 节假日数据
import holidays

# 可视化
import plotly.graph_objects as go
import plotly.express as px
import plotly.io as pio
from plotly.subplots import make_subplots
pio.renderers.default = "browser"


class EnhancedFBRCausalInference:
    """增强版 FBR 因果推断分析引擎 - 包含天气和节假日分析"""

    # ------------------------------------------------------------
    # 0. 初始化
    # ------------------------------------------------------------
    def __init__(self, ch_config: dict, weather_api_key: str = None):
        self.ch_client = clickhouse_connect.get_client(**ch_config)
        self.scaler = StandardScaler()
        self.weather_api_key = weather_api_key

        # 美国节假日（可根据需要扩展到其他国家）
        self.us_holidays = holidays.US()

    # ------------------------------------------------------------
    # 1. 数据抽取（增强版）
    # ------------------------------------------------------------
    def load_integrated_data(self, start_date: str, end_date: str) -> pd.DataFrame:
        """拉取日 × 店铺销售聚合 + 类别特征（增强版）"""
        sales_query = f"""
        WITH
            toDate(created_at_pt)          AS dt,
            toDayOfWeek(created_at_pt)     AS day_of_week,
            toHour(created_at_pt)          AS hour_of_day,
            substring(location_name, position(location_name,'-')+1, 2) AS state,
            toMonth(created_at_pt)         AS month,
            toDayOfMonth(created_at_pt)    AS day_of_month
        SELECT
            dt  AS date,
            location_id,
            location_name,
            state,
            day_of_week,
            hour_of_day,
            month,
            day_of_month,
            /* 基础指标 */
            countDistinct(order_id)                              AS order_count,
            sum(item_total_amt)                                  AS total_revenue,
            avg(item_total_amt)                                  AS avg_order_value,
            sum(item_discount)                                   AS total_discount,
            sum(item_discount>0)                                 AS discount_orders,
            countDistinct(customer_id)                           AS unique_customers,
            sum(is_loyalty)                                      AS loyalty_orders,
            sum(arrayExists(x->x='BOGO',assumeNotNull(campaign_names))) AS bogo_orders,
            countDistinct(category_name)                         AS category_diversity,

            /* 时段相关指标 */
            sum(if(hour_of_day BETWEEN 7 AND 10, 1, 0))         AS morning_orders,
            sum(if(hour_of_day BETWEEN 11 AND 14, 1, 0))        AS lunch_orders,
            sum(if(hour_of_day BETWEEN 15 AND 17, 1, 0))        AS afternoon_orders,
            sum(if(hour_of_day BETWEEN 18 AND 21, 1, 0))        AS evening_orders,

            /* 产品类型指标 */
            sum(if(category_name IN ('Milk Tea', 'Fruit Tea'), 1, 0)) AS cold_drink_orders,
            sum(if(category_name = 'Coffee', 1, 0))              AS hot_drink_orders,
            sum(if(category_name = 'Snacks', 1, 0))              AS food_orders
        FROM dw.fact_order_item_variations
        WHERE
            created_at_pt >= '{start_date}'
            AND created_at_pt <= '{end_date}'
            AND pay_status = 'COMPLETED'
        GROUP BY
            date, location_id, location_name, state, day_of_week, hour_of_day, month, day_of_month
        ORDER BY
            date, location_id
        """
        sales_df = self.ch_client.query_df(sales_query)

        # 聚合到日级别
        daily_agg = sales_df.groupby(['date', 'location_id', 'location_name', 'state']).agg({
            'order_count': 'sum',
            'total_revenue': 'sum',
            'avg_order_value': 'mean',
            'total_discount': 'sum',
            'discount_orders': 'sum',
            'unique_customers': 'sum',
            'loyalty_orders': 'sum',
            'bogo_orders': 'sum',
            'category_diversity': 'max',
            'morning_orders': 'sum',
            'lunch_orders': 'sum',
            'afternoon_orders': 'sum',
            'evening_orders': 'sum',
            'cold_drink_orders': 'sum',
            'hot_drink_orders': 'sum',
            'food_orders': 'sum',
            'day_of_week': 'first',
            'month': 'first',
            'day_of_month': 'first'
        }).reset_index()

        # 关键修复：立即转换数值类型
        numeric_cols = [
            'order_count', 'total_revenue', 'avg_order_value', 'total_discount',
            'discount_orders', 'unique_customers', 'loyalty_orders', 'bogo_orders',
            'category_diversity', 'morning_orders', 'lunch_orders', 'afternoon_orders',
            'evening_orders', 'cold_drink_orders', 'hot_drink_orders', 'food_orders',
            'day_of_week', 'month', 'day_of_month'
        ]

        for col in numeric_cols:
            if col in daily_agg.columns:
                daily_agg[col] = pd.to_numeric(daily_agg[col], errors='coerce')

        return daily_agg.fillna(0)

    # ------------------------------------------------------------
    # 2. 天气数据获取
    # ------------------------------------------------------------
    def get_weather_data(self, start_date: str, end_date: str, locations: pd.DataFrame) -> pd.DataFrame:
        """获取天气数据"""
        print("\n=== 获取天气数据 ===")

        weather_data_list = []

        # 获取各州的主要城市坐标（简化版，实际应该根据具体店铺位置）
        state_coords = {
            'CA': {'lat': 37.7749, 'lon': -122.4194, 'city': 'San Francisco'},  # 加州
            'IL': {'lat': 41.8781, 'lon': -87.6298, 'city': 'Chicago'},       # 伊利诺伊州
            'AZ': {'lat': 33.4484, 'lon': -112.0740, 'city': 'Phoenix'},      # 亚利桑那州
            'TX': {'lat': 29.7604, 'lon': -95.3698, 'city': 'Houston'},       # 德克萨斯州
        }

        unique_states = locations['state'].unique()

        for state in unique_states:
            if state not in state_coords:
                print(f"跳过未知州: {state}")
                continue

            coords = state_coords[state]
            weather_data = self._fetch_weather_api(
                start_date, end_date,
                coords['lat'], coords['lon'], state
            )

            if weather_data is not None:
                weather_data_list.append(weather_data)

        if weather_data_list:
            weather_df = pd.concat(weather_data_list, ignore_index=True)
            print(f"成功获取 {len(weather_df)} 条天气记录")
            return weather_df
        else:
            # 如果无法获取真实天气数据，生成模拟数据
            print("无法获取真实天气数据，生成模拟数据...")
            return self._generate_mock_weather_data(start_date, end_date, unique_states)

    def _fetch_weather_api(self, start_date: str, end_date: str, lat: float, lon: float, state: str) -> pd.DataFrame:
        """从天气API获取数据（使用Open-Meteo免费API）"""
        try:
            # 使用 Open-Meteo API（免费，无需API key）
            url = "https://archive-api.open-meteo.com/v1/archive"
            params = {
                'latitude': lat,
                'longitude': lon,
                'start_date': start_date,
                'end_date': end_date,
                'daily': [
                    'temperature_2m_max', 'temperature_2m_min', 'temperature_2m_mean',
                    'precipitation_sum', 'rain_sum', 'snowfall_sum',
                    'windspeed_10m_max', 'sunshine_duration'
                ],
                'timezone': 'America/Los_Angeles'
            }

            response = requests.get(url, params=params, timeout=30)

            if response.status_code == 200:
                data = response.json()

                weather_df = pd.DataFrame({
                    'date': pd.to_datetime(data['daily']['time']),
                    'state': state,
                    'temperature_max': data['daily']['temperature_2m_max'],
                    'temperature_min': data['daily']['temperature_2m_min'],
                    'temperature_mean': data['daily']['temperature_2m_mean'],
                    'precipitation': data['daily']['precipitation_sum'],
                    'rain': data['daily']['rain_sum'],
                    'snow': data['daily']['snowfall_sum'],
                    'wind_speed': data['daily']['windspeed_10m_max'],
                    'sunshine_hours': data['daily']['sunshine_duration']
                })

                return weather_df
            else:
                print(f"天气API请求失败: {response.status_code}")
                return None

        except Exception as e:
            print(f"天气数据获取失败: {e}")
            return None

    def _generate_mock_weather_data(self, start_date: str, end_date: str, states: list) -> pd.DataFrame:
        """生成模拟天气数据（用于演示）"""
        date_range = pd.date_range(start=start_date, end=end_date, freq='D')
        weather_data = []

        for state in states:
            # 根据州设置基础温度（简化）
            base_temp = {
                'CA': 22, 'IL': 15, 'AZ': 30, 'TX': 25
            }.get(state, 20)

            for date in date_range:
                # 生成带季节性的模拟天气数据
                day_of_year = date.timetuple().tm_yday
                seasonal_factor = np.sin(2 * np.pi * day_of_year / 365.25)

                temp_max = base_temp + 8 * seasonal_factor + np.random.normal(0, 3)
                temp_min = temp_max - np.random.uniform(5, 15)
                temp_mean = (temp_max + temp_min) / 2

                weather_data.append({
                    'date': date,
                    'state': state,
                    'temperature_max': temp_max,
                    'temperature_min': temp_min,
                    'temperature_mean': temp_mean,
                    'precipitation': max(0, np.random.exponential(2) - 1),
                    'rain': max(0, np.random.exponential(1.5) - 0.5),
                    'snow': 0 if state in ['CA', 'AZ', 'TX'] else max(0, np.random.exponential(0.5) - 2),
                    'wind_speed': np.random.uniform(5, 25),
                    'sunshine_hours': np.random.uniform(4, 12)
                })

        return pd.DataFrame(weather_data)

    # ------------------------------------------------------------
    # 3. 节假日和特殊日期数据
    # ------------------------------------------------------------
    def add_calendar_features(self, df: pd.DataFrame) -> pd.DataFrame:
        """添加日历特征（节假日、周末等）"""
        df = df.copy()
        df['date'] = pd.to_datetime(df['date'])

        # 基础日期特征
        df['is_weekend'] = df['date'].dt.dayofweek.isin([5, 6]).astype(int)
        df['is_monday'] = (df['date'].dt.dayofweek == 0).astype(int)
        df['is_friday'] = (df['date'].dt.dayofweek == 4).astype(int)

        # 节假日特征
        df['is_holiday'] = df['date'].apply(lambda x: x.date() in self.us_holidays).astype(int)
        df['is_holiday_week'] = df['date'].apply(
            lambda x: any((x + timedelta(days=i)).date() in self.us_holidays for i in range(-3, 4))
        ).astype(int)

        # 特殊节日
        df['is_valentine'] = ((df['date'].dt.month == 2) & (df['date'].dt.day == 14)).astype(int)
        df['is_christmas_season'] = ((df['date'].dt.month == 12) & (df['date'].dt.day >= 15)).astype(int)
        df['is_thanksgiving_week'] = df['date'].apply(self._is_thanksgiving_week).astype(int)
        df['is_summer'] = df['date'].dt.month.isin([6, 7, 8]).astype(int)
        df['is_winter'] = df['date'].dt.month.isin([12, 1, 2]).astype(int)

        # 学期特征（影响学生客流）
        df['is_school_term'] = df['date'].apply(self._is_school_term).astype(int)
        df['is_finals_week'] = df['date'].apply(self._is_finals_week).astype(int)

        # 发薪日特征（影响消费力）
        df['is_payday'] = df['date'].apply(self._is_payday).astype(int)

        return df

    def _is_thanksgiving_week(self, date):
        """判断是否为感恩节周"""
        # 感恩节是11月的第四个星期四
        year = date.year
        november_first = datetime(year, 11, 1)
        days_to_thursday = (3 - november_first.weekday()) % 7
        first_thursday = november_first + timedelta(days=days_to_thursday)
        thanksgiving = first_thursday + timedelta(weeks=3)

        return abs((date - thanksgiving).days) <= 3

    def _is_school_term(self, date):
        """判断是否为学期内（简化版）"""
        month = date.month
        # 秋季学期: 9月-12月中旬，春季学期: 1月-5月
        return month in [1, 2, 3, 4, 5, 9, 10, 11] or (month == 12 and date.day <= 15)

    def _is_finals_week(self, date):
        """判断是否为期末考试周（简化版）"""
        month, day = date.month, date.day
        # 秋季期末: 12月第二周，春季期末: 5月第二周
        return (month == 12 and 8 <= day <= 15) or (month == 5 and 8 <= day <= 15)

    def _is_payday(self, date):
        """判断是否为发薪日（通常是每月15日和月末）"""
        day = date.day
        month_end = (date + timedelta(days=1)).day == 1  # 下一天是新月第一天
        return day == 15 or month_end

    # ------------------------------------------------------------
    # 4. 天气特征工程
    # ------------------------------------------------------------
    def add_weather_features(self, df: pd.DataFrame, weather_df: pd.DataFrame) -> pd.DataFrame:
        """添加天气特征"""
        if weather_df is None or len(weather_df) == 0:
            print("警告: 无天气数据，跳过天气特征")
            return df

        df = df.copy()
        weather_df = weather_df.copy()

        # 确保日期格式一致
        df['date'] = pd.to_datetime(df['date'])
        weather_df['date'] = pd.to_datetime(weather_df['date'])

        # 合并天气数据
        merged = df.merge(weather_df, on=['date', 'state'], how='left')

        # 填充缺失的天气数据
        merged = merged.fillna(method='ffill').fillna(method='bfill')

        # 天气分类特征
        merged['is_hot'] = (merged['temperature_max'] > 30).astype(int)  # 30°C以上为热天
        merged['is_cold'] = (merged['temperature_max'] < 10).astype(int)  # 10°C以下为冷天
        merged['is_mild'] = ((merged['temperature_max'] >= 15) & (merged['temperature_max'] <= 25)).astype(int)

        merged['is_rainy'] = (merged['precipitation'] > 2).astype(int)  # 2mm以上降水
        merged['is_heavy_rain'] = (merged['precipitation'] > 10).astype(int)  # 10mm以上大雨
        merged['is_snowy'] = (merged['snow'] > 0).astype(int)

        merged['is_sunny'] = (merged['sunshine_hours'] > 8).astype(int)  # 8小时以上日照
        merged['is_windy'] = (merged['wind_speed'] > 20).astype(int)  # 20km/h以上大风

        # 舒适度指数（简化版）
        merged['comfort_index'] = (
            (merged['temperature_mean'] - 20).abs() * (-0.1) +  # 距离20°C越远越不舒适
            merged['sunshine_hours'] * 0.1 -  # 日照时间越长越舒适
            merged['precipitation'] * 0.05 -  # 降水越多越不舒适
            merged['wind_speed'] * 0.02  # 风速越大越不舒适
        )

        # 温度变化特征
        merged = merged.sort_values(['state', 'date'])
        merged['temp_change'] = merged.groupby('state')['temperature_mean'].diff()
        merged['temp_volatility'] = merged.groupby('state')['temperature_mean'].rolling(7).std().values

        return merged.fillna(0)

    # ------------------------------------------------------------
    # 5. 增强的处理变量创建
    # ------------------------------------------------------------
    def create_enhanced_treatment_variables(self, df: pd.DataFrame) -> pd.DataFrame:
        """创建增强版处理变量"""
        df = df.copy()

        # 原有的促销处理变量
        df['has_promotion'] = (df['total_discount'] > 0).astype(int)
        df['promotion_intensity'] = df['total_discount'].astype(float) / (df['total_revenue'].astype(float) + df['total_discount'].astype(float) + 1e-3)
        df['has_bogo'] = (df['bogo_orders'] > 0).astype(int)

        # 时间相关处理变量
        df['is_member_day'] = (df['date'].dt.dayofweek == 2).astype(int)  # 周三会员日
        df['is_peak_day'] = ((df['date'].dt.dayofweek == 5) | (df['date'].dt.dayofweek == 6)).astype(int)  # 周五周六

        # 天气相关处理变量（如果有天气数据）
        if 'temperature_max' in df.columns:
            df['extreme_weather'] = (
                # (df['is_very_hot'] == 1) |
                # (df['is_very_cold'] == 1) |
                (df['is_heavy_rain'] == 1) |
                (df['is_snowy'] == 1)
            ).astype(int)

            df['good_weather'] = (
                (df['is_mild'] == 1) &
                (df['is_sunny'] == 1) &
                (df['is_rainy'] == 0)
            ).astype(int)

        # 竞争压力（基于营收表现）
        df['low_performance'] = df.groupby('location_id')['total_revenue'].transform(
            lambda x: (x < x.quantile(0.25))#.astype(int)
        )

        return df

    # ------------------------------------------------------------
    # 6. 多因素因果分析
    # ------------------------------------------------------------
    def analyze_multi_factor_effects(self, df: pd.DataFrame) -> Dict[str, Any]:
        """分析多因素（促销、天气、节假日）的因果效应"""
        print("\n=== 多因素因果效应分析 ===\n")

        results = {}

        # 1. 促销效应分析
        if 'has_promotion' in df.columns:
            promo_results = self._analyze_factor_effect(
                df, 'has_promotion', '促销活动',
                confounders=['is_weekend', 'is_holiday', 'day_of_week', 'unique_customers']
            )
            results['promotion'] = promo_results

        # 2. 天气效应分析
        if 'is_hot' in df.columns:
            weather_results = self._analyze_factor_effect(
                df, 'is_hot', '高温天气',
                confounders=['is_weekend', 'is_holiday', 'day_of_week', 'unique_customers']
            )
            results['hot_weather'] = weather_results

        if 'is_rainy' in df.columns:
            rain_results = self._analyze_factor_effect(
                df, 'is_rainy', '雨天',
                confounders=['is_weekend', 'is_holiday', 'day_of_week', 'temperature_mean']
            )
            results['rainy_weather'] = rain_results

        # 3. 节假日效应分析
        if 'is_holiday' in df.columns:
            holiday_results = self._analyze_factor_effect(
                df, 'is_holiday', '节假日',
                confounders=['day_of_week', 'unique_customers', 'has_promotion']
            )
            results['holiday'] = holiday_results

        # 4. 周末效应分析
        weekend_results = self._analyze_factor_effect(
            df, 'is_weekend', '周末',
            confounders=['is_holiday', 'unique_customers', 'has_promotion']
        )
        results['weekend'] = weekend_results

        return results

    def _analyze_factor_effect(self, df: pd.DataFrame, treatment: str, treatment_name: str,
                             confounders: List[str], outcome: str = 'total_revenue') -> Dict[str, Any]:
        """分析单个因素的因果效应"""
        print(f"\n--- 分析 {treatment_name} 的因果效应 ---")

        # 数据预处理
        analysis_cols = [treatment, outcome] + confounders
        clean_df = self._force_numeric(df, analysis_cols).dropna(subset=analysis_cols)

        if len(clean_df) < 50:
            print(f"警告: {treatment_name} 数据不足，跳过分析")
            return {'error': '数据不足'}

        # 使用EconML进行分析（更稳定）
        try:
            Y = clean_df[outcome].values
            T = clean_df[treatment].values.astype(int)
            X = clean_df[confounders].values

            # 分割数据
            X_tr, X_te, T_tr, T_te, Y_tr, Y_te = train_test_split(
                X, T, Y, test_size=0.2, random_state=42
            )

            # 使用LinearDML（更稳定）
            ldml = LinearDML(
                model_t='auto',
                model_y=RandomForestRegressor(n_estimators=100, max_depth=5),
                discrete_treatment=True,
                random_state=42
            )

            ldml.fit(Y_tr, T_tr, X=X_tr)
            ate = float(ldml.ate(X_te))

            # 计算处理组和控制组的基线
            treatment_group_revenue = clean_df[clean_df[treatment] == 1][outcome].mean()
            control_group_revenue = clean_df[clean_df[treatment] == 0][outcome].mean()
            observed_diff = treatment_group_revenue - control_group_revenue

            # 计算处理率
            treatment_rate = clean_df[treatment].mean()

            print(f"{treatment_name} 分析结果:")
            print(f"  因果效应 (ATE): ${ate:.2f}")
            print(f"  观察到的差异: ${observed_diff:.2f}")
            print(f"  处理率: {treatment_rate:.1%}")
            print(f"  样本量: {len(clean_df)}")

            return {
                'ate': ate,
                'observed_diff': observed_diff,
                'treatment_rate': treatment_rate,
                'sample_size': len(clean_df),
                'treatment_group_mean': treatment_group_revenue,
                'control_group_mean': control_group_revenue
            }

        except Exception as e:
            print(f"{treatment_name} 分析失败: {e}")
            return {'error': str(e)}

    # ------------------------------------------------------------
    # 7. 交互效应分析
    # ------------------------------------------------------------
    def analyze_interaction_effects(self, df: pd.DataFrame) -> Dict[str, Any]:
        """分析交互效应（如：雨天 × 促销活动）"""
        print("\n=== 交互效应分析 ===\n")

        interactions = {}

        # 1. 天气 × 促销交互效应
        if all(col in df.columns for col in ['is_rainy', 'has_promotion']):
            rain_promo = self._analyze_interaction(
                df, 'is_rainy', 'has_promotion', '雨天促销交互效应'
            )
            interactions['rain_promotion'] = rain_promo

        # 2. 高温 × 产品类型交互效应
        if 'is_hot' in df.columns:
            hot_cold_drinks = self._analyze_product_weather_interaction(
                df, 'is_hot', '高温天气对不同产品的影响'
            )
            interactions['hot_weather_products'] = hot_cold_drinks

        # 3. 节假日 × 周末交互效应
        if 'is_holiday' in df.columns:
            holiday_weekend = self._analyze_interaction(
                df, 'is_holiday', 'is_weekend', '节假日周末交互效应'
            )
            interactions['holiday_weekend'] = holiday_weekend

        return interactions

    def _analyze_interaction(self, df: pd.DataFrame, factor1: str, factor2: str,
                           interaction_name: str) -> Dict[str, Any]:
        """分析两个因素的交互效应"""
        print(f"\n--- {interaction_name} ---")

        # 创建交互项
        df_temp = df.copy()
        interaction_term = f"{factor1}_x_{factor2}"
        df_temp[interaction_term] = df_temp[factor1] * df_temp[factor2]

        # 计算各组合的平均营收
        results = {}
        for val1 in [0, 1]:
            for val2 in [0, 1]:
                mask = (df_temp[factor1] == val1) & (df_temp[factor2] == val2)
                group_revenue = df_temp[mask]['total_revenue'].mean()
                group_size = mask.sum()
                results[f"{factor1}_{val1}_{factor2}_{val2}"] = {
                    'revenue': group_revenue,
                    'count': group_size
                }

        # 计算交互效应
        # 交互效应 = (A=1,B=1的效应) - (A=1,B=0的效应) - (A=0,B=1的效应) + (A=0,B=0的效应)
        baseline = results[f"{factor1}_0_{factor2}_0"]['revenue']
        factor1_effect = results[f"{factor1}_1_{factor2}_0"]['revenue'] - baseline
        factor2_effect = results[f"{factor1}_0_{factor2}_1"]['revenue'] - baseline
        combined_effect = results[f"{factor1}_1_{factor2}_1"]['revenue'] - baseline

        interaction_effect = combined_effect - factor1_effect - factor2_effect

        print(f"  {factor1} 单独效应: ${factor1_effect:.2f}")
        print(f"  {factor2} 单独效应: ${factor2_effect:.2f}")
        print(f"  交互效应: ${interaction_effect:.2f}")

        return {
            'factor1_effect': factor1_effect,
            'factor2_effect': factor2_effect,
            'interaction_effect': interaction_effect,
            'group_details': results
        }

    def _analyze_product_weather_interaction(self, df: pd.DataFrame, weather_factor: str,
                                           analysis_name: str) -> Dict[str, Any]:
        """分析天气对不同产品类型的影响"""
        print(f"\n--- {analysis_name} ---")

        product_effects = {}
        product_cols = ['cold_drink_orders', 'hot_drink_orders', 'food_orders']

        for product in product_cols:
            if product in df.columns:
                weather_group = df[df[weather_factor] == 1][product].mean()
                normal_group = df[df[weather_factor] == 0][product].mean()
                effect = weather_group - normal_group

                product_effects[product] = {
                    'weather_avg': weather_group,
                    'normal_avg': normal_group,
                    'effect': effect,
                    'effect_pct': (effect / normal_group * 100) if normal_group > 0 else 0
                }

                print(f"  {product}: {effect:+.1f} ({effect/normal_group*100:+.1f}%)")

        return product_effects

    # ------------------------------------------------------------
    # 8. 增强的可视化
    # ------------------------------------------------------------
    def visualize_multi_factor_effects(self, results: Dict[str, Any]) -> go.Figure:
        """可视化多因素效应对比"""
        factors = []
        effects = []
        colors = []

        color_map = {
            'promotion': 'blue',
            'hot_weather': 'red',
            'rainy_weather': 'gray',
            'holiday': 'green',
            'weekend': 'orange'
        }

        for factor_name, result in results.items():
            if 'ate' in result:
                factors.append(factor_name.replace('_', ' ').title())
                effects.append(result['ate'])
                colors.append(color_map.get(factor_name, 'purple'))

        fig = go.Figure(go.Bar(
            x=factors,
            y=effects,
            text=[f"${v:.0f}" for v in effects],
            textposition='auto',
            marker_color=colors
        ))

        fig.update_layout(
            title='各因素对营收的因果效应对比',
            xaxis_title='影响因素',
            yaxis_title='平均因果效应 ($)',
            showlegend=False
        )

        return fig

    def visualize_weather_impact_by_product(self, interaction_results: Dict[str, Any]) -> go.Figure:
        """可视化天气对不同产品的影响"""
        if 'hot_weather_products' not in interaction_results:
            return None

        weather_effects = interaction_results['hot_weather_products']

        products = []
        effects_pct = []

        for product, data in weather_effects.items():
            products.append(product.replace('_', ' ').title())
            effects_pct.append(data['effect_pct'])

        fig = go.Figure(go.Bar(
            x=products,
            y=effects_pct,
            text=[f"{v:+.1f}%" for v in effects_pct],
            textposition='auto',
            marker_color=['red' if v > 0 else 'blue' for v in effects_pct]
        ))

        fig.update_layout(
            title='高温天气对不同产品销量的影响',
            xaxis_title='产品类型',
            yaxis_title='销量变化 (%)',
            showlegend=False
        )

        return fig

    def visualize_seasonal_patterns(self, df: pd.DataFrame) -> go.Figure:
        """可视化季节性模式"""
        df_temp = df.copy()
        df_temp['date'] = pd.to_datetime(df_temp['date'])
        df_temp['month'] = df_temp['date'].dt.month

        # 按月份聚合
        monthly_data = df_temp.groupby('month').agg({
            'total_revenue': 'mean',
            'order_count': 'mean',
            'cold_drink_orders': 'mean',
            'hot_drink_orders': 'mean'
        }).reset_index()

        fig = make_subplots(
            rows=2, cols=2,
            subplot_titles=['月度营收', '月度订单量', '冷饮订单', '热饮订单']
        )

        # 营收
        fig.add_trace(
            go.Scatter(x=monthly_data['month'], y=monthly_data['total_revenue'],
                      mode='lines+markers', name='营收'),
            row=1, col=1
        )

        # 订单量
        fig.add_trace(
            go.Scatter(x=monthly_data['month'], y=monthly_data['order_count'],
                      mode='lines+markers', name='订单量'),
            row=1, col=2
        )

        # 冷饮
        if 'cold_drink_orders' in monthly_data.columns:
            fig.add_trace(
                go.Scatter(x=monthly_data['month'], y=monthly_data['cold_drink_orders'],
                          mode='lines+markers', name='冷饮订单'),
                row=2, col=1
            )

        # 热饮
        if 'hot_drink_orders' in monthly_data.columns:
            fig.add_trace(
                go.Scatter(x=monthly_data['month'], y=monthly_data['hot_drink_orders'],
                          mode='lines+markers', name='热饮订单'),
                row=2, col=2
            )

        fig.update_layout(
            title='季节性销售模式分析',
            showlegend=False
        )

        return fig

    # ------------------------------------------------------------
    # 9. 综合洞察报告
    # ------------------------------------------------------------
    def generate_comprehensive_insights(self, df: pd.DataFrame,
                                      multi_factor_results: Dict[str, Any],
                                      interaction_results: Dict[str, Any]) -> str:
        """生成综合洞察报告"""
        insights = []
        insights.append("# UMe 茶饮销售因果分析洞察报告\n")

        # 1. 总体概览
        insights.append("## 📊 总体概览")
        avg_revenue = df['total_revenue'].mean()
        avg_orders = df['order_count'].mean()
        insights.append(f"- 平均日营收: ${avg_revenue:.0f}")
        insights.append(f"- 平均日订单: {avg_orders:.0f}单")
        insights.append(f"- 分析期间: {df['date'].min()} 至 {df['date'].max()}")
        insights.append("")

        # 2. 主要影响因素
        insights.append("## 🎯 主要影响因素")
        sorted_factors = sorted(
            [(k, v) for k, v in multi_factor_results.items() if 'ate' in v],
            key=lambda x: abs(x[1]['ate']), reverse=True
        )

        for i, (factor, result) in enumerate(sorted_factors[:5]):
            effect = result['ate']
            rate = result['treatment_rate']
            insights.append(f"{i+1}. **{factor.replace('_', ' ').title()}**: {effect:+.0f}$ (发生率: {rate:.1%})")

        insights.append("")

        # 3. 关键发现
        insights.append("## 🔍 关键发现")

        # 促销效应分析
        if 'promotion' in multi_factor_results and 'ate' in multi_factor_results['promotion']:
            promo_effect = multi_factor_results['promotion']['ate']
            if promo_effect > 0:
                insights.append(f"- 促销活动平均带来 ${promo_effect:.0f} 的营收提升")
            else:
                insights.append(f"- 促销活动可能存在负面效应，营收下降 ${abs(promo_effect):.0f}")

        # 天气效应分析
        if 'hot_weather' in multi_factor_results and 'ate' in multi_factor_results['hot_weather']:
            hot_effect = multi_factor_results['hot_weather']['ate']
            insights.append(f"- 高温天气对营收的影响: {hot_effect:+.0f}$")

        if 'rainy_weather' in multi_factor_results and 'ate' in multi_factor_results['rainy_weather']:
            rain_effect = multi_factor_results['rainy_weather']['ate']
            insights.append(f"- 雨天对营收的影响: {rain_effect:+.0f}$")

        # 节假日效应
        if 'holiday' in multi_factor_results and 'ate' in multi_factor_results['holiday']:
            holiday_effect = multi_factor_results['holiday']['ate']
            insights.append(f"- 节假日带来营收提升: ${holiday_effect:.0f}")

        insights.append("")

        # 4. 交互效应洞察
        if interaction_results:
            insights.append("## 🔄 交互效应洞察")

            if 'rain_promotion' in interaction_results:
                rain_promo = interaction_results['rain_promotion']
                interaction_effect = rain_promo['interaction_effect']
                insights.append(f"- 雨天促销交互效应: {interaction_effect:+.0f}$")
                if interaction_effect > 0:
                    insights.append("  → 雨天进行促销活动特别有效")
                else:
                    insights.append("  → 雨天促销效果不如预期")

            insights.append("")

        # 5. 行动建议
        insights.append("## 💡 行动建议")

        # 基于最强影响因素给出建议
        if sorted_factors:
            strongest_factor, strongest_result = sorted_factors[0]
            if strongest_result['ate'] > 0:
                insights.append(f"1. **优化 {strongest_factor.replace('_', ' ')}**: 这是最有效的营收提升手段")

            # 基于天气给出建议
            if 'hot_weather' in multi_factor_results:
                insights.append("2. **天气适应性营销**: 根据天气预报调整产品推广和库存")

            # 基于促销给出建议
            if 'promotion' in multi_factor_results:
                promo_effect = multi_factor_results['promotion']['ate']
                if promo_effect > 0:
                    insights.append("3. **促销策略优化**: 当前促销有效，可以扩大范围")
                else:
                    insights.append("3. **重新评估促销策略**: 当前促销可能过度，建议精准化")

        return "\n".join(insights)

    # ------------------------------------------------------------
    # 10. 工具函数
    # ------------------------------------------------------------
    @staticmethod
    def _force_numeric(df, cols):
        """强制转换为数值类型"""
        out = df.copy()
        for c in cols:
            if c in out.columns:
                out[c] = pd.to_numeric(out[c], errors='coerce')
        return out


# ──────────────────────────────────────────────
# 完整使用示例
# ──────────────────────────────────────────────
if __name__ == "__main__":
    # 配置
    CLICKHOUSE_CONFIG = dict(
        host="clickhouse-0-0.umetea.net",
        port=443,
        database="dw",
        user="ml_ume",
        password="hDAoDvg8x552bH",
        verify=False,
    )

    WEATHER_API_KEY = "your-weather-api-key"  # 实际使用时填入真实API key

    # 初始化分析引擎
    enhanced_ci = EnhancedFBRCausalInference(CLICKHOUSE_CONFIG, WEATHER_API_KEY)

    # 设置分析时间范围
    start_date, end_date = "2025-06-01", "2025-07-31"

    print("="*80)
    print("UMe 茶饮增强版因果推断分析")
    print("="*80)

    # 1. 加载销售数据
    print("\n1️⃣ 加载销售数据...")
    sales_df = enhanced_ci.load_integrated_data(start_date, end_date)
    print(f"加载了 {len(sales_df)} 条销售记录")

    # 2. 获取天气数据
    print("\n2️⃣ 获取天气数据...")
    weather_df = enhanced_ci.get_weather_data(start_date, end_date, sales_df)

    # 3. 添加日历特征
    print("\n3️⃣ 添加日历特征...")
    sales_df = enhanced_ci.add_calendar_features(sales_df)

    # 4. 合并天气特征
    print("\n4️⃣ 合并天气特征...")
    if weather_df is not None:
        enhanced_df = enhanced_ci.add_weather_features(sales_df, weather_df)
    else:
        enhanced_df = sales_df

    # 5. 创建处理变量
    print("\n5️⃣ 创建处理变量...")
    enhanced_df = enhanced_ci.create_enhanced_treatment_variables(enhanced_df)

    # 6. 多因素因果分析
    print("\n6️⃣ 执行多因素因果分析...")
    multi_factor_results = enhanced_ci.analyze_multi_factor_effects(enhanced_df)

    # 7. 交互效应分析
    print("\n7️⃣ 执行交互效应分析...")
    interaction_results = enhanced_ci.analyze_interaction_effects(enhanced_df)

    # 8. 生成可视化
    print("\n8️⃣ 生成可视化...")

    # 多因素效应对比
    fig1 = enhanced_ci.visualize_multi_factor_effects(multi_factor_results)
    fig1.show()

    # 天气对产品的影响
    fig2 = enhanced_ci.visualize_weather_impact_by_product(interaction_results)
    if fig2:
        fig2.show()

    # 季节性模式
    fig3 = enhanced_ci.visualize_seasonal_patterns(enhanced_df)
    fig3.show()

    # 9. 生成综合洞察报告
    print("\n9️⃣ 生成综合分析报告...")
    insights_report = enhanced_ci.generate_comprehensive_insights(
        enhanced_df, multi_factor_results, interaction_results
    )

    print("\n" + "="*80)
    print("📋 综合分析报告")
    print("="*80)
    print(insights_report)

    print("\n✅ 分析完成！")
    print("💡 建议保存分析结果用于决策参考")

UMe 茶饮增强版因果推断分析

1️⃣ 加载销售数据...
加载了 1362 条销售记录

2️⃣ 获取天气数据...

=== 获取天气数据 ===
跳过未知州: UM
跳过未知州: Me
成功获取 183 条天气记录

3️⃣ 添加日历特征...

4️⃣ 合并天气特征...

5️⃣ 创建处理变量...

6️⃣ 执行多因素因果分析...

=== 多因素因果效应分析 ===


--- 分析 促销活动 的因果效应 ---
促销活动 分析结果:
  因果效应 (ATE): $358.09
  观察到的差异: $2172.52
  处理率: 93.5%
  样本量: 1362

--- 分析 高温天气 的因果效应 ---
高温天气 分析结果:
  因果效应 (ATE): $-206.01
  观察到的差异: $-1721.68
  处理率: 12.7%
  样本量: 1362

--- 分析 雨天 的因果效应 ---
雨天 分析结果:
  因果效应 (ATE): $-1142.95
  观察到的差异: $-919.79
  处理率: 5.1%
  样本量: 1362

--- 分析 节假日 的因果效应 ---
节假日 分析结果:
  因果效应 (ATE): $-630.23
  观察到的差异: $91.02
  处理率: 3.3%
  样本量: 1362

--- 分析 周末 的因果效应 ---
周末 分析结果:
  因果效应 (ATE): $-301.11
  观察到的差异: $473.88
  处理率: 28.2%
  样本量: 1362

7️⃣ 执行交互效应分析...

=== 交互效应分析 ===


--- 雨天促销交互效应 ---
  is_rainy 单独效应: $276.23
  has_promotion 单独效应: $2241.00
  交互效应: $-1236.21

--- 高温天气对不同产品的影响 ---
  cold_drink_orders: -72.4 (-59.8%)
  hot_drink_orders: +0.0 (+nan%)
  food_orders: -40.1 (-83.2%)

--- 节假日周末交互效应 ---
  is_holiday 单独效应: $232.30
  is_weekend 单独效应: $4