In [0]:
from pyspark.sql import functions as F
from typing import List, Optional, Dict

BRONZE_TABLES = {
    "yellow": "workspace.nyc_taxi.yellow_trips_bronze",
    "green":  "workspace.nyc_taxi.green_trips_bronze",
    "fhv":    "workspace.nyc_taxi.fhv_trips_bronze",
    "fhvhv":  "workspace.nyc_taxi.fhvhv_trips_bronze",
}

SILVER_TABLES = {
    "yellow": "workspace.nyc_taxi.yellow_trips_silver",
    "green":  "workspace.nyc_taxi.green_trips_silver",
    "fhv":    "workspace.nyc_taxi.fhv_trips_silver",
    "fhvhv":  "workspace.nyc_taxi.fhvhv_trips_silver",
}


CFG: Dict[str, Dict] = {
    "yellow": {
        "anomes_col": "anomes",
        "essential_cols": ["anomes", "vendorid", "tpep_pickup_datetime", "tpep_dropoff_datetime"],
        "pickup_col": "tpep_pickup_datetime",
        "dropoff_col": "tpep_dropoff_datetime",
        "passenger_count_col": "passenger_count",
    },
    "green": {
        "anomes_col": "anomes",
        "essential_cols": ["anomes", "vendorid", "lpep_pickup_datetime", "lpep_dropoff_datetime"],
        "pickup_col": "lpep_pickup_datetime",
        "dropoff_col": "lpep_dropoff_datetime",
        "passenger_count_col": "passenger_count",
    },
    "fhv": {
        "anomes_col": "anomes",
        "essential_cols": ["anomes", "pickup_datetime", "dropoff_datetime"],
        "pickup_col": "pickup_datetime",
        "dropoff_col": "dropoff_datetime",
        "passenger_count_col": None,  # não existe em FHV
    },
    "fhvhv": {
        "anomes_col": "anomes",
        "essential_cols": ["anomes", "pickup_datetime", "dropoff_datetime"],
        "pickup_col": "pickup_datetime",
        "dropoff_col": "dropoff_datetime",
        "passenger_count_col": None,  # não existe em FHVHV
    },
}

# ---------------------------------------------
# Funções genéricas (sem valores fixos internos)
# ---------------------------------------------
def remove_null_essentials(df, essential_cols: List[str]) -> (int, "DataFrame"):
    """Remove linhas com NULL em QUALQUER coluna essencial."""
    before = df.count()
    cond = None
    for c in essential_cols:
        expr = F.col(c).isNotNull()
        cond = expr if cond is None else (cond & expr)
    df = df.filter(cond) if cond is not None else df
    after = df.count()
    return (before - after), df

def remove_small_partitions(df, anomes_col: str, min_pct_of_max: float = 0.05) -> (int, "DataFrame"):
    """Remove partições com contagem < (min_pct_of_max * max_count)."""
    before = df.count()
    counts = df.groupBy(anomes_col).count()
    max_count = counts.agg(F.max("count")).collect()[0][0]
    if max_count is None or max_count == 0:
        # nada a fazer
        return 0, df
    threshold = max_count * min_pct_of_max
    valid = counts.filter(F.col("count") >= F.lit(threshold)).select(anomes_col)
    df = df.join(valid, on=anomes_col, how="inner")
    after = df.count()
    return (before - after), df

def remove_bad_time_order(df, pickup_col: str, dropoff_col: str) -> (int, "DataFrame"):
    """
    Remove registros com ordem temporal inválida.
    Assumido aqui: **inválido** quando dropoff < pickup.
    (Se você REALMENTE quiser remover pickup < dropoff, inverta o comparador abaixo.)
    """
    before = df.count()
    df = df.filter(F.col(dropoff_col) >= F.col(pickup_col))
    after = df.count()
    return (before - after), df

def remove_nonpositive_passengers(df, passenger_count_col: Optional[str]) -> (int, "DataFrame"):
    """Remove linhas com passenger_count <= 0 (aplicado somente se a coluna existir)."""
    if not passenger_count_col or passenger_count_col not in df.columns:
        return 0, df
    before = df.count()
    df = df.filter(F.col(passenger_count_col) > F.lit(0))
    after = df.count()
    return (before - after), df

def write_delta(df, table: str, anomes_col: str) -> None:
    """Escreve Silver (Delta) particionada por anomes, sobrescrevendo."""
    spark.sql(f"DROP TABLE IF EXISTS {table}")
    (df.repartition(anomes_col)
       .write
       .format("delta")
       .mode("overwrite")
       .partitionBy(anomes_col)
       .saveAsTable(table))

def run_silver_simple(
    bronze_table: str,
    silver_table: str,
    anomes_col: str,
    essential_cols: List[str],
    pickup_col: str,
    dropoff_col: str,
    passenger_count_col: Optional[str],
    small_part_pct: float = 0.05
) -> None:
    """Executa o pipeline simples de limpeza + escrita da Silver e imprime as métricas de remoção."""
    print(f"\n===== SILVER: {silver_table} =====")
    df = spark.table(bronze_table)
    total_ini = df.count()
    print(f"📥 Bronze lida: {bronze_table} | Linhas iniciais: {total_ini:,}")

    removed_ess, df = remove_null_essentials(df, essential_cols)
    print(f"🧹 Removidas por colunas essenciais nulas: {removed_ess:,}")

    removed_part, df = remove_small_partitions(df, anomes_col, min_pct_of_max=small_part_pct)
    print(f"🧱 Removidas por partições < {int(small_part_pct*100)}% do pico: {removed_part:,}")

    removed_time, df = remove_bad_time_order(df, pickup_col, dropoff_col)
    print(f"⏱️ Removidas por ordem temporal inválida (dropoff < pickup): {removed_time:,}")

    removed_pass, df = remove_nonpositive_passengers(df, passenger_count_col)
    if passenger_count_col and passenger_count_col in df.columns:
        print(f"🧍 Removidas por passenger_count ≤ 0: {removed_pass:,}")
    else:
        print(f"🧍 Coluna de passageiros ausente – regra ignorada.")

    total_fim = df.count()
    print(f"✅ Linhas finais após limpeza: {total_fim:,} (removidas no total: {(total_ini - total_fim):,})")

    write_delta(df, silver_table, anomes_col)
    print(f"💾 Silver gravada: {silver_table}")

# ---------------------------------------------
# Execução por categoria (ligue/desligue à vontade)
# ---------------------------------------------

# # Yellow
# run_silver_simple(
#     bronze_table=BRONZE_TABLES["yellow"],
#     silver_table=SILVER_TABLES["yellow"],
#     anomes_col=CFG["yellow"]["anomes_col"],
#     essential_cols=CFG["yellow"]["essential_cols"],
#     pickup_col=CFG["yellow"]["pickup_col"],
#     dropoff_col=CFG["yellow"]["dropoff_col"],
#     passenger_count_col=CFG["yellow"]["passenger_count_col"],
#     small_part_pct=0.05
# )

# # Green
# run_silver_simple(
#     bronze_table=BRONZE_TABLES["green"],
#     silver_table=SILVER_TABLES["green"],
#     anomes_col=CFG["green"]["anomes_col"],
#     essential_cols=CFG["green"]["essential_cols"],
#     pickup_col=CFG["green"]["pickup_col"],
#     dropoff_col=CFG["green"]["dropoff_col"],
#     passenger_count_col=CFG["green"]["passenger_count_col"],
#     small_part_pct=0.05
# )

# FHV
run_silver_simple(
    bronze_table=BRONZE_TABLES["fhv"],
    silver_table=SILVER_TABLES["fhv"],
    anomes_col=CFG["fhv"]["anomes_col"],
    essential_cols=CFG["fhv"]["essential_cols"],
    pickup_col=CFG["fhv"]["pickup_col"],
    dropoff_col=CFG["fhv"]["dropoff_col"],
    passenger_count_col=CFG["fhv"]["passenger_count_col"],
    small_part_pct=0.05
)

# FHVHV
run_silver_simple(
    bronze_table=BRONZE_TABLES["fhvhv"],
    silver_table=SILVER_TABLES["fhvhv"],
    anomes_col=CFG["fhvhv"]["anomes_col"],
    essential_cols=CFG["fhvhv"]["essential_cols"],
    pickup_col=CFG["fhvhv"]["pickup_col"],
    dropoff_col=CFG["fhvhv"]["dropoff_col"],
    passenger_count_col=CFG["fhvhv"]["passenger_count_col"],
    small_part_pct=0.05
)
