### Estimate Partition Count for File Read

In [None]:
# Create Spark Session
from pyspark.sql import SparkSession
spark = SparkSession \
    .builder \
    .appName("Factor of cores") \
    .master("local[*]") \
    .config("spark.executor.instances", "4") \
    .config("spark.executor.cores", "2") \
    .config("spark.executor.memory", "1G") \
    .config("spark.driver.memory", "4G") \
    .getOrCreate()
spark

In [None]:
# Check the default partition size
partition_size = int(spark.conf.get("spark.sql.files.maxPartitionBytes").replace("b",""))
print(f"Partition Size: {partition_size} in bytes and {int(partition_size) / 1024 / 1024} in MB")

# Check the default open Cost in Bytes
open_cost_size = int(spark.conf.get("spark.sql.files.openCostInBytes").replace("b",""))
print(f"Open Cost Size: {open_cost_size} in bytes and {int(open_cost_size) / 1024 / 1024} in MB")

# Default parallelism
parallelism = int(spark.sparkContext.defaultParallelism)
print(f"Default Parallelism: {parallelism}")

In [None]:
# File size in Bytes
average_file_size = 2898932284
total_files = 1

# Total Actual File Size in Bytes
total_file_size = average_file_size * total_files
print(f"Total File size on disk: {total_file_size} in bytes and {total_file_size / 1024 /1024} in MB")

In [None]:
# Padded file size for Spark read
padded_file_size = total_file_size + (total_files * open_cost_size)
print(f"Total padded file size: {padded_file_size} in bytes and {padded_file_size / 1024 /1024} in MB")

In [None]:
# Number of Bytes per Core
bytes_per_core = padded_file_size / parallelism
print(f"Bytes per Core: {bytes_per_core} in bytes and {bytes_per_core / 1024 /1024} in MB")

In [None]:
# Max Split Bytes
max_bytes_per_split = min(partition_size, max(open_cost_size, bytes_per_core))
print(f"Max bytes per Partition: {max_bytes_per_split} in bytes and {max_bytes_per_split / 1024 /1024} in MB")

In [None]:
# Total number of Partitions
num_of_partitions = padded_file_size / max_bytes_per_split
print(f"Approx number of partitions: {num_of_partitions}")

In [None]:
# Read the file to see the number of partitons
df_1 = spark.read.format("csv").option("header", True).load("./Input/sample_data.csv")
print(f"Number of Partition -> {df_1.rdd.getNumPartitions()}")

In [None]:
# Lets pack all code in single function
def num_partitions(file_size, num_of_files, spark):
    # Check the default partition size
    partition_size = int(spark.conf.get("spark.sql.files.maxPartitionBytes").replace("b",""))
    # Check the default open Cost in Bytes
    open_cost_size = int(spark.conf.get("spark.sql.files.openCostInBytes").replace("b",""))
    # Default parallelism
    parallelism = int(spark.sparkContext.defaultParallelism)
    # Total Actual File Size in Bytes
    total_file_size = file_size * num_of_files
    # Padded file size for Spark read
    padded_file_size = total_file_size + (num_of_files * open_cost_size)
    # Number of Bytes per Core
    bytes_per_core = padded_file_size / parallelism
    # Max Split Bytes
    max_bytes_per_split = min(partition_size, max(open_cost_size, bytes_per_core))
    # Total number of Partitions
    num_of_partitions = padded_file_size / max_bytes_per_split
    
    return num_of_partitions

In [None]:
# Validation 1
# Calculate the number of partitions as per our logic
estimated_num_partition = num_partitions(2647733632, 1, spark)
print(f"Estimated number of partitions = {estimated_num_partition}")

# Lets read a single csv file with approx size 2647733632 bytes or 2.5 GB
df_1 = spark.read.format("csv").option("header", True).load("./Input/sample_data.csv")
print(f"Number of Partition -> {df_1.rdd.getNumPartitions()}")


In [None]:
# Validation 2
# Calculate the number of partitions as per our logic for tiny files
estimated_num_partition = num_partitions(255196, 4000, spark)
print(f"Estimated number of partitions = {estimated_num_partition}")

# Lets read multiple parquet file with approx size 7777 bytes or 7.7 KB
df_2 = spark.read.format("parquet").load("./Input/sample_data_parguet/")
print(f"Number of Partition -> {df_2.rdd.getNumPartitions()}")

In [None]:
spark.stop()