In [None]:
# ============================================================
# IED M√©xico (Nacional) ‚Äì Secretar√≠a de Econom√≠a (datos.gob.mx)
# ADAPTADO A PLOTLY/STREAMLIT con indentaci√≥n uniforme y limpia
# ============================================================

import re, math, os
import pandas as pd
import plotly.graph_objects as go
import requests
from datetime import datetime
import streamlit as st

# -----------------------
# Par√°metros de usuario
# -----------------------
TOP_N_SECTORES    = 10
TOP_N_PAISES      = 10
SAVE_PNG          = False
OUT_DIR           = "./"

# ---------------------------------------------
# Adaptaci√≥n de Estilo para Streamlit
# ---------------------------------------------
DEFAULT_FONT = "Aptos Light, Aptos, Segoe UI, Arial, Helvetica, sans-serif"
COLOR_TEXT_DEFAULT = "#4a4a4a"
DEFAULT_PALETTE = [
    "#1f2a35", "#ff9f18", "#889064", "#2E7D32", "#546E7A",
    "#8E24AA", "#00838F", "#6D4C41", "#C62828", "#1565C0", "#7CB342"
]

PALETTE = globals().get('active_palette', DEFAULT_PALETTE)
FONT = globals().get('active_font', DEFAULT_FONT)

# Asignar variables de estilo
FONT_FAMILY = FONT
COLOR_BAR = PALETTE[0] if len(PALETTE) > 0 else DEFAULT_PALETTE[0]
COLOR_BG = "#ffffff"
COLOR_GRID = "#d9dada"
COLOR_TEXT = COLOR_TEXT_DEFAULT

# -----------------------
# Endpoints (CKAN)
# -----------------------
CKAN_PACKAGE_ID = "inversion_extranjera_directa"
CKAN_API = f"https://www.datos.gob.mx/api/3/action/package_show?id={CKAN_PACKAGE_ID}"

# --------------------------
# Diccionario SCIAN (2 d√≠gitos -> Sector)
# --------------------------
SCIAN_SECTORES = {
    "11": "Agroindustria",
    "21": "Miner√≠a",
    "22": "Electricidad, agua y gas",
    "23": "Construcci√≥n",
    "31": "Manufactura", "32": "Manufactura", "33": "Manufactura",
    "43": "Comercio al por mayor",
    "46": "Comercio al por menor",
    "48": "Transporte, correo y almacenamiento", "49": "Transporte, correo y almacenamiento",
    "51": "Informaci√≥n en medios",
    "52": "Servicios financieros y de seguros",
    "53": "Servicios inmobiliarios",
    "54": "Serv. profesionales, cient√≠ficos y t√©cnicos",
    "55": "Direcci√≥n de corporativos",
    "56": "Serv. de apoyo a los negocios",
    "61": "Serv. educativos",
    "62": "Salud y asistencia social",
    "71": "Cultura, deporte y esparcimiento",
    "72": "Alojamiento y preparaci√≥n de alimentos",
    "81": "Otros servicios",
}

# --------------------------
# Utilidades de formato
# --------------------------
def format_short(x: float) -> str:
    """Formatea a notaci√≥n corta (k/M/B) con prefijo $."""
    if x is None or (isinstance(x, float) and math.isnan(x)): return ""
    absx = abs(x)
    if absx >= 1_000_000_000: return f"${x/1_000_000_000:.2f} B"
    if absx >= 1_000_000:     return f"${x/1_000_000:.2f} M"
    if absx >= 1_000:         return f"${x/1_000:.2f} k"
    return f"${x:.0f}"

def ensure_dir(path: str):
    if path and not os.path.exists(path):
        os.makedirs(path, exist_ok=True)

def try_save(fig, path, scale=3):
    """Guarda PNG si SAVE_PNG=True y est√° 'kaleido' disponible."""
    if not SAVE_PNG:
        return
    try:
        import kaleido # noqa: F401
        ensure_dir(os.path.dirname(path) or ".")
        fig.write_image(path, scale=scale)
        print(f"‚úÖ PNG guardado: {path}")
    except Exception as e:
        print("‚ö†Ô∏è No se pudo exportar a PNG. Instala 'kaleido' (pip install -U kaleido).")
        print("Error:", e)

# --------------------------
# Utilidades de periodo
# --------------------------
def _build_period(df: pd.DataFrame, mapping: dict) -> pd.Series:
    if mapping.get("periodo") and mapping["periodo"] in df.columns:
        return df[mapping["periodo"]]
    y = mapping.get("anio"); t = mapping.get("trimestre")
    if y and t and (y in df.columns) and (t in df.columns):
        return df[y].astype(str) + "T" + df[t].astype(str)
    elif y and (y in df.columns):
        return df[y].astype(str)
    return pd.Series(["NA"] * len(df))

def _parse_last_period(unique_periods):
    def parse_key(p):
        s = str(p)
        if "T" in s:
            y, tt = s.split("T", 1)
            try: return (int(y), int(tt))
            except: return (s, 0)
        try: return (int(s), 0)
        except: return (s, 0)
    return sorted(unique_periods, key=parse_key)[-1] if len(unique_periods) else None

# --------------------------
# CKAN helpers
# --------------------------
def fetch_ckan_resources():
    r = requests.get(CKAN_API, timeout=60)
    r.raise_for_status()
    data = r.json()
    if not data.get("success", False):
        raise RuntimeError("No se pudo consultar el package_show de CKAN.")
    return data["result"].get("resources", [])

def pick_resource_for_sectors(resources):
    key_names = [
        "por pa√≠s de origen, sector, subsector y rama",
        "por pais de origen, sector, subsector y rama",
        "por entidad y sector",
        "sector",
    ]
    name_to_res = { (res.get("name") or res.get("title") or "").lower(): res for res in resources }
    for key in key_names:
        for name_l, res in name_to_res.items():
            if key in name_l and res.get("format", "").lower() in {"csv", "txt", "zip"}:
                return res
    for res in resources:
        if res.get("format", "").lower() == "csv":
            return res
    raise RuntimeError("No fue posible identificar un recurso CSV con sectorizaci√≥n.")

def pick_resource_for_countries(resources):
    key_names = [
        "por pa√≠s de origen y tipo",
        "por pais de origen y tipo",
        "por pa√≠s de origen, sector, subsector y rama",
        "por pais de origen, sector, subsector y rama",
        "pa√≠s de origen",
        "pais de origen",
    ]
    name_to_res = { (res.get("name") or res.get("title") or "").lower(): res for res in resources }
    for key in key_names:
        for name_l, res in name_to_res.items():
            if key in name_l and res.get("format", "").lower() in {"csv", "txt", "zip"}:
                return res
    for res in resources:
        if res.get("format", "").lower() == "csv":
            return res
    raise RuntimeError("No fue posible identificar un recurso CSV con pa√≠s de origen.")

def load_csv_resource(res) -> pd.DataFrame:
    url = res.get("url") or res.get("download_url")
    if not url:
        raise RuntimeError("El recurso seleccionado no tiene URL de descarga.")
    try:
        df = pd.read_csv(url, encoding="utf-8", low_memory=False)
    except Exception:
        df = pd.read_csv(url, encoding="latin-1", sep=";", low_memory=False)
    df.columns = [c.strip().lower() for c in df.columns]
    return df

# --------------------------
# Mapeo de columnas
# --------------------------
def detect_fields(df: pd.DataFrame) -> dict:
    cols = set(df.columns)
    def pick(cands):
        for c in cands:
            if c in cols: return c
        return None
    return {
        "anio": pick(["anio", "a√±o", "year"]),
        "trimestre": pick(["trimestre", "quarter"]),
        "periodo": pick(["periodo", "period", "fecha"]),
        "sector": pick(["sector_subsector_rama", "sector", "sector_economico", "sector econ√≥mico"]),
        "pais": pick(["pais", "pa√≠s", "pais_de_origen", "pa√≠s de origen", "pais de origen", "country"]),
        "monto_musd": pick(["fn_millones_de_dolares", "millones_de_dolares"]),
    }

# --------------------------
# Parsing sector (SCIAN 2 d√≠gitos)
# --------------------------
def parse_scian_sector_name(texto: str) -> str:
    if pd.isna(texto): return "No clasificado"
    s = str(texto).strip()
    m = re.match(r"^(\d{2,4})", s)
    if not m: return "No clasificado"
    key2 = m.group(1)[:2]
    return SCIAN_SECTORES.get(key2, "No clasificado")

# --------------------------
# Procesamiento ‚Äì Sectores
# --------------------------
def aggregate_national_sector_latest_USD(df: pd.DataFrame, mapping: dict, top_n=10) -> pd.DataFrame:
    df = df.copy()
    df["__periodo__"] = _build_period(df, mapping)

    ups = df["__periodo__"].dropna().astype(str).replace({"nan": None}).dropna().unique()
    last_key = _parse_last_period(ups)
    if last_key is not None:
        df = df[df["__periodo__"] == last_key]

    if "entidad" in df.columns:
        m = df["entidad"].astype(str).str.lower().str.contains(r"nacional|total|rep√∫blica|republica")
        if m.any(): df = df[m]

    sector_raw = mapping.get("sector")
    musd_col = mapping.get("monto_musd")
    if not sector_raw or not musd_col:
        raise RuntimeError("No se detectaron columnas 'sector'/'monto_musd'.")

    df["__sector2__"] = df[sector_raw].apply(parse_scian_sector_name)

    serie_musd = df[musd_col]
    if serie_musd.dtype.kind not in "biufc":
        serie_musd = (serie_musd.astype(str)
                      .str.replace("$", "", regex=False)
                      .str.replace(",", "", regex=False)
                      .str.replace(" ", "", regex=False)
                      .str.replace("\u00a0", "", regex=False))
        serie_musd = pd.to_numeric(serie_musd, errors="coerce")
    df["__usd__"] = serie_musd * 1_000_000

    out = (
        df.groupby("__sector2__", dropna=False)["__usd__"].sum().reset_index()
          .rename(columns={"__sector2__": "sector", "__usd__": "monto_usd"})
          .sort_values("monto_usd", ascending=False)
    )
    if top_n and top_n > 0 and len(out) > top_n:
        out = out.head(top_n).sort_values("monto_usd", ascending=True)

    out["periodo"] = last_key if last_key is not None else "NA"
    return out

# --------------------------
# Procesamiento ‚Äì Pa√≠ses
# --------------------------
def aggregate_national_by_country_USD(df: pd.DataFrame, mapping: dict, top_n=10) -> pd.DataFrame:
    df = df.copy()
    df["__periodo__"] = _build_period(df, mapping)

    ups = df["__periodo__"].dropna().astype(str).replace({"nan": None}).dropna().unique()
    last_key = _parse_last_period(ups)
    if last_key is not None:
        df = df[df["__periodo__"] == last_key]

    if "entidad" in df.columns:
        m = df["entidad"].astype(str).str.lower().str.contains(r"nacional|total|rep√∫blica|republica")
        if m.any(): df = df[m]

    pais_col = mapping.get("pais")
    musd_col = mapping.get("monto_musd")
    if not pais_col or not musd_col:
        raise RuntimeError("No se detectaron columnas 'pais'/'monto_musd'.")

    serie_musd = df[musd_col]
    if serie_musd.dtype.kind not in "biufc":
        serie_musd = (serie_musd.astype(str)
                      .str.replace("$", "", regex=False)
                      .str.replace(",", "", regex=False)
                      .str.replace(" ", "", regex=False)
                      .str.replace("\u00a0", "", regex=False))
        serie_musd = pd.to_numeric(serie_musd, errors="coerce")
    df["__usd__"] = serie_musd * 1_000_000

    df = df[df["__usd__"] > 0]
    out = (
        df.groupby(pais_col, dropna=False)["__usd__"].sum().reset_index()
          .rename(columns={pais_col: "pais", "__usd__": "monto_usd"})
          .sort_values("monto_usd", ascending=False)
    )

    if top_n and top_n > 0 and len(out) > top_n:
        top = out.head(top_n).copy()
        otros = pd.DataFrame([{"pais": "Otros", "monto_usd": out["monto_usd"].iloc[top_n:].sum()}])
        out = pd.concat([top, otros], ignore_index=True)

    out["periodo"] = last_key if last_key is not None else "NA"
    return out

# --------------------------
# Gr√°ficas (Modificadas con Fuente Abajo-Izquierda)
# --------------------------
def plot_barras_sectores(df_sec: pd.DataFrame, titulo_sufijo: str = ""):
    maxv = df_sec["monto_usd"].max() if len(df_sec) else 0
    texts, positions = [], []
    for v in df_sec["monto_usd"]:
        label = format_short(v)
        if maxv and v >= 0.12 * maxv:
            texts.append(label); positions.append("inside")
        else:
            texts.append(label); positions.append("outside")

    fig = go.Figure(
        go.Bar(
            x=df_sec["monto_usd"], y=df_sec["sector"], orientation="h",
            marker_color=COLOR_BAR, text=texts, textposition=positions,
            insidetextanchor="middle",
            textfont=dict(color="white", family=FONT_FAMILY, size=12),
            cliponaxis=False, name="Monto Inversi√≥n",
        )
    )
    periodo = df_sec["periodo"].iloc[0] if len(df_sec) else ""
    fig.update_layout(
        title=dict(
            text=f"Monto de Inversi√≥n por sector econ√≥mico{titulo_sufijo}"
                 f"<br><sup>Billones ‚Äì USD ‚Äì {periodo}</sup>",
            font=dict(family=FONT_FAMILY, size=26, color=COLOR_TEXT),
            x=0.5, xanchor="center",
        ),
        plot_bgcolor=COLOR_BG, paper_bgcolor=COLOR_BG,
        font=dict(family=FONT_FAMILY, color=COLOR_TEXT),
        xaxis=dict(showgrid=True, gridcolor=COLOR_GRID, zeroline=False,
                   tickprefix="$ ", separatethousands=True, title="Monto USD"),
        yaxis=dict(showgrid=False, title="", automargin=True, categoryorder="total ascending"),
        # === CAMBIO 1: Aumentamos margen inferior (b) de 60 a 130 ===
        # Necesario para que quepa el t√≠tulo del eje X y la Fuente debajo
        margin=dict(l=170, r=40, t=90, b=130), 
        height=540,
    )
    fig.update_traces(outsidetextfont=dict(color=COLOR_TEXT, family=FONT_FAMILY, size=12))

    # === CAMBIO 2: Agregar Fuente en inferior izquierda ===
    fig.add_annotation(
        text="Fuente: Secretar√≠a de Econom√≠a (datos.gob.mx)",
        xref="paper", yref="paper",
        x=-0.15, # Un poco negativo para alinearse con los nombres de los sectores (izquierda visual)
        y=-0.25, # Lo suficientemente abajo para no chocar con el t√≠tulo "Monto USD"
        showarrow=False,
        xanchor='left',
        yanchor='top',
        font=dict(size=11, color="gray", family=FONT_FAMILY)
    )

    return fig

def plot_pastel_paises(df_c: pd.DataFrame, titulo_sufijo: str = ""):
    periodo = df_c["periodo"].iloc[0] if len(df_c) else ""
    labels = df_c["pais"]; values = df_c["monto_usd"]
    custom = [format_short(v) for v in values]
    fig = go.Figure(
        go.Pie(
            labels=labels, values=values, hole=0.35,
            marker=dict(colors=PALETTE[:len(labels)]),
            textinfo="percent", textfont=dict(family=FONT_FAMILY, size=12),
            hovertemplate="<b>%{label}</b><br>%{percent}<br>Monto: %{customdata}<extra></extra>",
            customdata=custom, sort=False
        )
    )
    fig.update_layout(
        title=dict(
            text=f"Pa√≠ses de origen de la IED en M√©xico{titulo_sufijo}"
                 f"<br><sup>USD corrientes (total del periodo) ‚Äì {periodo}</sup>",
            font=dict(family=FONT_FAMILY, size=22, color=COLOR_TEXT),
            x=0.5, xanchor="center"
        ),
        showlegend=True,
        legend=dict(orientation="v", yanchor="top", y=1.0, xanchor="left", x=1.02,
                    bgcolor="rgba(255,255,255,0.0)", font=dict(family=FONT_FAMILY, size=12, color=COLOR_TEXT)),
        paper_bgcolor=COLOR_BG, plot_bgcolor=COLOR_BG,
        # === CAMBIO 3: Aumentamos margen inferior (b) de 40 a 100 ===
        margin=dict(l=60, r=180, t=90, b=100), 
        height=560,
    )

    # === CAMBIO 4: Agregar Fuente en inferior izquierda ===
    fig.add_annotation(
        text="Fuente: Secretar√≠a de Econom√≠a (datos.gob.mx)",
        xref="paper", yref="paper",
        x=0,     # Alineado al borde izquierdo del √°rea de trazado
        y=-0.1,  # Debajo del gr√°fico
        showarrow=False,
        xanchor='left',
        yanchor='top',
        font=dict(size=11, color="gray", family=FONT_FAMILY)
    )

    return fig

# --------------------------
# Pipeline
# --------------------------
def run_both_charts():
    @st.cache_data(ttl=3600)
    def fetch_and_process():
        # print("Consultando paquete CKAN...", CKAN_API) # Desactivado por ser output de terminal
        try:
            resources = fetch_ckan_resources()
        except Exception as e:
            st.error(f"Error al conectar con datos.gob.mx: {e}")
            st.stop()
            return None, None

        # ===== Sectores =====
        res_sec = pick_resource_for_sectors(resources)
        # print("Recurso Sectores:", res_sec.get("name") or res_sec.get("title"))
        df_sec_raw = load_csv_resource(res_sec)
        mapping_sec = detect_fields(df_sec_raw)
        df_sec = aggregate_national_sector_latest_USD(df_sec_raw, mapping_sec, top_n=TOP_N_SECTORES)

        # ===== Pa√≠ses =====
        res_ctry = pick_resource_for_countries(resources)
        # print("Recurso Pa√≠ses:", res_ctry.get("name") or res_ctry.get("title"))
        df_ctry_raw = load_csv_resource(res_ctry)
        mapping_ctry = detect_fields(df_ctry_raw)
        df_ctry = aggregate_national_by_country_USD(df_ctry_raw, mapping_ctry, top_n=TOP_N_PAISES)

        return df_sec, df_ctry

    df_sec, df_ctry = fetch_and_process()

    if df_sec is None or df_ctry is None:
        return

    # Generar Figuras
    fig_sec = plot_barras_sectores(df_sec)
    fig_ctry = plot_pastel_paises(df_ctry)

    # Exportaci√≥n opcional de PNG (mantenida pero deshabilitada por defecto)
    period_sec = (df_sec["periodo"].iloc[0] if len(df_sec) else datetime.today().strftime("%Y%m%d")).replace("/", "-")
    out_sec = os.path.join(OUT_DIR, f"ied_sector_top{TOP_N_SECTORES}_{period_sec}.png")
    try_save(fig_sec, out_sec)

    period_ctry = (df_ctry["periodo"].iloc[0] if len(df_ctry) else datetime.today().strftime("%Y%m%d")).replace("/", "-")
    out_ctry = os.path.join(OUT_DIR, f"ied_paises_top{TOP_N_PAISES}_{period_ctry}.png")
    try_save(fig_ctry, out_ctry)

    # === MODIFICACI√ìN: Retornar tambi√©n los DataFrames para mostrarlos en tablas ===
    return fig_sec, fig_ctry, df_sec, df_ctry

# --------------------------
# Ejecuci√≥n en Streamlit
# --------------------------

st.title("Inversi√≥n Extranjera Directa (IED) en M√©xico")

# Bot√≥n para forzar la actualizaci√≥n de datos (limpiar cach√©)
if st.button("Actualizar Gr√°ficas IED (Forzar descarga)"):
    st.cache_data.clear()

try:
    # === MODIFICACI√ìN: Desempaquetar 4 valores ===
    fig_sec, fig_ctry, df_sec, df_ctry = run_both_charts()

    if fig_sec and fig_ctry:
        # --- SECCI√ìN 1: SECTORES ---
        st.subheader("1. üí∞ IED por Sector Econ√≥mico (Top 10)")
        st.plotly_chart(fig_sec, use_container_width=True)
        
        # Tabla de datos Sectores (Visible, sin expander)
        st.markdown("**Datos detallados (Sectores):**")
        table_sec = df_sec[["sector", "monto_usd", "periodo"]].copy()
        # Ordenamos de mayor a menor (como la gr√°fica)
        table_sec = table_sec.sort_values("monto_usd", ascending=False)
        table_sec.columns = ["Sector", "Monto (USD)", "Periodo"]
        
        st.dataframe(
            table_sec,
            use_container_width=True,
            hide_index=True,
            column_config={
                "Monto (USD)": st.column_config.NumberColumn(format="$%,.2f"),
                "Sector": st.column_config.TextColumn()
            }
        )

        st.markdown("---")

        # --- SECCI√ìN 2: PA√çSES ---
        st.subheader("2. üåé IED por Pa√≠s de Origen (Distribuci√≥n)")
        st.plotly_chart(fig_ctry, use_container_width=True)
        
        # Tabla de datos Pa√≠ses (Visible, sin expander)
        st.markdown("**Datos detallados (Pa√≠ses):**")
        table_ctry = df_ctry[["pais", "monto_usd", "periodo"]].copy()
        table_ctry = table_ctry.sort_values("monto_usd", ascending=False)
        table_ctry.columns = ["Pa√≠s", "Monto (USD)", "Periodo"]

        st.dataframe(
            table_ctry,
            use_container_width=True,
            hide_index=True,
            column_config={
                "Monto (USD)": st.column_config.NumberColumn(format="$%,.2f"),
                "Pa√≠s": st.column_config.TextColumn()
            }
        )

except RuntimeError as e:
    st.error(f"Error fatal en el procesamiento de datos: {e}")
except Exception as e:
    st.error(f"Ocurri√≥ un error inesperado: {e}")