In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import sum, count, col, when

spark = SparkSession.builder.appName("PersistExample").getOrCreate()


# Cache example
df = spark.read.parquet("transactions.parquet")

# Apply transformations
expensive_df = df.groupBy("customer_id").agg(
    sum("amount").alias("total_spent"),
    count("*").alias("transaction_count")
).filter(col("total_spent") > 1000)

# Cache the result since we'll analyze it further
expensive_df.cache()

# Multiple operations that benefit from caching
high_value_customers = expensive_df.filter(col("total_spent") > 10000)
customer_segments = expensive_df.withColumn("segment", 
    when(col("total_spent") > 5000, "Premium")
    .when(col("total_spent") > 2000, "Gold")
    .otherwise("Silver"))

# Show results
high_value_customers.show()
customer_segments.groupBy("segment").count().show()

# Clean up
expensive_df.unpersist()