In [0]:
%pip install polars

In [0]:
from abc import ABC, abstractmethod
from datetime import datetime

import polars as pl
import pyarrow as pa
from polars import DataFrame as PolarsDataFrame
from pyspark.sql import functions as f, DataFrame as SparkDataFrame, SparkSession

MAX_ITERATIONS = 20
TOLERANCE = 1e-4
CUSTOMER_COUNT = 10

SPARK_OR_POLARS_DF = SparkDataFrame | PolarsDataFrame
spark = SparkSession.getActiveSession()

In [0]:
class RimWeightTPCDSBase(ABC):
    def __init__(
            self,
            unique_col: str,
            max_iterations: int = MAX_ITERATIONS,
            tolerance: float = TOLERANCE,
    ):
        self.unique_col = unique_col
        self.max_iterations = max_iterations
        self.tolerance = tolerance

        self.variables = [
            "cd_gender",
            "birth_bucket",
            "country_bucket",
        ]

        self.targets = {
            "cd_gender": {"M": 0.48, "F": 0.50, "U": 0.02},
            "birth_bucket": {
                "1900_1945": 0.10,
                "1946_1964": 0.30,
                "1965_1980": 0.25,
                "1981_1996": 0.25,
                "1997_plus": 0.10,
            },
            "country_bucket": {
                "United States": 0.70,
                "Canada": 0.15,
                "United Kingdom": 0.10,
                "Other": 0.05,
            },
        }

    @staticmethod
    def log(msg: str):
        print(f"[{datetime.now()}][RIM] {msg}")

    def get_target_map(self, var: str):
        return self.targets[var]

    @abstractmethod
    def prepare(self, raw_df: SPARK_OR_POLARS_DF) -> SPARK_OR_POLARS_DF:
        raise NotImplementedError("`prepare` method must be implemented in subclasses.")

    @abstractmethod
    def run_ipf(self, raw_df: SPARK_OR_POLARS_DF) -> SPARK_OR_POLARS_DF:
        raise NotImplementedError("`run_ipf` method must be implemented in subclasses.")

    def run(self, raw_df: SPARK_OR_POLARS_DF) -> SPARK_OR_POLARS_DF:
        self.log("Старт RIM...")
        df_out = self.run_ipf(raw_df)
        self.log("Завершено RIM.")
        return df_out

# Spark RIM

In [0]:
from pyspark import Broadcast


class RimWeightSparkTPCDS(RimWeightTPCDSBase):
    def __init__(self, spark: SparkSession, unique_col="c_customer_sk", **kwargs):
        super().__init__(unique_col, **kwargs)
        self._spark = spark

        self._spark.sparkContext.setCheckpointDir("/tmp/checkpoints")
        self._spark.conf.set("spark.sql.adaptive.enabled", "true")
        self._spark.conf.set("spark.sql.shuffle.partitions", self._spark.sparkContext.defaultParallelism)

    def prepare(self, raw_df: SparkDataFrame) -> SparkDataFrame:
        self.log("Spark: bucketization + select")

        raw_df = raw_df.withColumn(
            "birth_bucket",
            f.when(f.col("c_birth_year") <= 1945, "1900_1945")
            .when(f.col("c_birth_year") <= 1964, "1946_1964")
            .when(f.col("c_birth_year") <= 1980, "1965_1980")
            .when(f.col("c_birth_year") <= 1996, "1981_1996")
            .otherwise("1997_plus")
        )

        raw_df = raw_df.withColumn(
            "country_bucket",
            f.when(
                f.col("ca_country").isin("United States", "Canada", "United Kingdom"),
                f.col("ca_country")
            ).otherwise("Other")
        )

        return raw_df.select(
            self.unique_col,
            "cd_gender",
            "birth_bucket",
            "country_bucket",
            f.lit(1.0).alias("weight")
        )

    def _build_broadcast_target_map(self, var: str) -> Broadcast:
        """Returns broadcast dict: value -> target_share"""

        mp = self.get_target_map(var)
        return self._spark.sparkContext.broadcast(mp)

    def run_ipf(self, raw_df: SparkDataFrame) -> SparkDataFrame:
        prepared_df = self.prepare(raw_df).repartition(self._spark.sparkContext.defaultParallelism)
        total_units = prepared_df.count()

        broadcast_targets = {
            var: self._build_broadcast_target_map(var)
            for var in self.variables
        }

        for it in range(self.max_iterations):
            self.log(f"--- Spark IPF iteration {it + 1} ---")
            max_diff = 0.0

            for var in self.variables:
                self.log(f"processing variable: {var}")

                total_w = prepared_df.agg(f.sum("weight")).first()[0]

                actual = (
                    prepared_df.groupBy(var)
                    .agg(f.sum("weight").alias("w_sum"))
                    .withColumn("actual_share", f.col("w_sum") / total_w)
                )

                actual_local = {
                    r[var]: r["actual_share"]
                    for r in actual.collect()
                }

                target_map = broadcast_targets[var].value

                for key, actual_val in actual_local.items():
                    target_val = target_map.get(key, actual_val)
                    max_diff = max(max_diff, abs(actual_val - target_val))

                factor_map = {
                    k: (target_map.get(k, v) / v) if v > 0 else 1.0
                    for k, v in actual_local.items()
                }

                factor_expr = f.when(f.col(var).isNull(), f.lit(1.0))

                for k, v in factor_map.items():
                    factor_expr = factor_expr.when(f.col(var) == f.lit(k), f.lit(v))

                factor_expr = factor_expr.otherwise(f.lit(1.0))

                prepared_df = prepared_df.withColumn(
                    "weight",
                    f.col("weight") * factor_expr
                )
                new_total = prepared_df.agg(f.sum("weight")).first()[0]
                prepared_df = prepared_df.withColumn("weight", f.col("weight") / new_total * total_units)

            self.log(f"iteration max_diff={max_diff}")

            prepared_df = prepared_df.checkpoint(eager=True)

            if max_diff <= self.tolerance:
                self.log("Converged.")
                break

        return prepared_df

In [0]:
customer = spark.table("samples.tpcds_sf1000.customer").limit(CUSTOMER_COUNT)
demo = spark.table("samples.tpcds_sf1000.customer_demographics")
addr = spark.table("samples.tpcds_sf1000.customer_address")

df = (
    customer
    .join(demo, customer.c_current_cdemo_sk == demo.cd_demo_sk)
    .join(addr, customer.c_current_addr_sk == addr.ca_address_sk)
)

rim = RimWeightSparkTPCDS(spark)
weighted = rim.run(df)

print(weighted.count())
display(weighted.limit(10))

# Polars RIM

In [None]:
spark.table("samples.tpcds_sf1000.customer").write.format("delta").mode("overwrite").save("/tmp/tpcds_sf1000/customer")
spark.table("samples.tpcds_sf1000.customer_demographics").write.format("delta").mode("overwrite").save("/tmp/tpcds_sf1000/customer_demographics")
spark.table("samples.tpcds_sf1000.customer_address").write.format("delta").mode("overwrite").save("/tmp/tpcds_sf1000/customer_address")

In [0]:
class RimWeightPolarsTPCDS(RimWeightTPCDSBase):
    def __init__(self, unique_col="c_customer_sk", **kwargs):
        super().__init__(unique_col, **kwargs)

    def prepare(self, raw_df: PolarsDataFrame) -> PolarsDataFrame:
        self.log("Polars: bucketization + select")

        needed = {
            "cd_gender": None,
            "c_birth_year": None,
            "ca_country": None,
        }

        missing = [c for c in needed if c not in raw_df.columns]
        if missing:
            raise ValueError(f"Missing required columns: {missing}")

        raw_df = raw_df.with_columns([
            pl.when(pl.col("c_birth_year") <= 1945).then(pl.lit("1900_1945"))
            .when(pl.col("c_birth_year") <= 1964).then(pl.lit("1946_1964"))
            .when(pl.col("c_birth_year") <= 1980).then(pl.lit("1965_1980"))
            .when(pl.col("c_birth_year") <= 1996).then(pl.lit("1981_1996"))
            .otherwise(pl.lit("1997_plus")).alias("birth_bucket"),
            pl.when(pl.col("ca_country").is_in(["United States", "Canada", "United Kingdom"]))
            .then(pl.col("ca_country"))
            .otherwise(pl.lit("Other")).alias("country_bucket"),
            pl.lit(1.0).alias("weight")
        ])

        return raw_df.select(
            self.unique_col,
            "cd_gender",
            "birth_bucket",
            "country_bucket",
            "weight",
        )

    def run_ipf(self, raw_df: PolarsDataFrame) -> PolarsDataFrame:
        prepared_df = self.prepare(raw_df)
        for it in range(self.max_iterations):
            self.log(f"--- Polars IPF iteration {it + 1} ---")

            max_diff = 0.0
            total_w = prepared_df["weight"].sum()

            for var in self.variables:
                self.log(f"processing variable: {var}")

                mp = self.get_target_map(var)
                target_df = pl.DataFrame({
                    var: list(mp.keys()),
                    "target_share": list(mp.values())
                })

                actual = (
                    prepared_df.group_by(var)
                    .agg(pl.sum("weight").alias("w_sum"))
                    .with_columns((pl.col("w_sum") / total_w).alias("actual_share"))
                )

                adj = actual.join(target_df, on=var, how="left").with_columns([
                    pl.col("target_share").fill_null(pl.col("actual_share")),
                    (pl.col("target_share") - pl.col("actual_share")).abs().alias("diff")
                ])

                diff = adj["diff"].max()
                max_diff = round(max(max_diff, diff), 4)

                factor_df = adj.with_columns([
                    pl.when((pl.col("actual_share") > 0) & (pl.col("target_share") > 0))
                    .then(pl.col("target_share") / pl.col("actual_share"))
                    .otherwise(1.0)
                    .alias("factor")
                ]).select(var, "factor")

                prepared_df = prepared_df.join(factor_df, on=var, how="left")
                prepared_df = prepared_df.with_columns((pl.col("weight") * pl.col("factor")).alias("weight")).drop("factor")

                new_total = prepared_df["weight"].sum()
                prepared_df = prepared_df.with_columns((pl.col("weight") / new_total * prepared_df.height).alias("weight"))

            self.log(f"iteration max_diff={max_diff}")
            if max_diff <= self.tolerance:
                self.log("Converged.")
                break

        return prepared_df

In [0]:
customer = pl.read_parquet("/dbfs/tmp/tpcds_sf1000/customer/*.parquet").limit(CUSTOMER_COUNT)
demo = pl.read_parquet("/dbfs/tmp/tpcds_sf1000/customer_demographics/*.parquet")
addr = pl.read_parquet("/dbfs/tmp/tpcds_sf1000/customer_address/*.parquet")

df = (
    customer
    .join(demo, left_on="c_current_cdemo_sk", right_on="cd_demo_sk", how="inner")
    .join(addr, left_on="c_current_addr_sk", right_on="ca_address_sk", how="inner")
)

rim = RimWeightPolarsTPCDS()
result = rim.run(df)

print(result.count())
print(result)

# Hybrid RIM

In [0]:
class RimWeightHybridTPCDS(RimWeightTPCDSBase):
    def __init__(self, spark: SparkSession, unique_col="c_customer_sk", **kwargs):
        super().__init__(unique_col, **kwargs)
        self._spark = spark

    @staticmethod
    def spark_to_polars(df_spark: SparkDataFrame) -> PolarsDataFrame:
        tbl = pa.Table.from_batches(df_spark._collect_as_arrow())
        return pl.from_arrow(tbl)

    def polars_to_spark(self, df_polars: PolarsDataFrame) -> SparkDataFrame:
        return self._spark.createDataFrame(df_polars.to_pandas())

    def prepare(self, raw_df: SparkDataFrame) -> SparkDataFrame:
        self.log("Hybrid: Spark bucketization")

        raw_df = raw_df.withColumn(
            "birth_bucket",
            f.when(f.col("c_birth_year") <= 1945, "1900_1945")
            .when(f.col("c_birth_year") <= 1964, "1946_1964")
            .when(f.col("c_birth_year") <= 1980, "1965_1980")
            .when(f.col("c_birth_year") <= 1996, "1981_1996")
            .otherwise("1997_plus")
        )

        raw_df = raw_df.withColumn(
            "country_bucket",
            f.when(
                f.col("ca_country").isin("United States", "Canada", "United Kingdom"),
                f.col("ca_country")
            ).otherwise("Other")
        )

        return raw_df.select(
            self.unique_col,
            "cd_gender",
            "birth_bucket",
            "country_bucket"
        )

    def run_ipf(self, raw_df: PolarsDataFrame) -> PolarsDataFrame:
        self.log("Hybrid: Polars IPF start")
        df_pl = raw_df.with_columns(pl.lit(1.0).alias("weight"))

        for it in range(self.max_iterations):
            self.log(f"--- Polars IPF iteration {it + 1} ---")
            max_diff = 0.0
            total_w = df_pl["weight"].sum()

            for var in self.variables:
                self.log(f"processing var: {var}")

                mp = self.get_target_map(var)
                target_df = pl.DataFrame({
                    var: list(mp.keys()),
                    "target_share": list(mp.values())
                })

                actual = (
                    df_pl.group_by(var)
                    .agg(pl.sum("weight").alias("w_sum"))
                    .with_columns((pl.col("w_sum") / total_w).alias("actual_share"))
                )

                adj = actual.join(target_df, on=var, how="left").with_columns([
                    pl.col("target_share").fill_null(pl.col("actual_share")),
                    (pl.col("target_share") - pl.col("actual_share")).abs().alias("diff")
                ])

                diff = adj["diff"].max()
                max_diff = max(max_diff, diff)

                factor_df = adj.with_columns([
                    pl.when((pl.col("actual_share") > 0) & (pl.col("target_share") > 0))
                    .then(pl.col("target_share") / pl.col("actual_share"))
                    .otherwise(1.0)
                    .alias("factor")
                ]).select(var, "factor")

                df_pl = df_pl.join(factor_df, on=var, how="left")
                df_pl = df_pl.with_columns((pl.col("weight") * pl.col("factor")).alias("weight"))
                df_pl = df_pl.drop("factor")

                new_total = df_pl["weight"].sum()
                df_pl = df_pl.with_columns((pl.col("weight") / new_total * df_pl.height).alias("weight"))

            self.log(f"iteration max_diff={max_diff}")
            if max_diff <= self.tolerance:
                self.log("Polars converged.")
                break

        return df_pl

    def run(self, df_spark: SparkDataFrame) -> SparkDataFrame:
        df_prep = self.prepare(df_spark)
        df_pl = self.spark_to_polars(df_prep)
        df_out = self.run(df_pl)
        return self.polars_to_spark(df_out)

In [0]:
customer = spark.table("samples.tpcds_sf1000.customer").limit(CUSTOMER_COUNT)
demo = spark.table("samples.tpcds_sf1000.customer_demographics")
addr = spark.table("samples.tpcds_sf1000.customer_address")

df = (
    customer
    .join(demo, customer.c_current_cdemo_sk == demo.cd_demo_sk)
    .join(addr, customer.c_current_addr_sk == addr.ca_address_sk)
)

rim = RimWeightHybridTPCDS(spark)
result = rim.run(df)
print(result.count())
display(result.limit(10))