<a href="https://colab.research.google.com/github/LucasMirandaVS/estudos_python/blob/main/Data_modeling_modular.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from pyspark.sql import functions as F
from pyspark.sql.window import Window


# Config
INPUT_PATH  = "/content/raw_data_test_version.csv"
OUTPUT_PATH = "ace_report_enriched_sparke.csv"

COL_ENTRY_NUM = "Entry Summary Number"
COL_LINE_NUM  = "Entry Summary Line Number"

TARIFF_COLS = [
    "Line Tariff Duty Amount",
    "Line MPF Amount",
    "Line HMF Amount",
    "Antidumping Duty Amount",
    "Countervailing Duty Amount",
]


DROP_EXTRA_COLS = [
    "Line Tariff Goods Value Amount",
    "Line Tariff Quantity (1)",
    "Line Tariff UOM (1) Code",
    "Line Tariff Quantity (2)",
    "Line Tariff UOM (2) Code",
    "Line Tariff Quantity (3)",
    "Line Tariff UOM (3) Code",
]

"""
Function meant to perform the following transformation:
0. Load data
1. Add entry summary code
2. Cast columns to double
3. Aggregate columns
4. Join aggregates
5. Add aggregated columns
6. Drop unneeded columns
"""

def ace_report_transform(
    input_path,
    output_path,
    col_entry_num,
    col_line_num,
    tariff_cols,
    drop_extra_cols
):
    def rm_dotzero(colname):
        return F.regexp_replace(F.trim(F.col(colname).cast("string")), r"\\.0$", "")

    # Step 0: Load data
    df = spark.read.option("header", True).option("inferSchema", True).csv(input_path)

    # Step 1: Add entry summary code
    df = df.withColumn(
        "Entry summary number code",
        F.concat_ws("-", rm_dotzero(col_entry_num), rm_dotzero(col_line_num))
    )

    # Step 2: Cast columns to double
    for c in tariff_cols:
        if c in df.columns:
            df = df.withColumn(c, F.col(c).cast("double"))
    if col_line_num in df.columns:
        df = df.withColumn(col_line_num, F.col(col_line_num).cast("double"))

    # Step 3: Aggregate columns
    agg_df = df.groupBy("Entry summary number code").agg(
        *[F.sum(F.col(c)).alias(f"{c}__agg") for c in tariff_cols]
    )

    # Step 4: Join aggregates
    df = df.join(agg_df, on="Entry summary number code", how="left")

    # Step 5: Add aggregated columns
    w = Window.partitionBy("Entry summary number code").orderBy(F.col(col_line_num).asc())
    is_last = F.lead(F.col("Entry summary number code")).over(w).isNull()
    for c in tariff_cols:
        df = df.withColumn(
            f"Aggregated {c}",
            F.when(is_last, F.col(f"{c}__agg")).otherwise(F.lit(0.0))
        )

    # Step 6: Drop unneeded columns
    cols_to_drop = [c for c in tariff_cols if c in df.columns] \
                 + [f"{c}__agg" for c in tariff_cols if f"{c}__agg" in df.columns] \
                 + [c for c in drop_extra_cols if c in df.columns]
    for c in cols_to_drop:
        df = df.drop(c)

    # Step 7: Filter rows where all aggregated values are zero
    aggregated_cols = [f"Aggregated {c}" for c in tariff_cols if f"Aggregated {c}" in df.columns]
    if aggregated_cols:
        non_zero_cond = None
        for c in aggregated_cols:
            cond = (F.col(c) != 0)
            non_zero_cond = cond if non_zero_cond is None else (non_zero_cond | cond)
        df = df.filter(non_zero_cond)

    # Step 8: Write output
    df.coalesce(1).write.mode("overwrite").option("header", True).csv(output_path)
    return df

result_df = ace_report_transform(
    input_path="/content/raw_data_test_version.csv",
    output_path="ace_report_enriched_sparke.csv",
    col_entry_num="Entry Summary Number",
    col_line_num="Entry Summary Line Number",
    tariff_cols=[
        "Line Tariff Duty Amount",
        "Line MPF Amount",
        "Line HMF Amount",
        "Antidumping Duty Amount",
        "Countervailing Duty Amount",
    ],
    drop_extra_cols=[
        "Line Tariff Goods Value Amount",
        "Line Tariff Quantity (1)",
        "Line Tariff UOM (1) Code",
        "Line Tariff Quantity (2)",
        "Line Tariff UOM (2) Code",
        "Line Tariff Quantity (3)",
        "Line Tariff UOM (3) Code",
    ]
)


