# Spark Examples

In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import avg, col, when

In [2]:
spark = SparkSession.builder.appName("demo").getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/12/13 08:20:09 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


## Simple Spark DataFrame example

In [5]:
df = spark.createDataFrame(
    [
        ("sue", 32),
        ("li", 3),
        ("bob", 75),
        ("heo", 13),
    ],
    ["first_name", "age"],
)

In [6]:
df.show()

[Stage 0:>                                                          (0 + 1) / 1]

+----------+---+
|first_name|age|
+----------+---+
|       sue| 32|
|        li|  3|
|       bob| 75|
|       heo| 13|
+----------+---+



                                                                                

In [31]:
df1 = df.withColumn(
    "life_stage",
    when(col("age") < 13, "child")
    .when(col("age").between(13, 19), "teenager")
    .otherwise("adult"),
)

In [32]:
df1.show()

+----------+---+----------+
|first_name|age|life_stage|
+----------+---+----------+
|       sue| 32|     adult|
|        li|  3|     child|
|       bob| 75|     adult|
|       heo| 13|  teenager|
+----------+---+----------+



In [33]:
df.show()

+----------+---+
|first_name|age|
+----------+---+
|       sue| 32|
|        li|  3|
|       bob| 75|
|       heo| 13|
+----------+---+



In [65]:
def with_life_stage(df):
    return df.withColumn(
        "life_stage",
        when(col("age") < 13, "child")
        .when(col("age") >= 13 | col("age") < 20, "teenager")
        .when(col("age") > 20, "adult"),
    )

In [35]:
df.transform(with_life_stage).show()

+----------+---+----------+
|first_name|age|life_stage|
+----------+---+----------+
|       sue| 32|     adult|
|        li|  3|     child|
|       bob| 75|     adult|
|       heo| 13|  teenager|
+----------+---+----------+



In [37]:
df1.where(col("life_stage").isin(["teenager", "adult"])).show()

+----------+---+----------+
|first_name|age|life_stage|
+----------+---+----------+
|       sue| 32|     adult|
|       bob| 75|     adult|
|       heo| 13|  teenager|
+----------+---+----------+



In [40]:
df1.select(avg("age")).show()

+--------+
|avg(age)|
+--------+
|   30.75|
+--------+



In [41]:
df1.groupBy("life_stage").avg().show()

+----------+--------+
|life_stage|avg(age)|
+----------+--------+
|     adult|    53.5|
|     child|     3.0|
|  teenager|    13.0|
+----------+--------+



In [46]:
spark.sql("select avg(age) from {df1}", df1=df1).show()

+--------+
|avg(age)|
+--------+
|   30.75|
+--------+



In [47]:
spark.sql("select life_stage, avg(age) from {df1} group by life_stage", df1=df1).show()

+----------+--------+
|life_stage|avg(age)|
+----------+--------+
|     adult|    53.5|
|     child|     3.0|
|  teenager|    13.0|
+----------+--------+



In [48]:
df1.write.saveAsTable("some_people")

                                                                                

In [50]:
spark.sql("select * from some_people").show()

+----------+---+----------+
|first_name|age|life_stage|
+----------+---+----------+
|       heo| 13|  teenager|
|       sue| 32|     adult|
|       bob| 75|     adult|
|        li|  3|     child|
+----------+---+----------+



In [51]:
spark.sql("INSERT INTO some_people VALUES ('frank', 4, 'child')")

DataFrame[]

In [52]:
spark.sql("select * from some_people").show()

+----------+---+----------+
|first_name|age|life_stage|
+----------+---+----------+
|       heo| 13|  teenager|
|     frank|  4|     child|
|       sue| 32|     adult|
|       bob| 75|     adult|
|        li|  3|     child|
+----------+---+----------+



In [54]:
spark.sql("select * from some_people where life_stage='teenager'").show()

+----------+---+----------+
|first_name|age|life_stage|
+----------+---+----------+
|       heo| 13|  teenager|
+----------+---+----------+



## Spark RDD

In [57]:
text_file = spark.sparkContext.textFile("some_words.txt")

In [58]:
counts = (
    text_file.flatMap(lambda line: line.split(" "))
    .map(lambda word: (word, 1))
    .reduceByKey(lambda a, b: a + b)
)

In [60]:
counts.collect()

[('these', 2),
 ('are', 2),
 ('more', 1),
 ('in', 1),
 ('words', 3),
 ('english', 1)]

## Spark unit testing

In [53]:
def with_life_stage(df):
    return df.withColumn(
        "life_stage",
        when(col("age") < 13, "child")
        .when(col("age").between(13, 19), "teenager")
        .when(col("age") > 19, "adult")
    )

In [None]:
def with_life_stage(df):
    return spark.sql("select *, life_stage from {df}", df=df)

In [54]:
df = spark.createDataFrame(
    [
        (3, "child"),
        (75, "adult"),
        (19, "teenager"),
        (None, None),
    ],
    ["age", "expected"],
)

In [55]:
import chispa

chispa.assert_column_equality(df.transform(with_life_stage), "life_stage", "expected")

In [56]:
df.transform(with_life_stage).show()

+----+--------+----------+
| age|expected|life_stage|
+----+--------+----------+
|   3|   child|     child|
|  75|   adult|     adult|
|  19|teenager|  teenager|
|null|    null|      null|
+----+--------+----------+

