In [0]:
%pip install polars

In [0]:
from datetime import datetime
import polars as pl
import pyarrow as pa
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql import functions as F

MAX_ITERATIONS = 20
TOLERANCE = 1e-4
CUSTOMER_COUNT = 10

In [0]:
spark.table("samples.tpcds_sf1000.customer").write.format("delta").mode("overwrite").save("/tmp/tpcds_sf1000/customer")
spark.table("samples.tpcds_sf1000.customer_demographics").write.format("delta").mode("overwrite").save("/tmp/tpcds_sf1000/customer_demographics")
spark.table("samples.tpcds_sf1000.customer_address").write.format("delta").mode("overwrite").save("/tmp/tpcds_sf1000/customer_address")

In [0]:
class RimWeightTPCDSBase:
    def __init__(
        self,
        unique_col: str,
        max_iterations: int = MAX_ITERATIONS,
        tolerance: float = TOLERANCE,
    ):
        self.unique_col = unique_col
        self.max_iterations = max_iterations
        self.tolerance = tolerance

        self.variables = [
            "cd_gender",
            "birth_bucket",
            "country_bucket",
        ]

        self.targets = {
            "cd_gender": {"M": 0.48, "F": 0.50, "U": 0.02},
            # "cd_credit_rating": {
            #     "Excellent": 0.25,
            #     "Good": 0.35,
            #     "Fair": 0.25,
            #     "Poor": 0.15,
            # },
            "birth_bucket": {
                "1900_1945": 0.10,
                "1946_1964": 0.30,
                "1965_1980": 0.25,
                "1981_1996": 0.25,
                "1997_plus": 0.10,
            },
            "country_bucket": {
                "United States": 0.70,
                "Canada": 0.15,
                "United Kingdom": 0.10,
                "Other": 0.05,
            },
        }

    def log(self, msg: str):
        print(f"[{datetime.now()}][RIM] {msg}")

    def get_target_map(self, var: str):
        return self.targets[var]

    def run_ipf(self, df):
        raise NotImplementedError("run_ipf must be implemented in subclasses.")

    def run(self, df):
        self.log("Старт RIM...")
        df_out = self.run_ipf(df)
        self.log("Завершено RIM.")
        return df_out

# Spark RIM

In [0]:
class RimWeightSparkTPCDS(RimWeightTPCDSBase):
    def __init__(self, unique_col="c_customer_sk", **kwargs):
        super().__init__(unique_col, **kwargs)
        spark.sparkContext.setCheckpointDir("/tmp/checkpoints")

    def prepare(self, df: DataFrame) -> DataFrame:
        self.log("Spark: bucketization + select")

        df = df.withColumn(
            "birth_bucket",
            F.when(F.col("c_birth_year") <= 1945, "1900_1945")
            .when(F.col("c_birth_year") <= 1964, "1946_1964")
            .when(F.col("c_birth_year") <= 1980, "1965_1980")
            .when(F.col("c_birth_year") <= 1996, "1981_1996")
            .otherwise("1997_plus")
        )

        df = df.withColumn(
            "country_bucket",
            F.when(
                F.col("ca_country").isin("United States", "Canada", "United Kingdom"),
                F.col("ca_country")
            ).otherwise("Other")
        )

        df = df.select(
            self.unique_col,
            "cd_gender",
            "birth_bucket",
            "country_bucket"
        )

        return df.withColumn("weight", F.lit(1.0))

    def build_broadcast_target_map(self, spark, var):
        """Returns broadcast dict: value -> target_share"""
        mp = self.get_target_map(var)
        return spark.sparkContext.broadcast(mp)

    def run_ipf(self, df: DataFrame):
        spark = df.sql_ctx.sparkSession

        spark.conf.set("spark.sql.adaptive.enabled", "true")
        spark.conf.set("spark.sql.shuffle.partitions", spark.sparkContext.defaultParallelism)

        df = self.prepare(df)
        df = df.repartition(spark.sparkContext.defaultParallelism)

        total_units = df.count()

        broadcast_targets = {
            var: self.build_broadcast_target_map(spark, var)
            for var in self.variables
        }

        for it in range(self.max_iterations):
            self.log(f"--- Spark IPF iteration {it+1} ---")
            max_diff = 0.0

            for var in self.variables:
                self.log(f"processing variable: {var}")

                total_w = df.agg(F.sum("weight")).first()[0]

                actual = (
                    df.groupBy(var)
                    .agg(F.sum("weight").alias("w_sum"))
                    .withColumn("actual_share", F.col("w_sum") / total_w)
                )

                actual_local = {
                    r[var]: r["actual_share"]
                    for r in actual.collect()
                }

                target_map = broadcast_targets[var].value

                for key, actual_val in actual_local.items():
                    target_val = target_map.get(key, actual_val)
                    max_diff = max(max_diff, abs(actual_val - target_val))

                factor_map = {
                    k: (target_map.get(k, v) / v) if v > 0 else 1.0
                    for k, v in actual_local.items()
                }

                factor_expr = F.when(F.col(var).isNull(), F.lit(1.0))

                for k, v in factor_map.items():
                    factor_expr = factor_expr.when(F.col(var) == F.lit(k), F.lit(v))

                factor_expr = factor_expr.otherwise(F.lit(1.0))

                df = df.withColumn(
                    "weight",
                    F.col("weight") * factor_expr
                )
                new_total = df.agg(F.sum("weight")).first()[0]
                df = df.withColumn("weight", F.col("weight") / new_total * total_units)

            self.log(f"iteration max_diff={max_diff}")

            df = df.checkpoint(eager=True)

            if max_diff <= self.tolerance:
                self.log("Converged.")
                break

        return df

In [0]:
customer = spark.table("samples.tpcds_sf1000.customer").limit(11579572)
demo = spark.table("samples.tpcds_sf1000.customer_demographics")
addr = spark.table("samples.tpcds_sf1000.customer_address")

df = (
    customer
    .join(demo, customer.c_current_cdemo_sk == demo.cd_demo_sk)
    .join(addr, customer.c_current_addr_sk == addr.ca_address_sk)
)

rim = RimWeightSparkTPCDS()
weighted = rim.run(df)

print(weighted.count())
display(weighted.limit(10))

# Polars RIM

In [0]:
class RimWeightPolarsTPCDS(RimWeightTPCDSBase):
    def __init__(self, unique_col="c_customer_sk", **kwargs):
        super().__init__(unique_col, **kwargs)

    def prepare(self, df: pl.DataFrame) -> pl.DataFrame:
            self.log("Polars: bucketization + select")

            needed = {
                "cd_gender": None,
                # "cd_credit_rating": None,
                "c_birth_year": None,
                "ca_country": None,
            }

            missing = [c for c in needed if c not in df.columns]
            if missing:
                raise ValueError(f"Missing required columns: {missing}")

            df = df.with_columns([
                pl.when(pl.col("c_birth_year") <= 1945).then(pl.lit("1900_1945"))
                .when(pl.col("c_birth_year") <= 1964).then(pl.lit("1946_1964"))
                .when(pl.col("c_birth_year") <= 1980).then(pl.lit("1965_1980"))
                .when(pl.col("c_birth_year") <= 1996).then(pl.lit("1981_1996"))
                .otherwise(pl.lit("1997_plus")).alias("birth_bucket"),

                pl.when(pl.col("ca_country").is_in(["United States","Canada","United Kingdom"]))
                .then(pl.col("ca_country"))
                .otherwise(pl.lit("Other")).alias("country_bucket"),

                pl.lit(1.0).alias("weight")
            ])

            return df.select(
                self.unique_col,
                "cd_gender",
                # "cd_credit_rating",
                "birth_bucket",
                "country_bucket",
                "weight",
            )

    def run_ipf(self, df: pl.DataFrame):
        df = self.prepare(df)
        for it in range(self.max_iterations):
            self.log(f"--- Polars IPF iteration {it+1} ---")

            max_diff = 0.0
            total_w = df["weight"].sum()

            for var in self.variables:
                self.log(f"processing variable: {var}")

                mp = self.get_target_map(var)
                target_df = pl.DataFrame({
                    var: list(mp.keys()),
                    "target_share": list(mp.values())
                })

                actual = (
                    df.group_by(var)
                      .agg(pl.sum("weight").alias("w_sum"))
                      .with_columns((pl.col("w_sum")/total_w).alias("actual_share"))
                )

                adj = actual.join(target_df, on=var, how="left").with_columns([
                    pl.col("target_share").fill_null(pl.col("actual_share")),
                    (pl.col("target_share") - pl.col("actual_share")).abs().alias("diff")
                ])

                diff = adj["diff"].max()
                max_diff = round(max(max_diff, diff), 4)

                factor_df = adj.with_columns([
                    pl.when((pl.col("actual_share")>0)&(pl.col("target_share")>0))
                      .then(pl.col("target_share")/pl.col("actual_share"))
                      .otherwise(1.0)
                      .alias("factor")
                ]).select(var,"factor")

                df = df.join(factor_df, on=var, how="left")
                df = df.with_columns((pl.col("weight")*pl.col("factor")).alias("weight"))
                df = df.drop("factor")

                new_total = df["weight"].sum()
                df = df.with_columns((pl.col("weight")/new_total * df.height).alias("weight"))

            self.log(f"iteration max_diff={max_diff}")
            if max_diff <= self.tolerance:
                self.log("Converged.")
                break

        return df

In [0]:
customer = pl.read_parquet("/dbfs/tmp/tpcds_sf1000/customer/*.parquet").limit(11579572)
demo = pl.read_parquet("/dbfs/tmp/tpcds_sf1000/customer_demographics/*.parquet")
addr = pl.read_parquet("/dbfs/tmp/tpcds_sf1000/customer_address/*.parquet")

df = (
    customer
    .join(demo, left_on="c_current_cdemo_sk", right_on="cd_demo_sk", how="inner")
    .join(addr, left_on="c_current_addr_sk", right_on="ca_address_sk", how="inner")
)

rim = RimWeightPolarsTPCDS()
result = rim.run(df)

print(result.count())
print(result)

# Hybrid RIM

In [0]:
class RimWeightHybridTPCDS(RimWeightTPCDSBase):
    def __init__(self, unique_col="c_customer_sk", **kwargs):
        super().__init__(unique_col, **kwargs)

    @staticmethod
    def spark_to_polars(df_spark: DataFrame) -> pl.DataFrame:
        tbl = pa.Table.from_batches(df_spark._collect_as_arrow())
        return pl.from_arrow(tbl)

    @staticmethod
    def polars_to_spark(spark, df_polars: pl.DataFrame) -> DataFrame:
        return spark.createDataFrame(df_polars.to_pandas())

    def prepare_spark(self, df: DataFrame) -> DataFrame:
        self.log("Hybrid: Spark bucketization")

        df = df.withColumn(
            "birth_bucket",
            F.when(F.col("c_birth_year")<=1945,"1900_1945")
             .when(F.col("c_birth_year")<=1964,"1946_1964")
             .when(F.col("c_birth_year")<=1980,"1965_1980")
             .when(F.col("c_birth_year")<=1996,"1981_1996")
             .otherwise("1997_plus")
        )

        df = df.withColumn(
            "country_bucket",
            F.when(
                F.col("ca_country").isin("United States","Canada","United Kingdom"),
                F.col("ca_country")
            ).otherwise("Other")
        )

        df = df.select(
            self.unique_col,
            "cd_gender",
            # "cd_credit_rating",
            "birth_bucket",
            "country_bucket"
        )

        return df

    def run_ipf_polars(self, df_pl: pl.DataFrame) -> pl.DataFrame:
        self.log("Hybrid: Polars IPF start")
        df_pl = df_pl.with_columns(pl.lit(1.0).alias("weight"))

        for it in range(self.max_iterations):
            self.log(f"--- Polars IPF iteration {it+1} ---")
            max_diff = 0.0
            total_w = df_pl["weight"].sum()

            for var in self.variables:
                self.log(f"processing var: {var}")

                mp = self.get_target_map(var)
                target_df = pl.DataFrame({
                    var: list(mp.keys()),
                    "target_share": list(mp.values())
                })

                actual = (
                    df_pl.group_by(var)
                        .agg(pl.sum("weight").alias("w_sum"))
                        .with_columns((pl.col("w_sum")/total_w).alias("actual_share"))
                )

                adj = actual.join(target_df, on=var, how="left").with_columns([
                    pl.col("target_share").fill_null(pl.col("actual_share")),
                    (pl.col("target_share")-pl.col("actual_share")).abs().alias("diff")
                ])

                diff = adj["diff"].max()
                max_diff = max(max_diff, diff)

                factor_df = adj.with_columns([
                    pl.when((pl.col("actual_share")>0)&(pl.col("target_share")>0))
                      .then(pl.col("target_share")/pl.col("actual_share"))
                      .otherwise(1.0)
                      .alias("factor")
                ]).select(var,"factor")

                df_pl = df_pl.join(factor_df, on=var, how="left")
                df_pl = df_pl.with_columns((pl.col("weight")*pl.col("factor")).alias("weight"))
                df_pl = df_pl.drop("factor")

                new_total = df_pl["weight"].sum()
                df_pl = df_pl.with_columns((pl.col("weight")/new_total * df_pl.height).alias("weight"))

            self.log(f"iteration max_diff={max_diff}")
            if max_diff <= self.tolerance:
                self.log("Polars converged.")
                break

        return df_pl

    def run_ipf(self, df_spark: DataFrame):
        spark = df_spark.sql_ctx.sparkSession

        df_prep = self.prepare_spark(df_spark)
        df_pl = self.spark_to_polars(df_prep)

        df_pl_weighted = self.run_ipf_polars(df_pl)

        df_final = self.polars_to_spark(spark, df_pl_weighted)
        return df_final

In [0]:
customer = spark.table("samples.tpcds_sf1000.customer").limit(11579572)
demo = spark.table("samples.tpcds_sf1000.customer_demographics")
addr = spark.table("samples.tpcds_sf1000.customer_address")

df = (
    customer
    .join(demo, customer.c_current_cdemo_sk == demo.cd_demo_sk)
    .join(addr, customer.c_current_addr_sk == addr.ca_address_sk)
)

rim = RimWeightHybridTPCDS()
result = rim.run(df)
print(result.count())
display(result.limit(10))