In [1]:
from pyspark.sql import SparkSession
import time

# 初始化 Spark 会话
spark = SparkSession.builder.appName("StageDataTransferExperiment").master("spark://ecnu01:7077") \
    .config("spark.executor.memory", "14g") \
    .config("spark.executor.cores", "16") \
    .config("spark.executor.instances", "2") \
    .config("spark.driver.memory", "4g") \
    .config("spark.ui.retainedStages", "100") \
    .config("spark.ui.retainedJobs", "100") \
    .config("spark.ui.port", "8000") \
    .getOrCreate()
    
sc = spark.sparkContext

# 创建一个示例数据集
data = [(i % 10, i) for i in range(100000)]  # 100 万条数据，key 有 10 个唯一值
rdd = sc.parallelize(data, 4)  # 初始分区为 4

# ================== Stage 内部数据传输（窄依赖操作） ==================
print("\n--- Stage 内部数据传输测试（窄依赖操作） ---")
time.sleep(5)  # 等待 UI 启动
start_time = time.time()

# 仅使用窄依赖操作（map 和 filter），不会触发新的 Stage
rdd_narrow0 = rdd.map(lambda x: (x[0]*3, x[1]))
rdd_narrow = rdd.map(lambda x: (x[0], x[1] * 2))
rdd_narrow = rdd_narrow.filter(lambda x: x[1] % 2 == 0)
rdd_narrow = rdd_narrow.union(rdd_narrow0)
rdd_narrow = rdd_narrow.flatMap(lambda x: [(x[0], x[1] * 2)])
rdd_narrow = rdd_narrow.sample(False, 0.5)
rdd_narrow = rdd_narrow.mapValues(lambda x: x * 2)

# 触发行动操作
rdd_narrow.count()

end_time = time.time()
# print(rdd_narrow.toDebugString())
print(f"Execution time for narrow dependencies: {end_time - start_time:.2f} seconds")
print("Observe Spark UI: Only 1 Stage is created, no Shuffle occurs.")

# ================== Stage 之间数据传输（宽依赖操作） ==================
print("\n--- Stage 之间数据传输测试（宽依赖操作） ---")
time.sleep(5)  # 等待 UI 刷新
start_time = time.time()

# 使用宽依赖操作（groupByKey），触发 Shuffle 和新的 Stage
rdd_wide = rdd.groupByKey()
rdd_wide1 = rdd.reduceByKey(lambda x, y: x + y)

# 触发行动操作
rdd_wide.mapValues(len).collect()
rdd_wide1.collect()

# 使用 join 操作（宽依赖操作）触发 Shuffle
# rdd_wide2 = sc.parallelize([(i % 10, i) for i in range(100000)])
# rdd_wide3 = rdd.join(rdd_wide2)

# 使用 cogroup 操作（宽依赖操作）触发 Shuffle
# rdd_wide4 = rdd.cogroup(rdd_wide2)
# rdd_wide2.collect()
# rdd_wide3.collect()
# rdd_wide4.collect()

end_time = time.time()
# print(rdd_narrow.toDebugString())
print(f"Execution time for wide dependencies: {end_time - start_time:.2f} seconds")
print("Observe Spark UI: Multiple Stages are created, Shuffle occurs.")
sc.stop()



--- Stage 内部数据传输测试（窄依赖操作） ---
Execution time for narrow dependencies: 1.03 seconds
Observe Spark UI: Only 1 Stage is created, no Shuffle occurs.

--- Stage 之间数据传输测试（宽依赖操作） ---
Execution time for wide dependencies: 0.90 seconds
Observe Spark UI: Multiple Stages are created, Shuffle occurs.
