In [2]:
# Omit this step if you already have SparkContext available
import findspark
findspark.init()

In [33]:
import logging
from pyspark.sql import DataFrame
import pyspark.sql.functions as sf
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.types import IntegerType


def generate_random_uniform_df(nrows, ncols, seed=1):
    df = spark.range(nrows).select(sf.col("id"))
    df = df.select('*', *(F.rand(seed).alias("_"+str(target)) for target in range(ncols)))
    return df.drop("id")

In [15]:
spark = SparkSession.builder.getOrCreate()

In [16]:
spark

# Extract Exact size of a DataFrame

In [17]:
df_small = generate_random_uniform_df(10000, 10)

In [20]:
df_small.show(n=1, vertical=True)

-RECORD 0-----------------
 _0  | 0.6363787615254752 
 _1  | 0.6363787615254752 
 _2  | 0.6363787615254752 
 _3  | 0.6363787615254752 
 _4  | 0.6363787615254752 
 _5  | 0.6363787615254752 
 _6  | 0.6363787615254752 
 _7  | 0.6363787615254752 
 _8  | 0.6363787615254752 
 _9  | 0.6363787615254752 
only showing top 1 row



In [21]:
df_small.rdd.getNumPartitions()

8

In [47]:
def df_size_in_bytes_exact(df: DataFrame):
    """
    Calculates the exact size in memory of a DataFrame by caching it and accessing the optimized plan

    NOTE: BE CAREFUL WITH THIS FUNCTION BECAUSE IT WILL CACHE ALL THE DATAFRAME!!! IF YOUR DATAFRAME IS
    TOO BIG USE `estimate_df_size_in_bytes`!!

    :param df: A pyspark DataFrame
    :return: The exact size in bytes
    """
    df = df.cache().select(
        df.columns
    )  # Just force the Spark planner to add the Cache op to the plan
    logging.info(f"Number of rows in the input DataFrame: {df.count()}")
    size_in_bytes = df._jdf.queryExecution().optimizedPlan().stats().sizeInBytes()
    df.unpersist(blocking=True)
    return size_in_bytes

def df_size_in_bytes_approximate(df: DataFrame, sample_perc: float = 0.05):
    """
    This method takes a sample of the input DataFrame (`sample_perc`) and applies `df_size_in_bytes_exact`
    method to it. After it calculates the exact size of the sample, it extrapolates the total size.

    :param df: A pyspark DataFrame
    :param sample_perc: The percentage of the DataFrame to sample. By default, a 5 %
    :return: The estimated size in bytes
    """
    sample_size_in_bytes = df_size_in_bytes_exact(df.sample(sample_perc))
    return sample_size_in_bytes / sample_perc


def add_partition_id_column(df: DataFrame):
    return df.withColumn("partition_id", sf.spark_partition_id())


def get_partition_count(df: DataFrame) -> DataFrame:
    """
    Gets the number of registers per partition. This method is useful if we are trying to determine if some
    partition is skewed.

    :return: A DataFrame containing `partition_id` and `count` columns
    """
    return add_partition_id_column(df).groupBy("partition_id").count()


def add_salt_column(df: DataFrame, skew_factor: int):
    """
    Adds a salt column to a DataFrame. We will be using this salt column when we are trying to perform
    join, groupBy, etc. operations into a skewed DataFrame. The idea is to add a random column and use
    the original keys + this salted key to perform the operations, so that we can avoid data skewness and
    possibly, OOM errors.

    :param df: A Pyspark DataFrame
    :param skew_factor: The skew factor. For example, if we set this value to 3, then the salted column will
        be populated by the elements 0, 1 and 2, extracted from a uniform probability distribution.
    :return: The original DataFrame with a `salt_id` column.
    """
    return df.withColumn("salt_id", (sf.rand() * skew_factor).cast(IntegerType()))

In [23]:
df_size_in_bytes_exact(df_small)

                                                                                

800000

In [26]:
df_size_in_bytes_approximate(df_small)

812800.0

In [29]:
get_partition_count(df_small).show()

+------------+-----+
|partition_id|count|
+------------+-----+
|           0| 1250|
|           1| 1250|
|           2| 1250|
|           3| 1250|
|           4| 1250|
|           5| 1250|
|           6| 1250|
|           7| 1250|
+------------+-----+



In [48]:
df_small_with_salt = add_salt_column(df_small, skew_factor=2)

In [49]:
df_small_with_salt.select("salt_id").show(n=20)

+-------+
|salt_id|
+-------+
|      1|
|      0|
|      1|
|      1|
|      0|
|      1|
|      0|
|      0|
|      0|
|      1|
|      0|
|      1|
|      0|
|      0|
|      0|
|      1|
|      0|
|      1|
|      1|
|      1|
+-------+
only showing top 20 rows



In [50]:
df_small_with_salt_and_partition_id = add_salt_column(add_partition_id_column(df_small), skew_factor=2)

In [51]:
df_small_with_salt_and_partition_id.select("partition_id", "salt_id").show(n=10)

+------------+-------+
|partition_id|salt_id|
+------------+-------+
|           0|      0|
|           0|      0|
|           0|      1|
|           0|      0|
|           0|      0|
|           0|      0|
|           0|      0|
|           0|      0|
|           0|      0|
|           0|      1|
+------------+-------+
only showing top 10 rows



In [53]:
df_small_with_salt_and_partition_id.groupBy("partition_id", "salt_id").count().show(n=100)

+------------+-------+-----+
|partition_id|salt_id|count|
+------------+-------+-----+
|           0|      0|  638|
|           0|      1|  612|
|           1|      0|  586|
|           1|      1|  664|
|           2|      1|  635|
|           2|      0|  615|
|           3|      1|  603|
|           3|      0|  647|
|           4|      0|  653|
|           4|      1|  597|
|           5|      0|  618|
|           5|      1|  632|
|           6|      1|  600|
|           6|      0|  650|
|           7|      1|  637|
|           7|      0|  613|
+------------+-------+-----+



In [54]:
get_partition_count(df_small_with_salt_and_partition_id).show()

+------------+-----+
|partition_id|count|
+------------+-----+
|           0| 1250|
|           1| 1250|
|           2| 1250|
|           3| 1250|
|           4| 1250|
|           5| 1250|
|           6| 1250|
|           7| 1250|
+------------+-----+

