In [None]:
import streamlit as st
import pandas as pd
from tfcausalimpact import CausalImpact
import matplotlib.pyplot as plt

st.set_page_config(layout="wide")
st.title("📈 Causal Impact Analysis App")

# File uploader
uploaded_file = st.file_uploader("Upload your Excel file (GA4 export)", type=["xlsx"])

if uploaded_file:
    xls = pd.ExcelFile(uploaded_file)
    sheet_names = xls.sheet_names
    sheet = st.selectbox("Select the sheet to use", sheet_names)

    raw_df = pd.read_excel(uploaded_file, sheet_name=sheet, skiprows=6)

    # Clean data
    df = raw_df.dropna(subset=["Date"])
    df = df.drop(columns=[col for col in df.columns if "Unnamed" in col], errors="ignore")
    df["Date"] = pd.to_datetime(df["Date"].astype(str), format="%Y%m%d")
    df.set_index("Date", inplace=True)

    st.subheader("Preview of Cleaned Data")
    st.dataframe(df.head())

    target_metric = st.selectbox("Select target metric to evaluate", df.columns.tolist())
    control_metrics = st.multiselect("Optional: Add control variables", [col for col in df.columns if col != target_metric])

    # Date range selector
    all_dates = df.index
    min_date = all_dates.min()
    max_date = all_dates.max()

    st.markdown("### Define Intervention Date")
    intervention_date = st.date_input("Intervention Date (e.g. campaign start)", min_value=min_date, max_value=max_date)

    # Pre and post period auto selector
    default_pre_period = [all_dates[0], intervention_date - pd.Timedelta(days=1)]
    default_post_period = [intervention_date, all_dates[-1]]

    st.write(f"Pre-period: {default_pre_period[0].date()} to {default_pre_period[1].date()}")
    st.write(f"Post-period: {default_post_period[0].date()} to {default_post_period[1].date()}")

    # Run Causal Impact
    if st.button("Run Causal Impact Analysis"):
        try:
            model_data = df[[target_metric] + control_metrics] if control_metrics else df[[target_metric]]

            # Convert date indices to integer indices for tfcausalimpact
            pre_idx = [model_data.index.get_loc(default_pre_period[0]), model_data.index.get_loc(default_pre_period[1])]
            post_idx = [model_data.index.get_loc(default_post_period[0]), model_data.index.get_loc(default_post_period[1])]

            impact = CausalImpact(model_data, pre_idx, post_idx)

            st.subheader("📊 Summary")
            st.text(impact.summary())

            st.subheader("📝 Detailed Report")
            st.text(impact.summary(output='report'))

            st.subheader("📉 Impact Plot")
            fig = impact.plot(figsize=(15, 6))
            st.pyplot(fig)

        except Exception as e:
            st.error(f"Error during analysis: {e}")
