In [None]:
!pip install pyspark==3.3.1 py4j==0.10.9.5

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
from pyspark.sql import SparkSession

spark = SparkSession.builder\
        .master("local[*]")\
        .appName('PySpark Accumulator')\
        .getOrCreate()

In [None]:
from pyspark.sql.functions import *
from pyspark.sql.types import *

In [None]:
data_list = [ 
    ("California", "Sunnyvale", 9511), 
    ("California", "Mountain View", 94111),
    ("California", "Cupertino", 94123),
    ("California", "San Jose", 951) 
]

df = spark.createDataFrame(data_list) \
    .toDF("state", "city", "zipcode")

In [None]:
df.show()

+----------+-------------+-------+
|     state|         city|zipcode|
+----------+-------------+-------+
|California|    Sunnyvale|   9511|
|California|Mountain View|  94111|
|California|    Cupertino|  94123|
|California|     San Jose|    951|
+----------+-------------+-------+



In [None]:
bad_zipcodes = spark.sparkContext.accumulator(0)

In [None]:
def handle_bad_zipcode(c: int) -> int:
    if len(str(c)) != 5:
        bad_zipcodes.add(1)
        return None
    return c

In [None]:
spark.udf.register("handle_bad_zipcode", handle_bad_zipcode, IntegerType())

<function __main__.handle_bad_zipcode(c: int) -> int>

In [None]:
df.withColumn("corrected_zipcode", expr("handle_bad_zipcode(zipcode)")) \
        .show()

+----------+-------------+-------+-----------------+
|     state|         city|zipcode|corrected_zipcode|
+----------+-------------+-------+-----------------+
|California|    Sunnyvale|   9511|             null|
|California|Mountain View|  94111|            94111|
|California|    Cupertino|  94123|            94123|
|California|     San Jose|    951|             null|
+----------+-------------+-------+-----------------+



In [None]:
print("Bad Record Count:" + str(bad_zipcodes.value))

Bad Record Count:2


In [None]:
df.withColumn("corrected_zipcode", expr("handle_bad_zipcode(zipcode)")). \
    select("state", "city", "corrected_zipcode"). \
    withColumnRenamed("corrected_zipcode", "zipcode").show()

+----------+-------------+-------+
|     state|         city|zipcode|
+----------+-------------+-------+
|California|    Sunnyvale|   null|
|California|Mountain View|  94111|
|California|    Cupertino|  94123|
|California|     San Jose|   null|
+----------+-------------+-------+



DataFrame Foreach 예제

In [None]:
data = [1, 2, 3, 4, 5]
df_test = spark.createDataFrame(data, "int").toDF("value")

accumulator = spark.sparkContext.accumulator(0)

def add_to_accumulator(row):
    global accumulator
    accumulator += row["value"]

df_test.foreach(add_to_accumulator)

print("Accumulator value: ", accumulator.value)

Accumulator value:  15


앞서 Zipcode 예제를 DataFrame으로 해보기

In [None]:
accumulator_zipcode = spark.sparkContext.accumulator(0)

def find_wrong_zipcode(row):
    global accumulator_zipcode
    accumulator_zipcode += 1 if len(str(row["zipcode"])) != 5 else 0

In [None]:
df.foreach(find_wrong_zipcode)

In [None]:
print("Wrong zipcode: ", accumulator_zipcode.value)

Wrong zipcode:  2
