# 0. Environment preparation


Let's create SparkSession object - entry point for all Spark computations. We're also loading some data and saving them as views.

In [1]:
import env_setup
import pyspark.sql.functions as f

spark = env_setup.getSession(local=True)

Created local SparkSession
Created "sales" view from CSV file
Created "item_prices" view from CSV file


We can assign each part of a query to a variable. __table()__ method will return a Dataframe which contains a structured dataset. To see what's inside it we can print its schema.

In [2]:
sales_df = spark.table("sales")
item_prices_df = spark.table("item_prices")

print("# schema of both tables")
sales_df.printSchema()
item_prices_df.printSchema()


# schema of both tables
root
 |-- shop_id: string (nullable = true)
 |-- item_id: string (nullable = true)
 |-- qty: string (nullable = true)
 |-- transaction_date: string (nullable = true)

root
 |-- item_id: string (nullable = true)
 |-- unit_price: string (nullable = true)



We can also print first rows with __show()__

In [3]:
print("# sample results from both tables")
sales_df.show()
item_prices_df.show()

# sample results from both tables
+-------+-------+---+----------------+
|shop_id|item_id|qty|transaction_date|
+-------+-------+---+----------------+
| SHOP_1| ITEM_1|  2|      2018-02-01|
| SHOP_1| ITEM_2|  1|      2018-02-01|
| SHOP_1| ITEM_3|  4|      2018-02-10|
| SHOP_2| ITEM_3|  1|      2018-02-02|
| SHOP_2| ITEM_1|  1|      2018-02-11|
+-------+-------+---+----------------+

+-------+----------+
|item_id|unit_price|
+-------+----------+
| ITEM_1|     100.0|
| ITEM_2|     300.0|
| ITEM_3|      50.0|
+-------+----------+



Dataframe transformations are evaluated lazily. If we chain multiple operations together they will be invoked only when we call an action. To see what's the current *execution plan* we can call __explain()__ method. It will work even for a simple selection from a view like here. 

In [4]:
print("# query execution plan for this simple select")
sales_df.explain()
item_prices_df.explain()

# query execution plan for this simple select
== Physical Plan ==
*(1) FileScan csv [shop_id#10,item_id#11,qty#12,transaction_date#13] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:/Users/mkromka/dev/trainings/pySpark_workshop/data/sales.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<shop_id:string,item_id:string,qty:string,transaction_date:string>
== Physical Plan ==
*(1) FileScan csv [item_id#28,unit_price#29] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:/Users/mkromka/dev/trainings/pySpark_workshop/data/item_prices.csv], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<item_id:string,unit_price:string>


It prints a one-step operation, which is a csv FileScan, which is correct, because our table was read from such file. We can also see the schema along with some other information.

# 1. SQL support

Spark SQL's name is rather intuitive - we can use Spark to execute SQL queries on our Dataframes.

In [5]:
sql_df = spark.sql('select item_id, transaction_date from sales where shop_id = "SHOP_1" order by transaction_date desc')
sql_df.show()

+-------+----------------+
|item_id|transaction_date|
+-------+----------------+
| ITEM_3|      2018-02-10|
| ITEM_1|      2018-02-01|
| ITEM_2|      2018-02-01|
+-------+----------------+



### ex1. Use plain SQL query select all transactions with quantity between 2 and 4 (inclusive). Show all results.


In [6]:
spark.sql('select * from sales where qty >= 2 and qty <=4').show()

+-------+-------+---+----------------+
|shop_id|item_id|qty|transaction_date|
+-------+-------+---+----------------+
| SHOP_1| ITEM_1|  2|      2018-02-01|
| SHOP_1| ITEM_3|  4|      2018-02-10|
+-------+-------+---+----------------+



### ex2. Print mean unit price for all items

In [7]:
spark.sql('select mean(unit_price) as mean_price from item_prices').show()

+----------+
|mean_price|
+----------+
|     150.0|
+----------+



# 2. Dataframe operations

For analysts usinq plain SQL API may be enough, but its much better to use method invocation and chaining to transform out Dataframes. Almost every operation in SQL can be translated to some methods.

In [8]:
string_df = sales_df.select("item_id", "transaction_date")\
    .filter(f.col("shop_id") == "SHOP_1")\
    .orderBy(f.col("transaction_date").desc())
string_df.show()

+-------+----------------+
|item_id|transaction_date|
+-------+----------------+
| ITEM_3|      2018-02-10|
| ITEM_1|      2018-02-01|
| ITEM_2|      2018-02-01|
+-------+----------------+



 We can use strings to specify column names, but there is another, more dynamic, option to treat columns as fields in a dataframe object.

In [9]:
field_df = sales_df.select(sales_df.item_id, sales_df.transaction_date)\
    .filter(sales_df.shop_id == "SHOP_1")\
    .orderBy(sales_df.transaction_date.desc())
field_df.show()

+-------+----------------+
|item_id|transaction_date|
+-------+----------------+
| ITEM_3|      2018-02-10|
| ITEM_1|      2018-02-01|
| ITEM_2|      2018-02-01|
+-------+----------------+



The results are the same and actually all these queries will be executed in __exactly__ the same way by Spark. The only difference is the query translation step. Afterwards, when spark has an execution plan of a query it will treat it in the same way regardless of used API. To verify that let's see physical plans of all three queries.

In [10]:
print("Plain SQL API execution plan")
sql_df.explain()
print("\n Columns as strings approach")
string_df.explain()
print("\n Columns as fields solution")
field_df.explain()

Plain SQL API execution plan
== Physical Plan ==
*(2) Sort [transaction_date#13 DESC NULLS LAST], true, 0
+- Exchange rangepartitioning(transaction_date#13 DESC NULLS LAST, 200)
   +- *(1) Project [item_id#11, transaction_date#13]
      +- *(1) Filter (isnotnull(shop_id#10) && (shop_id#10 = SHOP_1))
         +- *(1) FileScan csv [shop_id#10,item_id#11,transaction_date#13] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:/Users/mkromka/dev/trainings/pySpark_workshop/data/sales.csv], PartitionFilters: [], PushedFilters: [IsNotNull(shop_id), EqualTo(shop_id,SHOP_1)], ReadSchema: struct<shop_id:string,item_id:string,transaction_date:string>

 Columns as strings approach
== Physical Plan ==
*(2) Sort [transaction_date#13 DESC NULLS LAST], true, 0
+- Exchange rangepartitioning(transaction_date#13 DESC NULLS LAST, 200)
   +- *(1) Project [item_id#11, transaction_date#13]
      +- *(1) Filter (isnotnull(shop_id#10) && (shop_id#10 = SHOP_1))
         +- *(1) FileScan csv [shop_id#1

### ex3. Rewrite query from ex1 to dataframe operations

In [11]:
sales_df.filter((sales_df.qty >= 2) & (sales_df.qty <= 4)).show()

+-------+-------+---+----------------+
|shop_id|item_id|qty|transaction_date|
+-------+-------+---+----------------+
| SHOP_1| ITEM_1|  2|      2018-02-01|
| SHOP_1| ITEM_3|  4|      2018-02-10|
+-------+-------+---+----------------+



# 3. Joins

Most of the complex queries in relational databases require joins. Spark SQL have them as well.

In [12]:
print("# joins using plain SQL queries")
spark.sql('select * from sales join item_prices on sales.item_id = item_prices.item_id').show()

print("# using Dataframe API - duplicated item_id column!")
sales_df.join(item_prices_df, sales_df.item_id == item_prices_df.item_id, "inner").show()

print("# dropping redundant column")
sales_with_unit_prices_df = sales_df\
    .join(item_prices_df, sales_df.item_id == item_prices_df.item_id)\
    .drop(sales_df.item_id)
    
sales_with_unit_prices_df.show()

# joins using plain SQL queries
+-------+-------+---+----------------+-------+----------+
|shop_id|item_id|qty|transaction_date|item_id|unit_price|
+-------+-------+---+----------------+-------+----------+
| SHOP_1| ITEM_1|  2|      2018-02-01| ITEM_1|     100.0|
| SHOP_1| ITEM_2|  1|      2018-02-01| ITEM_2|     300.0|
| SHOP_1| ITEM_3|  4|      2018-02-10| ITEM_3|      50.0|
| SHOP_2| ITEM_3|  1|      2018-02-02| ITEM_3|      50.0|
| SHOP_2| ITEM_1|  1|      2018-02-11| ITEM_1|     100.0|
+-------+-------+---+----------------+-------+----------+

# using Dataframe API - duplicated item_id column!
+-------+-------+---+----------------+-------+----------+
|shop_id|item_id|qty|transaction_date|item_id|unit_price|
+-------+-------+---+----------------+-------+----------+
| SHOP_1| ITEM_1|  2|      2018-02-01| ITEM_1|     100.0|
| SHOP_1| ITEM_2|  1|      2018-02-01| ITEM_2|     300.0|
| SHOP_1| ITEM_3|  4|      2018-02-10| ITEM_3|      50.0|
| SHOP_2| ITEM_3|  1|      2018-02-02| ITEM_3|

### ex4. Filter out excluded items

Given a Spark Dataframe with a column of items select all transactions from sales_df not containing any of these items.

In [13]:
print("# Dataframe with column of items we would like to exclude")
excluded_items_df = spark.createDataFrame([("ITEM_2",),("ITEM_4",)], ['item'])
excluded_items_df.show()

print("# we could collect the result but the execution plan contains all values, not good for large datasets:")
excluded_items = [x.item for x in excluded_items_df.collect()]
print(excluded_items)
filtered_isin_df = sales_df.filter(~sales_df.item_id.isin(excluded_items))
filtered_isin_df.show()
filtered_isin_df.explain()

print("# using join and filtering - result doesn't contain ITEM_2 and ITEM_4 ")
filtered_left_join = sales_df.join(excluded_items_df, sales_df.item_id == excluded_items_df.item, "left_outer")\
    .filter(f.isnull(excluded_items_df.item))\
    .drop(excluded_items_df.item)
filtered_left_join.show()
filtered_left_join.explain()

print("# better option: anti join")
filtered_left_anti = sales_df.join(excluded_items_df, sales_df.item_id == excluded_items_df.item, "left_anti")
filtered_left_anti.show()
filtered_left_anti.explain()

# Dataframe with column of items we would like to exclude
+------+
|  item|
+------+
|ITEM_2|
|ITEM_4|
+------+

# we could collect the result but the execution plan contains all values, not good for large datasets:
['ITEM_2', 'ITEM_4']
+-------+-------+---+----------------+
|shop_id|item_id|qty|transaction_date|
+-------+-------+---+----------------+
| SHOP_1| ITEM_1|  2|      2018-02-01|
| SHOP_1| ITEM_3|  4|      2018-02-10|
| SHOP_2| ITEM_3|  1|      2018-02-02|
| SHOP_2| ITEM_1|  1|      2018-02-11|
+-------+-------+---+----------------+

== Physical Plan ==
*(1) Filter NOT item_id#11 IN (ITEM_2,ITEM_4)
+- *(1) FileScan csv [shop_id#10,item_id#11,qty#12,transaction_date#13] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:/Users/mkromka/dev/trainings/pySpark_workshop/data/sales.csv], PartitionFilters: [], PushedFilters: [Not(In(item_id, [ITEM_2,ITEM_4]))], ReadSchema: struct<shop_id:string,item_id:string,qty:string,transaction_date:string>
# using join and filtering -

# 4. Adding columns

We might want to add a column in dataframe based on values from other columns. __withColumn()__ method is just for that.

In [14]:
total_sales_df = sales_with_unit_prices_df\
    .withColumn("total_sales", f.col("qty") * f.col("unit_price"))

print("# Added new total_sales column which is a multuply of unit_price and qty")
total_sales_df.show()

# Added new total_sales column which is a multuply of unit_price and qty
+-------+---+----------------+-------+----------+-----------+
|shop_id|qty|transaction_date|item_id|unit_price|total_sales|
+-------+---+----------------+-------+----------+-----------+
| SHOP_1|  2|      2018-02-01| ITEM_1|     100.0|      200.0|
| SHOP_1|  1|      2018-02-01| ITEM_2|     300.0|      300.0|
| SHOP_1|  4|      2018-02-10| ITEM_3|      50.0|      200.0|
| SHOP_2|  1|      2018-02-02| ITEM_3|      50.0|       50.0|
| SHOP_2|  1|      2018-02-11| ITEM_1|     100.0|      100.0|
+-------+---+----------------+-------+----------+-----------+



Apart from simple ooperations we may use complex predicates while calculating the new value

In [45]:
print("# Adding price category column based on a condition")
sales_with_transaction_category = total_sales_df\
    .withColumn("transaction_price_category", \
                f.when(f.col("total_sales") > 150, "High")\
                .when(f.col("total_sales") < 60, "Low")\
                .otherwise("Medium"))

sales_with_transaction_category.show()

# Adding price category column based on a condition
+-------+---+----------------+-------+----------+-----------+--------------------------+
|shop_id|qty|transaction_date|item_id|unit_price|total_sales|transaction_price_category|
+-------+---+----------------+-------+----------+-----------+--------------------------+
| SHOP_1|  2|      2018-02-01| ITEM_1|     100.0|      200.0|                      High|
| SHOP_1|  1|      2018-02-01| ITEM_2|     300.0|      300.0|                      High|
| SHOP_1|  4|      2018-02-10| ITEM_3|      50.0|      200.0|                      High|
| SHOP_2|  1|      2018-02-02| ITEM_3|      50.0|       50.0|                       Low|
| SHOP_2|  1|      2018-02-11| ITEM_1|     100.0|      100.0|                    Medium|
+-------+---+----------------+-------+----------+-----------+--------------------------+



### ex5. Two-packs of items
We want to create two-packs of items, but sum of their prices must be lower than 360.
hint: use cross join, and alias


In [16]:
item_prices_df.alias("items1")\
    .crossJoin(item_prices_df.alias("items2"))\
    .withColumn("price_sum",f.col("items1.unit_price") + f.col("items2.unit_price"))\
    .where((f.col("price_sum") < 360) & (f.col("items1.item_id") != f.col("items2.item_id")))\
    .select("items1.item_id", "items2.item_id", "price_sum")\
    .show()

+-------+-------+---------+
|item_id|item_id|price_sum|
+-------+-------+---------+
| ITEM_1| ITEM_3|    150.0|
| ITEM_2| ITEM_3|    350.0|
| ITEM_3| ITEM_1|    150.0|
| ITEM_3| ITEM_2|    350.0|
+-------+-------+---------+



# 5. Simple aggregations

We already saw a simple aggregation when calculating mean of prices. Dataframe API allows us to make do it as well.

In [17]:
print("# using alias to have a better column name")
total_sales_df\
    .groupBy("shop_id")\
    .agg(f.sum(total_sales_df.total_sales).alias("sales"))\
    .orderBy(f.col("sales").desc())\
    .show()
    # .orderBy(total_sales_df.sales) won't work as total_sales_df has no sales column (we define it later)

# using alias to have a better column name
+-------+-----+
|shop_id|sales|
+-------+-----+
| SHOP_1|700.0|
| SHOP_2|150.0|
+-------+-----+



### ex6. Aggregating data to lists
Produce a column with list of all shops where each item was sold, new column should be named "shops"
hint: check collect_list function

In [18]:
total_sales_df\
    .groupBy("item_id")\
    .agg(f.collect_list(f.col("shop_id")).alias("shops"))\
    .show()

+-------+----------------+
|item_id|           shops|
+-------+----------------+
| ITEM_3|[SHOP_1, SHOP_2]|
| ITEM_2|        [SHOP_1]|
| ITEM_1|[SHOP_1, SHOP_2]|
+-------+----------------+



# 6. Date handling

Lots of datasets contain some kind of notion of date or time. Let's see how can we transform it. Our total_sales_df contains *transaction_date* column, we are going to extract each bit out of it with functions from pyspark.sql.functions module

In [19]:
print("# extracting multiple elements of date")
total_sales_df\
    .withColumn("year", f.year(f.col("transaction_date")))\
    .withColumn("month", f.month(f.col("transaction_date")))\
    .withColumn("day", f.dayofmonth(f.col("transaction_date")))\
    .withColumn("day_of_year", f.dayofyear(f.col("transaction_date")))\
    .withColumn("day_of_week", f.date_format(f.col("transaction_date"), 'u'))\
    .withColumn("day_of_week_str", f.date_format(f.col("transaction_date"), 'E'))\
    .withColumn("week_of_year", f.weekofyear(f.col("transaction_date")))\
    .show()

# extracting multiple elements of date
+-------+---+----------------+-------+----------+-----------+----+-----+---+-----------+-----------+---------------+------------+
|shop_id|qty|transaction_date|item_id|unit_price|total_sales|year|month|day|day_of_year|day_of_week|day_of_week_str|week_of_year|
+-------+---+----------------+-------+----------+-----------+----+-----+---+-----------+-----------+---------------+------------+
| SHOP_1|  2|      2018-02-01| ITEM_1|     100.0|      200.0|2018|    2|  1|         32|          4|            Thu|           5|
| SHOP_1|  1|      2018-02-01| ITEM_2|     300.0|      300.0|2018|    2|  1|         32|          4|            Thu|           5|
| SHOP_1|  4|      2018-02-10| ITEM_3|      50.0|      200.0|2018|    2| 10|         41|          6|            Sat|           6|
| SHOP_2|  1|      2018-02-02| ITEM_3|      50.0|       50.0|2018|    2|  2|         33|          5|            Fri|           5|
| SHOP_2|  1|      2018-02-11| ITEM_1|     100.0|  

We don't need to define new column to use obtained values. Here's an example of getting sales aggregated by week.

In [20]:
print("# aggregate sales by week")
total_sales_df\
    .groupBy(f.weekofyear(f.col("transaction_date")))\
    .agg(f.sum(f.col("total_sales")))\
    .show()

# aggregate sales by week
+----------------------------+----------------+
|weekofyear(transaction_date)|sum(total_sales)|
+----------------------------+----------------+
|                           6|           300.0|
|                           5|           550.0|
+----------------------------+----------------+



### ex7. Weekly sales aggregation not starting on Monday
For Spark, each week starts on Monday. But what if we want to start aggregation on a different day, for example Sunday?

In [21]:
total_sales_df\
    .groupBy(f.weekofyear(f.date_add(f.col("transaction_date"), 1)).alias("shifted_week_of_year"))\
    .agg(f.sum(f.col("total_sales")))\
    .show()

+--------------------+----------------+
|shifted_week_of_year|sum(total_sales)|
+--------------------+----------------+
|                   6|           200.0|
|                   5|           550.0|
|                   7|           100.0|
+--------------------+----------------+



### ex8. (Homework) Sales aggregation with preserving date
Sometimes we want to preserve the date of the week (for example last day) instead of year week number. Try implementing aggregation above where instead of week number there is a date of last day of given week. Try to do it without using join. 

hint: maybe *next_day()* function will be helpful.

In [22]:
total_sales_df\
    .withColumn("aggr_date", f.next_day(f.date_sub(f.col("transaction_date"), 1), "Sat"))\
    .groupBy(f.col("aggr_date"))\
    .agg(f.sum(f.col("total_sales")))\
    .withColumn("day_of_week_string",  f.date_format(f.col("aggr_date"), 'E'))\
    .show()

+----------+----------------+------------------+
| aggr_date|sum(total_sales)|day_of_week_string|
+----------+----------------+------------------+
|2018-02-10|           200.0|               Sat|
|2018-02-17|           100.0|               Sat|
|2018-02-03|           550.0|               Sat|
+----------+----------------+------------------+



# 7. Using results of one query in another

How can we get results from one query to another? We could use joins, but there are other ways. Let's say we want to add maximal price of all items to each sales row.

In [23]:
print("# Calculate global max unit_price")
item_prices_df\
    .select(f.max(item_prices_df.unit_price).alias("max_price"))\
    .show()

# Calculate global max unit_price
+---------+
|max_price|
+---------+
|     50.0|
+---------+



But how can we get that value out of the dataframe? If we only have one Row then using *first()* method will be enpough to return a Row object. In cases we have more rows then we need to use *collect()*.
Both of these methods are actions, which means they will invoke calculations. Furthermore, result of these operations will be sent directly to driver. This can be be problematic for large Dataframes.

In [24]:
max_date_row = item_prices_df\
    .select(f.max(item_prices_df.unit_price).alias("max_price"))\
    .first() # first() returns first Row, collect returns list of rows
    #.collect()[0]

print(max_date_row)

Row(max_price='50.0')


We ended up with a Row object. It has a field for each column in the Dataframe. We can extract values either by index or by name.

In [25]:
print(max_date_row[0])
print(max_date_row.max_price)
print(max_date_row['max_price'])

max_price = max_date_row.max_price

50.0
50.0
50.0


Now to include it in each sales rows, we need to add a new column of literal value using __lit()__ function.

In [26]:
print("# adding it as a literal (constant) column")
sales_with_max_global_price_df = total_sales_df\
    .withColumn("global_max_price", f.lit(max_price))

sales_with_max_global_price_df.show()

# adding it as a literal (constant) column
+-------+---+----------------+-------+----------+-----------+----------------+
|shop_id|qty|transaction_date|item_id|unit_price|total_sales|global_max_price|
+-------+---+----------------+-------+----------+-----------+----------------+
| SHOP_1|  2|      2018-02-01| ITEM_1|     100.0|      200.0|            50.0|
| SHOP_1|  1|      2018-02-01| ITEM_2|     300.0|      300.0|            50.0|
| SHOP_1|  4|      2018-02-10| ITEM_3|      50.0|      200.0|            50.0|
| SHOP_2|  1|      2018-02-02| ITEM_3|      50.0|       50.0|            50.0|
| SHOP_2|  1|      2018-02-11| ITEM_1|     100.0|      100.0|            50.0|
+-------+---+----------------+-------+----------+-----------+----------------+



### ex8. Adding constant column using cross join
Most of the times we want to avoid unnnecessary actions. They break the flow of a query, canot be optimized and require sending data over network to driver and, in our case, back again. Let's implement query above using cross join.

In [27]:
max_price_df = item_prices_df\
    .select(f.max(f.col("unit_price")).alias("max_price"))
    
total_sales_df\
    .crossJoin(f.broadcast(max_price_df))\
    .show()

+-------+---+----------------+-------+----------+-----------+---------+
|shop_id|qty|transaction_date|item_id|unit_price|total_sales|max_price|
+-------+---+----------------+-------+----------+-----------+---------+
| SHOP_1|  2|      2018-02-01| ITEM_1|     100.0|      200.0|     50.0|
| SHOP_1|  1|      2018-02-01| ITEM_2|     300.0|      300.0|     50.0|
| SHOP_1|  4|      2018-02-10| ITEM_3|      50.0|      200.0|     50.0|
| SHOP_2|  1|      2018-02-02| ITEM_3|      50.0|       50.0|     50.0|
| SHOP_2|  1|      2018-02-11| ITEM_1|     100.0|      100.0|     50.0|
+-------+---+----------------+-------+----------+-----------+---------+



You must make sure that Dataframe used in cross join has at most couple elements, if not then the number of rows will explode.

# 8. Window functions

To perform partial aggregations but preserving initial number of rows we could use joins.Let's try that to get latest transaction date for each shop.

In [46]:
max_date_by_store_df = total_sales_df\
    .groupBy(f.col("shop_id"))\
    .agg(f.max("transaction_date").alias("max_transaction_date_by_shop")) 
    
total_sales_df.join(max_date_by_store_df, ["shop_id"])\
    .show() # 

+-------+---+----------------+-------+----------+-----------+----------------------------+
|shop_id|qty|transaction_date|item_id|unit_price|total_sales|max_transaction_date_by_shop|
+-------+---+----------------+-------+----------+-----------+----------------------------+
| SHOP_1|  2|      2018-02-01| ITEM_1|     100.0|      200.0|                  2018-02-10|
| SHOP_1|  1|      2018-02-01| ITEM_2|     300.0|      300.0|                  2018-02-10|
| SHOP_1|  4|      2018-02-10| ITEM_3|      50.0|      200.0|                  2018-02-10|
| SHOP_2|  1|      2018-02-02| ITEM_3|      50.0|       50.0|                  2018-02-11|
| SHOP_2|  1|      2018-02-11| ITEM_1|     100.0|      100.0|                  2018-02-11|
+-------+---+----------------+-------+----------+-----------+----------------------------+



__Side note__: instead of join condition we passed a list with one column name. This way we avoid duplicationd of that column in join.

Another option is using __window functions__. They are experimental since spark 1.4, but they are widely used in production.

First, we need to define Window with a grouping column using __partitionBy()__ method. Then we create a new column as usual, but after invoking a function we add __.over(window)__ to say that this is not a global operation.

In [29]:
from pyspark.sql import Window

window = Window.partitionBy(f.col("shop_id"))

total_sales_df\
    .withColumn("max_transaction_date_by_shop", f.max(f.col("transaction_date")).over(window)).show()


+-------+---+----------------+-------+----------+-----------+----------------------------+
|shop_id|qty|transaction_date|item_id|unit_price|total_sales|max_transaction_date_by_shop|
+-------+---+----------------+-------+----------+-----------+----------------------------+
| SHOP_2|  1|      2018-02-02| ITEM_3|      50.0|       50.0|                  2018-02-11|
| SHOP_2|  1|      2018-02-11| ITEM_1|     100.0|      100.0|                  2018-02-11|
| SHOP_1|  2|      2018-02-01| ITEM_1|     100.0|      200.0|                  2018-02-10|
| SHOP_1|  1|      2018-02-01| ITEM_2|     300.0|      300.0|                  2018-02-10|
| SHOP_1|  4|      2018-02-10| ITEM_3|      50.0|      200.0|                  2018-02-10|
+-------+---+----------------+-------+----------+-----------+----------------------------+



In each partition partition we can specify order in which we traverse the rows. Let's use that to find ordinals for transactions for each item.

In [30]:
window_by_item_sorted = Window.partitionBy(f.col("item_id")).orderBy(f.col("transaction_date"))

total_sales_df\
    .withColumn("item_transaction_ordinal", f.rank().over(window_by_item_sorted))\
    .show()

+-------+---+----------------+-------+----------+-----------+------------------------+
|shop_id|qty|transaction_date|item_id|unit_price|total_sales|item_transaction_ordinal|
+-------+---+----------------+-------+----------+-----------+------------------------+
| SHOP_2|  1|      2018-02-02| ITEM_3|      50.0|       50.0|                       1|
| SHOP_1|  4|      2018-02-10| ITEM_3|      50.0|      200.0|                       2|
| SHOP_1|  1|      2018-02-01| ITEM_2|     300.0|      300.0|                       1|
| SHOP_1|  2|      2018-02-01| ITEM_1|     100.0|      200.0|                       1|
| SHOP_2|  1|      2018-02-11| ITEM_1|     100.0|      100.0|                       2|
+-------+---+----------------+-------+----------+-----------+------------------------+



We can specify also how large the window should be in each step using **rowsBetween()** method. We can use it to find average price from last two transactions in given shop, ordered by transaction date (like a group-level, moving average)

In [31]:
window_by_transaction_date = Window\
    .partitionBy(f.col("shop_id"))\
    .orderBy(f.col("transaction_date"))\
    .rowsBetween(-1,Window.currentRow)

total_sales_df\
    .withColumn("price_moving_average", f.mean(f.col("total_sales")).over(window_by_transaction_date))\
    .orderBy(f.col("shop_id"), f.col("transaction_date"))\
    .show()

+-------+---+----------------+-------+----------+-----------+--------------------+
|shop_id|qty|transaction_date|item_id|unit_price|total_sales|price_moving_average|
+-------+---+----------------+-------+----------+-----------+--------------------+
| SHOP_1|  1|      2018-02-01| ITEM_2|     300.0|      300.0|               250.0|
| SHOP_1|  2|      2018-02-01| ITEM_1|     100.0|      200.0|               200.0|
| SHOP_1|  4|      2018-02-10| ITEM_3|      50.0|      200.0|               250.0|
| SHOP_2|  1|      2018-02-02| ITEM_3|      50.0|       50.0|                50.0|
| SHOP_2|  1|      2018-02-11| ITEM_1|     100.0|      100.0|                75.0|
+-------+---+----------------+-------+----------+-----------+--------------------+



### ex9. Cumulative moving average of quantities
Find average quantity of items from current and all previous transactions for given item ordered by transaction date.

In [32]:
unbounded_window_by_transaction_date = Window\
    .partitionBy(f.col("item_id"))\
    .orderBy(f.col("transaction_date"))\
    .rowsBetween(Window.unboundedPreceding,Window.currentRow)
    
total_sales_df\
    .withColumn("average_qty_until_now", f.mean(f.col("qty")).over(unbounded_window_by_transaction_date))\
    .orderBy(f.col("item_id"), f.col("transaction_date"))\
    .show()


+-------+---+----------------+-------+----------+-----------+---------------------+
|shop_id|qty|transaction_date|item_id|unit_price|total_sales|average_qty_until_now|
+-------+---+----------------+-------+----------+-----------+---------------------+
| SHOP_1|  2|      2018-02-01| ITEM_1|     100.0|      200.0|                  2.0|
| SHOP_2|  1|      2018-02-11| ITEM_1|     100.0|      100.0|                  1.5|
| SHOP_1|  1|      2018-02-01| ITEM_2|     300.0|      300.0|                  1.0|
| SHOP_2|  1|      2018-02-02| ITEM_3|      50.0|       50.0|                  1.0|
| SHOP_1|  4|      2018-02-10| ITEM_3|      50.0|      200.0|                  2.5|
+-------+---+----------------+-------+----------+-----------+---------------------+



# 9. Complex aggregations

Sometimes we need more complex aggregations. Let's say we want to analyse weekly sales in each shop. We would like to get a Dataframe with one row per shop and a list of all transactions with week and year numbers.
We could try doing that using multiple invocations of collect_list function.

In [33]:
weekly_sales_by_shop_df = total_sales_df\
    .groupBy("shop_id", f.weekofyear("transaction_date").alias("week"), f.year("transaction_date").alias("year"))\
    .agg(f.sum("total_sales").alias("sales"))

print("# Sales Dataframe with week and year columnns")
weekly_sales_by_shop_df.show()
        
print("# aggregating sales with three collect_list invocations")
shop_sales_weekly_series_df = weekly_sales_by_shop_df\
    .groupBy("shop_id")\
    .agg(f.collect_list("week"),f.collect_list("year"),  f.collect_list("sales"))

shop_sales_weekly_series_df.show(truncate=False)

# Sales Dataframe with week and year columnns
+-------+----+----+-----+
|shop_id|week|year|sales|
+-------+----+----+-----+
| SHOP_2|   5|2018| 50.0|
| SHOP_1|   5|2018|500.0|
| SHOP_1|   6|2018|200.0|
| SHOP_2|   6|2018|100.0|
+-------+----+----+-----+

# aggregating sales with three collect_list invocations
+-------+------------------+------------------+-------------------+
|shop_id|collect_list(week)|collect_list(year)|collect_list(sales)|
+-------+------------------+------------------+-------------------+
|SHOP_2 |[5, 6]            |[2018, 2018]      |[50.0, 100.0]      |
|SHOP_1 |[5, 6]            |[2018, 2018]      |[500.0, 200.0]     |
+-------+------------------+------------------+-------------------+



Unfortunately solution above won't work correctly, as ordering in each column may be different. Passing list of columns to colllect_list won't work either as there is no such API: .agg(f.collect_list(["sales", "week"]))  

We can overcome that using a struct method which aggregates values for each row in a structure (similar to dict)

In [34]:
shop_sales_weekly_series_df = weekly_sales_by_shop_df\
    .groupBy("shop_id")\
    .agg(f.collect_list(f.struct(["year", "week", "sales"])).alias("sales_ts"))

shop_sales_weekly_series_df.show(truncate=False)
shop_sales_weekly_series_df.printSchema()

+-------+------------------------------------+
|shop_id|sales_ts                            |
+-------+------------------------------------+
|SHOP_2 |[[2018, 5, 50.0], [2018, 6, 100.0]] |
|SHOP_1 |[[2018, 5, 500.0], [2018, 6, 200.0]]|
+-------+------------------------------------+

root
 |-- shop_id: string (nullable = true)
 |-- sales_ts: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- year: integer (nullable = true)
 |    |    |-- week: integer (nullable = true)
 |    |    |-- sales: double (nullable = true)



Ok, we have a time series for each shop, but what if we want to have it ordered by date? We could try sorting the dataframe before aggregation. Unfortunately Spark doesn't preserve this ordering after groupBy

In [35]:
ordered_weekly_sales_df = weekly_sales_by_shop_df\
    .orderBy("shop_id", "year", "week")
  
ordered_weekly_sales_df.show()

wrongly_sorted_series_df = ordered_weekly_sales_df\
    .groupBy("shop_id")\
    .agg(f.collect_list(f.struct(["year", "week", "sales"])).alias("sales_ts"))
    
wrongly_sorted_series_df.show(truncate=False)

+-------+----+----+-----+
|shop_id|week|year|sales|
+-------+----+----+-----+
| SHOP_1|   5|2018|500.0|
| SHOP_1|   6|2018|200.0|
| SHOP_2|   5|2018| 50.0|
| SHOP_2|   6|2018|100.0|
+-------+----+----+-----+

+-------+------------------------------------+
|shop_id|sales_ts                            |
+-------+------------------------------------+
|SHOP_2 |[[2018, 5, 50.0], [2018, 6, 100.0]] |
|SHOP_1 |[[2018, 5, 500.0], [2018, 6, 200.0]]|
+-------+------------------------------------+



# 10. Defining custom UDFs

To solve that issue we need to sort the time series after aggregation. To do that we need to define a custom User Defined Functions (UDF).
Such function is invoked on each row separately. It can take as many columns as we need. It can contain any custom Python code, even from libraries available in your environment. 

Let's create a function appending a custom prefix to value from another column

In [36]:
def my_custom_function(column1):
    return "AFTER_UDF_" + str(column1)

my_custom_udf = f.udf(my_custom_function)
df_after_udf = shop_sales_weekly_series_df.withColumn("sales_ts_after_udf", my_custom_udf(f.col("sales_ts")))
df_after_udf.show()
print("# Schema of the new dataframe")
df_after_udf.printSchema()

+-------+--------------------+--------------------+
|shop_id|            sales_ts|  sales_ts_after_udf|
+-------+--------------------+--------------------+
| SHOP_2|[[2018, 5, 50.0],...|AFTER_UDF_[Row(ye...|
| SHOP_1|[[2018, 5, 500.0]...|AFTER_UDF_[Row(ye...|
+-------+--------------------+--------------------+

# Schema of the new dataframe
root
 |-- shop_id: string (nullable = true)
 |-- sales_ts: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- year: integer (nullable = true)
 |    |    |-- week: integer (nullable = true)
 |    |    |-- sales: double (nullable = true)
 |-- sales_ts_after_udf: string (nullable = true)



To use custom UDF in a plain, SQL query we need to register it in a sqlContext

In [37]:
from pyspark import SparkContext
from pyspark.sql import SQLContext

sqlContext = SQLContext(spark.sparkContext)
sqlContext.registerFunction("my_udf", my_custom_function)

spark.sql("select my_udf(shop_id) from sales").show()


+----------------+
| my_udf(shop_id)|
+----------------+
|AFTER_UDF_SHOP_1|
|AFTER_UDF_SHOP_1|
|AFTER_UDF_SHOP_1|
|AFTER_UDF_SHOP_2|
|AFTER_UDF_SHOP_2|
+----------------+



### ex10. UDF calculating sales for given transaction by multiplying qty and unit_price  
Create a UDF taking two columns (qty and unit_price) from total_sales_df and returning their product as a new column

In [38]:
def my_custom_multiply(col1, col2):
    return col1 * col2

my_custom_multiply_udf = f.udf(my_custom_multiply)

total_sales_df.withColumn("custom_udf_total_sales", my_custom_multiply(f.col("qty"), f.col("unit_price"))).show()

+-------+---+----------------+-------+----------+-----------+----------------------+
|shop_id|qty|transaction_date|item_id|unit_price|total_sales|custom_udf_total_sales|
+-------+---+----------------+-------+----------+-----------+----------------------+
| SHOP_1|  2|      2018-02-01| ITEM_1|     100.0|      200.0|                 200.0|
| SHOP_1|  1|      2018-02-01| ITEM_2|     300.0|      300.0|                 300.0|
| SHOP_1|  4|      2018-02-10| ITEM_3|      50.0|      200.0|                 200.0|
| SHOP_2|  1|      2018-02-02| ITEM_3|      50.0|       50.0|                  50.0|
| SHOP_2|  1|      2018-02-11| ITEM_1|     100.0|      100.0|                 100.0|
+-------+---+----------------+-------+----------+-----------+----------------------+



What happens if we want to return more than one column from a UDF? Let's try returning a tuple.

In [39]:
def split_shop_id(shop_id):
    s, i = shop_id.split("_")
    return s, int(i) 

split_shop_id_udf = f.udf(split_shop_id)
df_udf_no_schema = shop_sales_weekly_series_df.withColumn("shop_id_splits", split_shop_id_udf(f.col("shop_id")))
print("# Results not as expected - seems like calling toString on Object")
df_udf_no_schema.show(truncate=False)

print("# Actual inferred schema: one string instead of a tuple")
df_udf_no_schema.printSchema()

# Results not as expected - seems like calling toString on Object
+-------+------------------------------------+----------------------------+
|shop_id|sales_ts                            |shop_id_splits              |
+-------+------------------------------------+----------------------------+
|SHOP_2 |[[2018, 5, 50.0], [2018, 6, 100.0]] |[Ljava.lang.Object;@6f5dae41|
|SHOP_1 |[[2018, 5, 500.0], [2018, 6, 200.0]]|[Ljava.lang.Object;@7d6842d4|
+-------+------------------------------------+----------------------------+

# Actual inferred schema: one string instead of a tuple
root
 |-- shop_id: string (nullable = true)
 |-- sales_ts: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- year: integer (nullable = true)
 |    |    |-- week: integer (nullable = true)
 |    |    |-- sales: double (nullable = true)
 |-- shop_id_splits: string (nullable = true)



To avoid that situation we need to define a result schema for our UDF

In [40]:
from pyspark.sql.types import IntegerType, StringType, StructType, StructField

schema = StructType([StructField("s", StringType()), StructField("i", IntegerType())])
udf_with_schema = f.udf(split_shop_id, schema)

df = df_udf_no_schema.withColumn("shop_id_splits_with_schema", udf_with_schema(f.col("shop_id")))
df.show(truncate=False)
print("# New schema is correct as well")
df.printSchema()

+-------+------------------------------------+----------------------------+--------------------------+
|shop_id|sales_ts                            |shop_id_splits              |shop_id_splits_with_schema|
+-------+------------------------------------+----------------------------+--------------------------+
|SHOP_2 |[[2018, 5, 50.0], [2018, 6, 100.0]] |[Ljava.lang.Object;@361a1c0f|[SHOP, 2]                 |
|SHOP_1 |[[2018, 5, 500.0], [2018, 6, 200.0]]|[Ljava.lang.Object;@4f2eaeb |[SHOP, 1]                 |
+-------+------------------------------------+----------------------------+--------------------------+

# New schema is correct as well
root
 |-- shop_id: string (nullable = true)
 |-- sales_ts: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- year: integer (nullable = true)
 |    |    |-- week: integer (nullable = true)
 |    |    |-- sales: double (nullable = true)
 |-- shop_id_splits: string (nullable = true)
 |-- shop_id_splits_with_schema

## Creating multiple columns based on a result from UDF
In the last example we created one column with multiple values, now let's try to extract them to separate columns

This can be done using asterisk __\*__

In [41]:
df_split_shop_id = df.select("*", "shop_id_splits_with_schema.*").drop("shop_id_splits_with_schema")
df_split_shop_id.show()
print("# Schema was updated and new fields have correct types")
df_split_shop_id.printSchema()

+-------+--------------------+--------------------+----+---+
|shop_id|            sales_ts|      shop_id_splits|   s|  i|
+-------+--------------------+--------------------+----+---+
| SHOP_2|[[2018, 5, 50.0],...|[Ljava.lang.Objec...|SHOP|  2|
| SHOP_1|[[2018, 5, 500.0]...|[Ljava.lang.Objec...|SHOP|  1|
+-------+--------------------+--------------------+----+---+

# Schema was updated and new fields have correct types
root
 |-- shop_id: string (nullable = true)
 |-- sales_ts: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- year: integer (nullable = true)
 |    |    |-- week: integer (nullable = true)
 |    |    |-- sales: double (nullable = true)
 |-- shop_id_splits: string (nullable = true)
 |-- s: string (nullable = true)
 |-- i: integer (nullable = true)



Solution above will invoke UDF as many times as new column (it's a feature not a bug! https://issues.apache.org/jira/browse/SPARK-17728").

For costly UDFs (and in pySpark most of them are very costly) we have a workaround: we need to explode an array with one element - result of the UDF

In [42]:
df_split_shop_id_correct = df_udf_no_schema.withColumn("shop_id_splits_with_schema", \
                                 f.explode(f.array(udf_with_schema(f.col("shop_id")))))

df_split_shop_id_correct = df_split_shop_id_correct \
    .select("*", "shop_id_splits_with_schema.*") \
    .drop("shop_id_splits_with_schema")
df_split_shop_id_correct.show()
print("# Results and schema are the same")
df_split_shop_id_correct.printSchema()


+-------+--------------------+--------------------+----+---+
|shop_id|            sales_ts|      shop_id_splits|   s|  i|
+-------+--------------------+--------------------+----+---+
| SHOP_2|[[2018, 5, 50.0],...|[Ljava.lang.Objec...|SHOP|  2|
| SHOP_1|[[2018, 5, 500.0]...|[Ljava.lang.Objec...|SHOP|  1|
+-------+--------------------+--------------------+----+---+

# Results and schema are the same
root
 |-- shop_id: string (nullable = true)
 |-- sales_ts: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- year: integer (nullable = true)
 |    |    |-- week: integer (nullable = true)
 |    |    |-- sales: double (nullable = true)
 |-- shop_id_splits: string (nullable = true)
 |-- s: string (nullable = true)
 |-- i: integer (nullable = true)



But how do we know that this UDF will be invoked multiple times? Let's take a deeper look at execution plans of both queries.

For the first version we can see:

> +- BatchEvalPython [split_shop_id(shop_id#10), split_shop_id(shop_id#10), split_shop_id(shop_id#10)], [shop_id#10, sales_ts#3899, pythonUDF0#4442, pythonUDF1#4443, pythonUDF2#4444]

which contains multiple pythonUDF references.
                                  
For the updated solution there's only one invocation:

> +- BatchEvalPython [split_shop_id(shop_id#10)], [shop_id#10, sales_ts#3899, shop_id_splits#4155, pythonUDF0#4448]


In [43]:
df_split_shop_id.explain()
print("\n\n\n")
df_split_shop_id_correct.explain()


== Physical Plan ==
*(4) Project [shop_id#10, sales_ts#2191, pythonUDF0#2633 AS shop_id_splits#2399, pythonUDF2#2635.s AS s#2499, pythonUDF2#2635.i AS i#2500]
+- BatchEvalPython [split_shop_id(shop_id#10), split_shop_id(shop_id#10), split_shop_id(shop_id#10)], [shop_id#10, sales_ts#2191, pythonUDF0#2633, pythonUDF1#2634, pythonUDF2#2635]
   +- ObjectHashAggregate(keys=[shop_id#10], functions=[collect_list(named_struct(year, year#2037, week, week#2036, sales, sales#2045), 0, 0)])
      +- Exchange hashpartitioning(shop_id#10, 200)
         +- ObjectHashAggregate(keys=[shop_id#10], functions=[partial_collect_list(named_struct(year, year#2037, week, week#2036, sales, sales#2045), 0, 0)])
            +- *(3) HashAggregate(keys=[shop_id#10, weekofyear(cast(transaction_date#13 as date))#2636, year(cast(transaction_date#13 as date))#2637], functions=[sum(total_sales#358)])
               +- Exchange hashpartitioning(shop_id#10, weekofyear(cast(transaction_date#13 as date))#2636, year(cast(tra

### ex.11 Sort each time series in wrongly_sorted_series_df from previous exercise in descending order and compare to initial ts 
tip: use python's sorted method inside a UDF. FloatType and ArrayType imports may be usefull as well

In [44]:
from pyspark.sql.types import FloatType, ArrayType

def sort_ts(ts):
    s_ts = sorted(ts, key=lambda row: (-row.week, -row.year))
    return s_ts

sort_ts_udf = f.udf(sort_ts, ArrayType(StructType(
            [StructField("year", IntegerType()),
             StructField("week", IntegerType()),
             StructField("sales", FloatType())])))

sorted_ts_df = wrongly_sorted_series_df.withColumn("sorted_ts", sort_ts_udf(f.col("sales_ts")))

sorted_ts_df.show()


+-------+--------------------+--------------------+
|shop_id|            sales_ts|           sorted_ts|
+-------+--------------------+--------------------+
| SHOP_2|[[2018, 5, 50.0],...|[[2018, 6, 100.0]...|
| SHOP_1|[[2018, 5, 500.0]...|[[2018, 6, 200.0]...|
+-------+--------------------+--------------------+

