# Spark Accumulators

Spark accumulators are *distributed counters* which allow you to increment a global counter in Python UDFs. This is useful for counting certain events or cases, which are not directly part of your data processing. Best example would be to count broken records.

## Weather Data Example
We will use the weather measurement data again as an example. Instead of using the Spark functions to extract the measurement information, we will write a Python UDF instead. Although this would not be required in our example, this approach might actually be useful in different scenarios. Even the weather data set contains more information which is at non-fixed locations and could not be extracted using simple Spark/SQL string functions.

This example will show to use accumulators to count records. For example this might be useful to count broken records in other examples (weather data does not have broken records, though).

# 1 Load Data

As we will not use the previous extraction, we simply load a single year as text data. In the next section we will apply a Python UDF to extract the desired information.

In [1]:
storageLocation = "s3://dimajix-training/data/weather"

In [2]:
from pyspark.sql.functions import *

raw_weather = spark.read.text(storageLocation + "/2003").withColumn("year", lit(2003))
raw_weather.limit(5).toPandas()

Unnamed: 0,value,year
0,0494703160256242003010100003+55200-162717SY-MT...,2003
1,0228703160256242003010100174+55200-162730FM-16...,2003
2,044070316025624200301010053C+55200-162717FM-15...,2003
3,0071703160256242003010101009+55200-162717NSRDB...,2003
4,042770316025624200301010153C+55200-162717FM-15...,2003


# 2 Extract Weather

We will now create and test a simple Python UDF for extracting the weather data. In the next section we will improve that function for counting invalid USAF and WBAN codes. But step by step...

## 2.1 Define Python UDF

In [5]:
from pyspark.sql.types import *

schema = StructType([
    StructField("usaf", StringType()),
    StructField("wban", StringType()),
    StructField("air_temperature", FloatType()),
    StructField("air_temperature_qual", IntegerType()),
])

@udf(schema)
def extract_weather(row):
    usaf = row[4:10]
    wban = row[10:15]
    air_temperature = float(row[87:92])/10
    air_temperature_qual = int(row[92])
    return (usaf, wban, air_temperature, air_temperature_qual)
    

## 2.2 Use Python UDF

Now we can apply the Python UDF `extract_weather` to process our data.

In [12]:
result = raw_weather.select(extract_weather(raw_weather["value"]).alias("measurement"))
result.limit(5).toPandas()

Unnamed: 0,measurement
0,"(703160, 25624, -0.6000000238418579, 5)"
1,"(703160, 25624, -2.0, 1)"
2,"(703160, 25624, -2.799999952316284, 5)"
3,"(703160, 25624, 999.9000244140625, 9)"
4,"(703160, 25624, -2.200000047683716, 5)"


### Inspect Schema

Since the UDF returned multiple columns, we now have a nested schema.

In [13]:
result.printSchema()

root
 |-- measurement: struct (nullable = true)
 |    |-- usaf: string (nullable = true)
 |    |-- wban: string (nullable = true)
 |    |-- air_temperature: float (nullable = true)
 |    |-- air_temperature_qual: integer (nullable = true)



# 3 Use Accumulators

As we just said, we want to improve the Python UDF to count certain important events. For example you might be interested at how many records are broken (none in our data set). We chose a different example: We want to count the number of invalid USAF and WBAN.

## 3.1 Create Accumulator

In [63]:
records_processed = spark.sparkContext.accumulator(0)
invalid_usaf = spark.sparkContext.accumulator(0) 
invalid_wban = spark.sparkContext.accumulator(0)

## 3.2 Increment accumulators

Now we need to adopt our Python UDF to increment accumulators on specific events. We want to increment each of the accumulators whenever we process an invalid usaf and/or wban.

In [64]:
from pyspark.sql.types import *

schema = StructType([
    StructField("usaf", StringType()),
    StructField("wban", StringType()),
    StructField("air_temperature", FloatType()),
    StructField("air_temperature_qual", IntegerType()),
])

@udf(schema)
def extract_weather(row):
    usaf = row[4:10]
    wban = row[10:15]
    air_temperature = float(row[87:92])/10
    air_temperature_qual = int(row[92])
    
    # Increment accumulators
    records_processed.add(1)
    if usaf == '999999':
        invalid_usaf.add(1)
    if wban == '99999':
        invalid_wban.add(1)
    
    return (usaf, wban, air_temperature, air_temperature_qual)

## 3.3 Execute Query

Now we can use the modified UDF and check if the accumulators are used.

In [65]:
result = raw_weather.select(extract_weather(raw_weather["value"]).alias("measurement"))
result.limit(5).toPandas()

Unnamed: 0,measurement
0,"(703160, 25624, -0.6000000238418579, 5)"
1,"(703160, 25624, -2.0, 1)"
2,"(703160, 25624, -2.799999952316284, 5)"
3,"(703160, 25624, 999.9000244140625, 9)"
4,"(703160, 25624, -2.200000047683716, 5)"


### Inspect Counters

In [66]:
print("records_processed=" + str(records_processed.value))
print("invalid_usaf=" + str(invalid_usaf.value))
print("invalid_wban=" + str(invalid_wban.value))

records_processed=0
invalid_usaf=0
invalid_wban=0


Surprisingly the counters are not increased. We will try `count()` instead.

### First run

Now let's try to execute the UDF for every record. The method `count()` should do the job

In [55]:
result.count()

1798753

In [56]:
print("records_processed=" + str(records_processed.value))
print("invalid_usaf=" + str(invalid_usaf.value))
print("invalid_wban=" + str(invalid_wban.value))

records_processed=0
invalid_usaf=0
invalid_wban=0


### Second Run

Since that didn't work either, because Spark is too clever, let's force the execution by adding a filter condition which requries the UDF to be executed.

In [57]:
result.filter(result["measurement.wban"] != '123').count()

1798753

In [58]:
print("records_processed=" + str(records_processed.value))
print("invalid_usaf=" + str(invalid_usaf.value))
print("invalid_wban=" + str(invalid_wban.value))

records_processed=1798753
invalid_usaf=140706
invalid_wban=753818


### Third Run

Accumulators won't be reset automatically between query executions.

In [59]:
result.filter(result["measurement.wban"] != '123').count()

print("records_processed=" + str(records_processed.value))
print("invalid_usaf=" + str(invalid_usaf.value))
print("invalid_wban=" + str(invalid_wban.value))

records_processed=3597506
invalid_usaf=281412
invalid_wban=1507636


### Reset Counter

You can also reset counters by simply assign them a value.

In [60]:
records_processed.value = 0

In [61]:
result.filter(result["measurement.wban"] != '123').count()

print("records_processed=" + str(records_processed.value))
print("invalid_usaf=" + str(invalid_usaf.value))
print("invalid_wban=" + str(invalid_wban.value))

records_processed=1798753
invalid_usaf=422118
invalid_wban=2261454


## 4 Afterthought

It is important to understand that Spark accumulators actually count how often a specific event was triggered in our Python UDF. Since Spark might evaluate certain code paths multiple times (for example in cases of node failures or in cases when the execution plan executes a certain step multiple times). Therefore accumulators cannot and therefore should not be used for generating statistics over the data itself. But they can be used to understand which code paths have been used more often than others.