In [None]:
from datetime import datetime, timezone
from typing import Dict, List, Optional, Tuple

import logging

import pyspark.sql.functions as F
from pyspark.sql import DataFrame
from delta.tables import DeltaTable
from pyspark.sql.types import StructType, StructField, TimestampType, StringType

CATALOG = ""

VOLUME_CATALOG = "main"
VOLUME_SCHEMA = "engenharia_dados"
VOLUME_NAME = "aviacao_landing"

LANDING_CSV_BASE_PATH = f"/Volumes/{VOLUME_CATALOG}/{VOLUME_SCHEMA}/{VOLUME_NAME}/aviacao/landing"

BRONZE_SCHEMA = "aviacao_bronze"
META_SCHEMA = "aviacao_meta"
ORIGEM_SISTEMA = "postgres-aviacao"

TABLE_CONFIGS: Dict[str, Dict] = {
    "companhias_aereas": {"schema": "aviacao", "business_key": ["id"]},
    "modelos_avioes": {"schema": "aviacao", "business_key": ["id"]},
    "aeroportos": {"schema": "aviacao", "business_key": ["id"]},
    "aeronaves": {"schema": "aviacao", "business_key": ["id"]},
    "funcionarios": {"schema": "aviacao", "business_key": ["id"]},
    "clientes": {"schema": "aviacao", "business_key": ["id"]},
    "voos": {"schema": "aviacao", "business_key": ["id"]},
    "reservas": {"schema": "aviacao", "business_key": ["id"]},
    "bilhetes": {"schema": "aviacao", "business_key": ["id"]},
    "bagagens": {"schema": "aviacao", "business_key": ["id"]},
    "manutencoes": {"schema": "aviacao", "business_key": ["id"]},
    "tripulacao_voo": {"schema": "aviacao", "business_key": ["id"]},
}

TABLE_SCHEMAS: Dict[str, StructType] = {}

logger = logging.getLogger("aviacao_bronze")
if not logger.handlers:
    handler = logging.StreamHandler()
    formatter = logging.Formatter(
        "%(asctime)s [%(levelname)s] %(name)s - %(message)s"
    )
    handler.setFormatter(formatter)
    logger.addHandler(handler)
logger.setLevel(logging.INFO)


def qname(schema: str, table: str) -> str:
    if CATALOG:
        return f"{CATALOG}.{schema}.{table}"
    return f"{schema}.{table}"


def now_utc():
    return datetime.now(timezone.utc)


def init_schema(schema_name: str) -> None:
    schema_qualified = f"{CATALOG}.{schema_name}" if CATALOG else schema_name
    spark.sql(f"CREATE SCHEMA IF NOT EXISTS {schema_qualified}")


def bronze_table_name(table_name: str) -> str:
    return qname(BRONZE_SCHEMA, f"{table_name}_changelog")


def meta_bronze_batches_table() -> str:
    return qname(META_SCHEMA, "bronze_landing_batches")


def get_landing_schema(table_name: str) -> Optional[StructType]:
    return TABLE_SCHEMAS.get(table_name)


def path_exists(path: str) -> bool:
    try:
        files = dbutils.fs.ls(path)
        return len(files) > 0
    except Exception:
        return False


def init_meta_bronze_batches() -> None:
    init_schema(META_SCHEMA)
    meta_table = meta_bronze_batches_table()
    spark.sql(f"""
        CREATE TABLE IF NOT EXISTS {meta_table} (
            tabela STRING NOT NULL,
            landing_batch_id STRING NOT NULL,
            landing_path STRING NOT NULL,
            first_seen_ts TIMESTAMP NOT NULL,
            processed_ts TIMESTAMP NOT NULL
        )
        USING DELTA
    """)


def list_landing_batches(table_name: str) -> List[Tuple[str, str]]:
    base_path = f"{LANDING_CSV_BASE_PATH}/{table_name}"
    if not path_exists(base_path):
        logger.info(f"[{table_name}] Nenhum diretório de landing encontrado em {base_path}.")
        return []

    batches: List[Tuple[str, str]] = []
    for f in dbutils.fs.ls(base_path):
        path = f.path
        name = path.rstrip("/").split("/")[-1]
        if name.startswith("batch_id="):
            batch_id = name.split("batch_id=")[-1]
            batches.append((batch_id, path))
    logger.info(f"[{table_name}] Batches encontrados na landing: {[b[0] for b in batches]}")
    return batches


def list_processed_batches(table_name: str) -> List[str]:
    meta_table = meta_bronze_batches_table()
    if not spark.catalog.tableExists(meta_table):
        return []
    df = spark.table(meta_table).filter(F.col("tabela") == table_name)
    rows = [r["landing_batch_id"] for r in df.select("landing_batch_id").distinct().collect()]
    logger.info(f"[{table_name}] Batches já processados na Bronze: {rows}")
    return rows


def read_landing_incremental(table_name: str) -> Tuple[Optional[DataFrame], int, List[Tuple[str, str]]]:
    landing_batches = list_landing_batches(table_name)
    if not landing_batches:
        return None, 0, []

    processed_ids = set(list_processed_batches(table_name))
    new_batches = [(bid, path) for (bid, path) in landing_batches if bid not in processed_ids]

    if not new_batches:
        logger.info(f"[{table_name}] Nenhum batch_id novo na landing. Nada a processar.")
        return None, 0, []

    logger.info(f"[{table_name}] Batches novos a serem processados: {[b[0] for b in new_batches]}")

    reader = (
        spark.read
        .option("header", "true")
        .option("delimiter", ";")
    )

    schema = get_landing_schema(table_name)
    if schema is not None:
        reader = reader.schema(schema)
    else:
        logger.warning(
            f"[{table_name}] Usando inferSchema na landing. "
            f"Configure TABLE_SCHEMAS para produção."
        )
        reader = reader.option("inferSchema", "true")

    df_all: Optional[DataFrame] = None
    total_raw = 0

    for batch_id, path in new_batches:
        df_part = reader.csv(path)
        df_part = df_part.withColumn("data_ref", F.col("data_ref").cast("timestamp"))
        if "landing_load_ts" in df_part.columns:
            df_part = df_part.withColumn("landing_load_ts", F.col("landing_load_ts").cast("timestamp"))
        else:
            logger.warning(f"[{table_name}] Coluna landing_load_ts não encontrada no batch {batch_id}.")

        count_part = df_part.count()
        total_raw += count_part
        logger.info(f"[{table_name}] Registros brutos no batch {batch_id}: {count_part}")

        if df_all is None:
            df_all = df_part
        else:
            df_all = df_all.unionByName(df_part)

    if df_all is None:
        return None, 0, []

    df_valid = df_all.filter(F.col("data_ref").isNotNull())
    total_valid = df_valid.count()
    null_count = total_raw - total_valid
    if null_count > 0:
        logger.warning(
            f"[{table_name}] {null_count} registros descartados por data_ref nula."
        )

    if total_valid > 0:
        stats = df_valid.agg(
            F.min("data_ref").alias("min_data_ref"),
            F.max("data_ref").alias("max_data_ref"),
            F.min("landing_load_ts").alias("min_landing_ts"),
            F.max("landing_load_ts").alias("max_landing_ts"),
        ).collect()[0]
        logger.info(
            f"[{table_name}] Faixa nos batches novos - "
            f"data_ref: [{stats['min_data_ref']}, {stats['max_data_ref']}], "
            f"landing_load_ts: [{stats['min_landing_ts']}, {stats['max_landing_ts']}]"
        )

    if "change_op" in df_valid.columns:
        dist = df_valid.groupBy("change_op").count().collect()
        for row in dist:
            logger.info(
                f"[{table_name}] (landing filtrada) change_op={row['change_op']} "
                f"-> {row['count']} registros"
            )

    business_key_cols = TABLE_CONFIGS[table_name]["business_key"]
    df_valid = df_valid.dropDuplicates(business_key_cols + ["data_ref"])
    total_final = df_valid.count()
    logger.info(
        f"[{table_name}] Registros após deduplicação (batches novos): {total_final}"
    )

    if total_final == 0:
        return None, 0, []

    return df_valid, total_final, new_batches


def ensure_bronze_changelog_table(table_name: str, df_sample: DataFrame) -> None:
    bronze_table = bronze_table_name(table_name)

    if spark.catalog.tableExists(bronze_table):
        return

    logger.info(f"[{table_name}] Criando tabela Bronze change-log vazia: {bronze_table}")

    base_schema: StructType = df_sample.schema
    existing_cols = {f.name for f in base_schema.fields}

    metadata_fields = []

    if "bronze_load_ts" not in existing_cols:
        metadata_fields.append(
            StructField("bronze_load_ts", TimestampType(), nullable=False)
        )

    if "bronze_batch_id" not in existing_cols:
        metadata_fields.append(
            StructField("bronze_batch_id", StringType(), nullable=False)
        )

    if "origem_sistema" not in existing_cols:
        metadata_fields.append(
            StructField("origem_sistema", StringType(), nullable=False)
        )

    bronze_schema = StructType(list(base_schema.fields) + metadata_fields)

    empty_df = spark.createDataFrame([], bronze_schema)

    (
        empty_df.write
        .mode("overwrite")
        .format("delta")
        .saveAsTable(bronze_table)
    )

    logger.info(f"[{table_name}] Tabela Bronze criada como append-only change-log.")


def register_processed_batches(table_name: str, batches: List[Tuple[str, str]]) -> None:
    if not batches:
        return
    meta_table = meta_bronze_batches_table()
    now_ts = now_utc()
    rows = [
        (table_name, batch_id, path, now_ts, now_ts)
        for (batch_id, path) in batches
    ]
    df_meta = spark.createDataFrame(
        rows,
        ["tabela", "landing_batch_id", "landing_path", "first_seen_ts", "processed_ts"]
    )
    (
        df_meta.write
        .mode("append")
        .format("delta")
        .saveAsTable(meta_table)
    )
    logger.info(f"[{table_name}] Batches registrados como processados: {[b[0] for b in batches]}")


def merge_into_bronze_changelog(
    df: DataFrame,
    table_name: str,
    business_key_cols: List[str],
    batch_id: str,
) -> int:
    bronze_table = bronze_table_name(table_name)

    ensure_bronze_changelog_table(table_name, df)

    df_enriched = (
        df
        .withColumn("bronze_load_ts", F.lit(now_utc()))
        .withColumn("bronze_batch_id", F.lit(batch_id))
        .withColumn("origem_sistema", F.lit(ORIGEM_SISTEMA).cast("string"))
    )

    dedup_cols = business_key_cols + ["data_ref"]
    df_enriched = df_enriched.dropDuplicates(dedup_cols)

    total_to_merge = df_enriched.count()
    logger.info(f"[{table_name}] Registros a serem mesclados no Bronze: {total_to_merge}")

    if total_to_merge == 0:
        return 0

    if "change_op" in df_enriched.columns:
        dist = df_enriched.groupBy("change_op").count().collect()
        for row in dist:
            logger.info(
                f"[{table_name}] (bronze) change_op={row['change_op']} "
                f"-> {row['count']} registros"
            )

    cond_parts = [f"tgt.{col} = src.{col}" for col in business_key_cols]
    cond_parts.append("tgt.data_ref = src.data_ref")
    merge_condition = " AND ".join(cond_parts)

    delta_tbl = DeltaTable.forName(spark, bronze_table)

    (
        delta_tbl.alias("tgt")
        .merge(
            df_enriched.alias("src"),
            merge_condition
        )
        .whenNotMatchedInsertAll()
        .execute()
    )

    logger.info(f"[{table_name}] MERGE em Bronze change-log concluído.")
    return total_to_merge


def process_table(table_name: str) -> None:
    if table_name not in TABLE_CONFIGS:
        raise ValueError(f"Tabela '{table_name}' não está configurada em TABLE_CONFIGS.")

    business_key_cols = TABLE_CONFIGS[table_name]["business_key"]

    logger.info(f"================ INÍCIO BRONZE: {table_name} ================")

    df_src, total_inc, new_batches = read_landing_incremental(table_name)

    if df_src is None or total_inc == 0:
        logger.info(f"[{table_name}] Nenhum dado incremental para processar.")
        logger.info(f"================ FIM BRONZE (sem dados): {table_name} ================")
        return

    batch_id = datetime.now().strftime("%Y%m%d%H%M%S")

    merged = merge_into_bronze_changelog(df_src, table_name, business_key_cols, batch_id)

    if merged > 0:
        register_processed_batches(table_name, new_batches)

    logger.info(
        f"[{table_name}] Resumo Bronze: lidos_incremental={total_inc}, "
        f"mesclados_no_bronze={merged}, batch_id={batch_id}"
    )
    logger.info(f"================ FIM BRONZE: {table_name} ================")


def main(tables: Optional[List[str]] = None) -> None:
    init_schema(BRONZE_SCHEMA)
    init_meta_bronze_batches()

    if tables is None:
        tables = list(TABLE_CONFIGS.keys())

    for tbl in tables:
        process_table(tbl)


if __name__ == "__main__":
    main()
