In [None]:
import pytest
from pyspark.sql import SparkSession, functions as F

@pytest.fixture(scope="session")
def spark():
    return SparkSession.builder.master("local[2]").getOrCreate()

def test_to_usd_and_missing_rate(spark):
    exch = spark.createDataFrame([("EUR", 1.1), ("USD", 1.0)], "Currency string, AvgRate double")
    df = spark.createDataFrame([
        ("P1","EUR",100.0), ("P2","USD",50.0), ("P3","GBP",10.0)
    ], "ProjectID string, Currency string, Price double")

    d = {r["Currency"]: r["AvgRate"] for r in exch.collect()}
    from pyspark.sql.functions import udf, lit
    from pyspark.sql.types import DoubleType, StringType

    @udf(DoubleType())
    def to_usd(price, currency):
        if currency is None or price is None: return None
        rate = d.get(currency)
        return price*rate if rate is not None else None

    @udf(StringType())
    def usd_flag(currency):
        return "OK" if currency in d else "MISSING_RATE"

    out = df.withColumn("Price_USD", to_usd("Price","Currency")) \
            .withColumn("USD_Rate_Status", usd_flag("Currency"))

    rows = {r["ProjectID"]: (r["Price_USD"], r["USD_Rate_Status"]) for r in out.collect()}
    assert rows["P1"] == (110.0, "OK")
    assert rows["P2"] == (50.0, "OK")
    assert rows["P3"][0] is None and rows["P3"][1] == "MISSING_RATE"

def test_integrity_flags_written(spark):
    cons = spark.createDataFrame([
        ("A","Before_RtB","Cat1",None),
        ("A","After_RtB","Cat1",1000.0),
        ("B","After_RtB","CatX",500.0),     
        ("C","Before_RtB","Cat1",None),
        ("C","Before_RtB","Cat2",None),     
        ("D","After_RtB","Cat1",100.0),
        ("D","After_RtB","Cat1",200.0),     
    ], "Project ID string, Project phase string, Material Category string, RtB Budget double")

    import pyspark.sql.functions as F
    issues = []

    phases = cons.select("Project ID","Project phase").dropDuplicates()
    proj_phase = phases.groupBy("Project ID").agg(F.collect_set("Project phase").alias("phases"))
    after_without_before = proj_phase.where(F.array_contains("phases","After_RtB") & ~F.array_contains("phases","Before_RtB"))
    if after_without_before.count()>0:
        issues.append(after_without_before.withColumn("issue", F.lit("PHASE_AFTER_WITHOUT_BEFORE")))

    matcat = cons.select("Project ID","Material Category").dropDuplicates()
    dup_matcat = matcat.groupBy("Project ID").agg(F.countDistinct("Material Category").alias("n")).where("n>1")
    if dup_matcat.count()>0:
        issues.append(dup_matcat.withColumn("issue", F.lit("MULTIPLE_MATCAT_PER_PROJECT")))

    budget = (cons.where((F.col("Project phase")=="After_RtB") & F.col("RtB Budget").isNotNull())
                   .select("Project ID","RtB Budget").dropDuplicates())
    dup_budget = budget.groupBy("Project ID").agg(F.countDistinct("RtB Budget").alias("n")).where("n>1")
    if dup_budget.count()>0:
        issues.append(dup_budget.withColumn("issue", F.lit("MULTIPLE_RTB_BUDGET_PER_PROJECT")))

    assert len(issues) == 3  # los tres chequeos encontraron algo
