In [53]:
from pyspark.sql import SparkSession 
from pyspark.sql.functions import * 
from pyspark.sql.types import *

In [54]:
spark = (SparkSession.builder
         .appName("cache-and-persist")
         .master("spark://spark-master:7077")
         .config("spark.executor.memory", "512m")
         .getOrCreate())

spark.sparkContext.setLogLevel("ERROR")

In [55]:
# Create some sample data frames
# A large data frame with 10 million rows and two columns: id and value
large_df = (spark.range(0, 10000000)
            .withColumn("date", date_sub(current_date(), (rand() * 365).cast("int")))
            .withColumn("ProductId", (rand() * 100).cast("int")))
large_df.show(5)

[Stage 0:>                                                          (0 + 1) / 1]

+---+----------+---------+
| id|      date|ProductId|
+---+----------+---------+
|  0|2023-09-21|       76|
|  1|2022-09-27|       36|
|  2|2023-08-30|       14|
|  3|2023-07-28|        6|
|  4|2023-07-21|       72|
+---+----------+---------+
only showing top 5 rows



                                                                                

In [56]:
# Cache the DataFrame using cache() method
large_df.cache()
# Check the storage level of the cached DataFrame
print(large_df.storageLevel)

Disk Memory Deserialized 1x Replicated


In [57]:
# Persist the DataFrame using persist() method with a different storage level
large_df.persist(StorageLevel.MEMORY_AND_DISK_DESER)
# Check the storage level of the persisted DataFrame
print(large_df.storageLevel)

Disk Memory Deserialized 1x Replicated


In [58]:
results_df = large_df.groupBy("ProductId").agg({"Id": "count"}) 
measure_time(results_df)
# Show the result
results_df.show(5)

                                                                                

Execution time: 9.418130159378052 seconds
+---------+---------+
|ProductId|count(Id)|
+---------+---------+
|       31|   100120|
|       85|   100070|
|       65|   100211|
|       53|    99935|
|       78|    99230|
+---------+---------+
only showing top 5 rows



In [61]:
results_df = large_df.groupBy("ProductId").agg({"Id": "count"}) 
measure_time(results_df)
# Show the result
results_df.show(5)

Execution time: 0.5906562805175781 seconds
+---------+---------+
|ProductId|count(Id)|
+---------+---------+
|       31|   100120|
|       85|   100070|
|       65|   100211|
|       53|    99935|
|       78|    99230|
+---------+---------+
only showing top 5 rows



In [59]:
# Unpersist the DataFrame using unpersist() method
large_df.unpersist()
# Check the storage level of the unpersisted DataFrame
print(large_df.storageLevel)

Serialized 1x Replicated


In [52]:
spark.stop()