In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *

spark = SparkSession.builder \
        .master("local[*]") \
        .appName("PySpark Accumulator") \
        .getOrCreate()

In [2]:
data_list = [
    ("California", "Sunnyvale", 9511),
    ("California", "Mountain View", 94111),
    ("California", "Cupertino", 94123),
    ("California", "San Jose", 951)
]

df = spark.createDataFrame(data_list) \
    .toDF("state","city","zipcode")

df.show(5)

+----------+-------------+-------+
|     state|         city|zipcode|
+----------+-------------+-------+
|California|    Sunnyvale|   9511|
|California|Mountain View|  94111|
|California|    Cupertino|  94123|
|California|     San Jose|    951|
+----------+-------------+-------+



In [3]:
bad_zipcode = spark.sparkContext.accumulator(0)

In [4]:
def handle_bad_zipcode(c: int) -> int:
    if len(str(c)) != 5:
        bad_zipcode.add(1)
        return None
    return c

In [5]:
spark.udf.register("handle_bad_zipcode", handle_bad_zipcode, IntegerType())

<function __main__.handle_bad_zipcode(c: int) -> int>

In [13]:
df.withColumn("corrected_zipcode", expr("handle_bad_zipcode(zipcode)")) \
    .show()

+----------+-------------+-------+-----------------+
|     state|         city|zipcode|corrected_zipcode|
+----------+-------------+-------+-----------------+
|California|    Sunnyvale|   9511|             NULL|
|California|Mountain View|  94111|            94111|
|California|    Cupertino|  94123|            94123|
|California|     San Jose|    951|             NULL|
+----------+-------------+-------+-----------------+



In [14]:
print("Bad Record Count:" + str(bad_zipcode.value))

Bad Record Count:2


In [None]:
df.withColumn("corrected_zipcode", expr("handle_bad_zipcode(zipcode)")). \
    select("state","city", "corrected_zipcode"). \
    withColumnRenamed("corrected_zipcode", "zipcode").show()

+----------+-------------+-------+
|     state|         city|zipcode|
+----------+-------------+-------+
|California|    Sunnyvale|   NULL|
|California|Mountain View|  94111|
|California|    Cupertino|  94123|
|California|     San Jose|   NULL|
+----------+-------------+-------+



: 

### DataFrame Foreach 예제

In [6]:
data = [1, 2, 3, 4, 5]
df_test = spark.createDataFrame(data, "int").toDF("value")

accumulator = spark.sparkContext.accumulator(0)

def add_to_accumulator(row):
    global accumulator
    accumulator += row["value"]

df_test.foreach(add_to_accumulator)
print("Accumulator value: ", accumulator.value)

Accumulator value:  15


### Zipcode 예제를 DataFrame으로 해보기

In [7]:
accumulator_zipcode = spark.sparkContext.accumulator(0)

def find_wrong_zipcode(row):
    global accumulator_zipcode
    accumulator_zipcode += 1 if len(str(row['zipcode'])) != 5 else 0

df.foreach(find_wrong_zipcode)
print("Wrong zipcode: ", accumulator_zipcode.value)

Wrong zipcode:  2
