# 0. Environment preparation


In [1]:
import env_setup
import pyspark.sql.functions as f

spark = env_setup.getSession(local=True)

Created local SparkSession
Created "sales" view from CSV file
Created "item_prices" view from CSV file


In [2]:
sales_df = spark.table("sales")
item_prices_df = spark.table("item_prices")

print("# schema of both tables")
sales_df.printSchema()
item_prices_df.printSchema()

print("# sample results from both tables")
sales_df.show()
item_prices_df.show()

print("# query execution plan for this simple select")
sales_df.explain()
item_prices_df.explain()

# schema of both tables
root
 |-- shop_id: string (nullable = true)
 |-- item_id: string (nullable = true)
 |-- qty: string (nullable = true)
 |-- transaction_date: string (nullable = true)

root
 |-- item_id: string (nullable = true)
 |-- unit_price: string (nullable = true)

# sample results from both tables
+-------+-------+---+----------------+
|shop_id|item_id|qty|transaction_date|
+-------+-------+---+----------------+
| SHOP_1| ITEM_1|  2|      2018-02-01|
| SHOP_1| ITEM_2|  1|      2018-02-01|
| SHOP_1| ITEM_3|  4|      2018-02-10|
| SHOP_2| ITEM_3|  1|      2018-02-02|
| SHOP_2| ITEM_1|  1|      2018-02-11|
+-------+-------+---+----------------+

+-------+----------+
|item_id|unit_price|
+-------+----------+
| ITEM_1|     100.0|
| ITEM_2|     300.0|
| ITEM_3|      50.0|
+-------+----------+

# query execution plan for this simple select
== Physical Plan ==
*FileScan csv [shop_id#12,item_id#13,qty#14,transaction_date#15] Batched: false, Format: CSV, Location: InMemoryFileIndex[

# 1. SQL support

In [3]:
spark.sql('select item_id,transaction_date from sales where shop_id = "SHOP_1" order by transaction_date desc')\
    .show()

+-------+----------------+
|item_id|transaction_date|
+-------+----------------+
| ITEM_3|      2018-02-10|
| ITEM_1|      2018-02-01|
| ITEM_2|      2018-02-01|
+-------+----------------+



### ex1. Using plain SQL query select all transactions with quantity = 1


In [4]:
spark.sql('select * from sales where qty = 1').show()

+-------+-------+---+----------------+
|shop_id|item_id|qty|transaction_date|
+-------+-------+---+----------------+
| SHOP_1| ITEM_2|  1|      2018-02-01|
| SHOP_2| ITEM_3|  1|      2018-02-02|
| SHOP_2| ITEM_1|  1|      2018-02-11|
+-------+-------+---+----------------+



### ex2. get mean unit price for all items

In [5]:
spark.sql('select mean(unit_price) from item_prices').show()

+-------------------------------+
|avg(CAST(unit_price AS DOUBLE))|
+-------------------------------+
|                          150.0|
+-------------------------------+



# 2. Dataframe operations

In [6]:
shop1_transactions_df = sales_df.select("item_id", "transaction_date")\
    .filter(f.col("shop_id") == "SHOP_1")\
    .orderBy(f.col("transaction_date").desc())

print("# The same query as above using dataframe api")
shop1_transactions_df.show()


sales_df.select(sales_df.item_id, sales_df.transaction_date)\
    .filter(sales_df.shop_id == "SHOP_1")\
    .orderBy(sales_df.transaction_date.desc())\
    .show()

print("# Execution plan of a more complex query")

shop1_transactions_df.explain()
    

# The same query as above using dataframe api
+-------+----------------+
|item_id|transaction_date|
+-------+----------------+
| ITEM_3|      2018-02-10|
| ITEM_1|      2018-02-01|
| ITEM_2|      2018-02-01|
+-------+----------------+

+-------+----------------+
|item_id|transaction_date|
+-------+----------------+
| ITEM_3|      2018-02-10|
| ITEM_1|      2018-02-01|
| ITEM_2|      2018-02-01|
+-------+----------------+

# Execution plan of a more complex query
== Physical Plan ==
*Sort [transaction_date#15 DESC NULLS LAST], true, 0
+- Exchange rangepartitioning(transaction_date#15 DESC NULLS LAST, 200)
   +- *Project [item_id#13, transaction_date#15]
      +- *Filter (isnotnull(shop_id#12) && (shop_id#12 = SHOP_1))
         +- *FileScan csv [shop_id#12,item_id#13,transaction_date#15] Batched: false, Format: CSV, Location: InMemoryFileIndex[file:/Users/mkromka/dev/trainings/pySpark_workshop/data/sales.csv], PartitionFilters: [], PushedFilters: [IsNotNull(shop_id), EqualTo(shop_id,SHOP

### ex3. rewrite query from ex1 to dataframe operations

In [7]:
sales_df.filter(f.col("qty") == 1).show()

+-------+-------+---+----------------+
|shop_id|item_id|qty|transaction_date|
+-------+-------+---+----------------+
| SHOP_1| ITEM_2|  1|      2018-02-01|
| SHOP_2| ITEM_3|  1|      2018-02-02|
| SHOP_2| ITEM_1|  1|      2018-02-11|
+-------+-------+---+----------------+



# 3. Joins

In [8]:
print("# joins using plain SQL queries")
spark.sql('select * from sales join item_prices on sales.item_id = item_prices.item_id').show()

print("# using Dataframe API - duplicated item_id column!")
sales_df.join(item_prices_df, sales_df.item_id == item_prices_df.item_id, "inner").show()

print("# dropping redundant column")
sales_with_unit_prices_df = sales_df\
    .join(item_prices_df, sales_df.item_id == item_prices_df.item_id)\
    .drop(sales_df.item_id)
    
sales_with_unit_prices_df.show()


# joins using plain SQL queries
+-------+-------+---+----------------+-------+----------+
|shop_id|item_id|qty|transaction_date|item_id|unit_price|
+-------+-------+---+----------------+-------+----------+
| SHOP_1| ITEM_1|  2|      2018-02-01| ITEM_1|     100.0|
| SHOP_1| ITEM_2|  1|      2018-02-01| ITEM_2|     300.0|
| SHOP_1| ITEM_3|  4|      2018-02-10| ITEM_3|      50.0|
| SHOP_2| ITEM_3|  1|      2018-02-02| ITEM_3|      50.0|
| SHOP_2| ITEM_1|  1|      2018-02-11| ITEM_1|     100.0|
+-------+-------+---+----------------+-------+----------+

# using Dataframe API - duplicated item_id column!
+-------+-------+---+----------------+-------+----------+
|shop_id|item_id|qty|transaction_date|item_id|unit_price|
+-------+-------+---+----------------+-------+----------+
| SHOP_1| ITEM_1|  2|      2018-02-01| ITEM_1|     100.0|
| SHOP_1| ITEM_2|  1|      2018-02-01| ITEM_2|     300.0|
| SHOP_1| ITEM_3|  4|      2018-02-10| ITEM_3|      50.0|
| SHOP_2| ITEM_3|  1|      2018-02-02| ITEM_3|

### ex4. Filter out excluded items

In [9]:
print("# Dataframe with column of items we would like to exclude")
excluded_items_df = spark.createDataFrame([("ITEM_2",),("ITEM_4",)], ['item'])
excluded_items_df.show()

print("# using join and filtering - result doesn't contain ITEM_2 and ITEM_4 ")
sales_df.join(excluded_items_df, sales_df.item_id == excluded_items_df.item, "left_outer")\
    .filter(f.isnull(excluded_items_df.item))\
    .drop(excluded_items_df.item)\
    .show()
    
print("# better option: anti join")
sales_df.join(excluded_items_df, sales_df.item_id == excluded_items_df.item, "left_anti")\
    .show()


# Dataframe with column of items we would like to exclude
+------+
|  item|
+------+
|ITEM_2|
|ITEM_4|
+------+

# using join and filtering - result doesn't contain ITEM_2 and ITEM_4 
+-------+-------+---+----------------+
|shop_id|item_id|qty|transaction_date|
+-------+-------+---+----------------+
| SHOP_1| ITEM_3|  4|      2018-02-10|
| SHOP_2| ITEM_3|  1|      2018-02-02|
| SHOP_1| ITEM_1|  2|      2018-02-01|
| SHOP_2| ITEM_1|  1|      2018-02-11|
+-------+-------+---+----------------+

# better option: anti join
+-------+-------+---+----------------+
|shop_id|item_id|qty|transaction_date|
+-------+-------+---+----------------+
| SHOP_1| ITEM_3|  4|      2018-02-10|
| SHOP_2| ITEM_3|  1|      2018-02-02|
| SHOP_1| ITEM_1|  2|      2018-02-01|
| SHOP_2| ITEM_1|  1|      2018-02-11|
+-------+-------+---+----------------+



# 4. Adding columns

In [10]:
total_sales_df = sales_with_unit_prices_df\
    .withColumn("total_sales", f.col("qty") * f.col("unit_price"))

print("# Added new total_sales column which is a multuply of unit_price and qty")
total_sales_df.show()

print("# Adding price category column based on a condition")
sales_with_transaction_category = total_sales_df\
    .withColumn("price_category", \
                f.when(f.col("total_sales") > 150, "High")\
                .when(f.col("total_sales") < 60, "Low")\
                .otherwise("Medium"))

sales_with_transaction_category.show()

# Added new total_sales column which is a multuply of unit_price and qty
+-------+---+----------------+-------+----------+-----------+
|shop_id|qty|transaction_date|item_id|unit_price|total_sales|
+-------+---+----------------+-------+----------+-----------+
| SHOP_1|  2|      2018-02-01| ITEM_1|     100.0|      200.0|
| SHOP_1|  1|      2018-02-01| ITEM_2|     300.0|      300.0|
| SHOP_1|  4|      2018-02-10| ITEM_3|      50.0|      200.0|
| SHOP_2|  1|      2018-02-02| ITEM_3|      50.0|       50.0|
| SHOP_2|  1|      2018-02-11| ITEM_1|     100.0|      100.0|
+-------+---+----------------+-------+----------+-----------+

# Adding price category column based on a condition
+-------+---+----------------+-------+----------+-----------+--------------+
|shop_id|qty|transaction_date|item_id|unit_price|total_sales|price_category|
+-------+---+----------------+-------+----------+-----------+--------------+
| SHOP_1|  2|      2018-02-01| ITEM_1|     100.0|      200.0|          High|
| SHOP_1

### ex5. We want to create two-packs of items, but their price must be lower than 360, choose those items.
hint: use cross join, and alias


In [11]:
item_prices_df.alias("items1")\
    .crossJoin(item_prices_df.alias("items2"))\
    .withColumn("price_sum",f.col("items1.unit_price") + f.col("items2.unit_price"))\
    .where((f.col("price_sum") < 360) & (f.col("items1.item_id") != f.col("items2.item_id")))\
    .select("items1.item_id", "items2.item_id", "price_sum")\
    .show()

+-------+-------+---------+
|item_id|item_id|price_sum|
+-------+-------+---------+
| ITEM_1| ITEM_3|    150.0|
| ITEM_2| ITEM_3|    350.0|
| ITEM_3| ITEM_1|    150.0|
| ITEM_3| ITEM_2|    350.0|
+-------+-------+---------+



# 5. Simple aggregations

In [12]:
print("# aggregate sales by shop - ugly column name")
total_sales_df\
    .groupBy("shop_id")\
    .agg(f.sum(total_sales_df.total_sales))\
    .orderBy(f.col("sum(total_sales)")).show()
    
print("# using alias to have a better column name")
total_sales_df\
    .groupBy("shop_id")\
    .agg(f.sum(total_sales_df.total_sales).alias("sales"))\
    .orderBy(f.col("sales").desc())\
    .show()
    # .orderBy(sales_df.sales) won't work as sales_with_prices has no price column (we define it later)
    

# aggregate sales by shop - ugly column name
+-------+----------------+
|shop_id|sum(total_sales)|
+-------+----------------+
| SHOP_2|           150.0|
| SHOP_1|           700.0|
+-------+----------------+

# using alias to have a better column name
+-------+-----+
|shop_id|sales|
+-------+-----+
| SHOP_1|700.0|
| SHOP_2|150.0|
+-------+-----+



### ex6. produce a list of all shops where each item was sold, new column should be named "shops"
hint: collect_list function

In [13]:
total_sales_df\
    .groupBy("item_id")\
    .agg(f.collect_list(f.col("shop_id")).alias("shops"))\
    .show()

+-------+----------------+
|item_id|           shops|
+-------+----------------+
| ITEM_3|[SHOP_1, SHOP_2]|
| ITEM_2|        [SHOP_1]|
| ITEM_1|[SHOP_1, SHOP_2]|
+-------+----------------+



# 6. Date handling

In [14]:
print("# extracting multiple elements of date")
total_sales_df\
    .withColumn("year", f.year(f.col("transaction_date")))\
    .withColumn("month", f.month(f.col("transaction_date")))\
    .withColumn("day", f.dayofmonth(f.col("transaction_date")))\
    .withColumn("day_of_year", f.dayofyear(f.col("transaction_date")))\
    .withColumn("day_of_week", f.date_format(f.col("transaction_date"), 'u'))\
    .withColumn("day_of_week_string", f.date_format(f.col("transaction_date"), 'E'))\
    .withColumn("week_of_year", f.weekofyear(f.col("transaction_date")))\
    .show()
    

print("# aggregate sales by week")
total_sales_df\
    .groupBy(f.weekofyear(f.col("transaction_date")))\
    .agg(f.sum(f.col("total_sales")))\
    .show()


# extracting multiple elements of date
+-------+---+----------------+-------+----------+-----------+----+-----+---+-----------+-----------+------------------+------------+
|shop_id|qty|transaction_date|item_id|unit_price|total_sales|year|month|day|day_of_year|day_of_week|day_of_week_string|week_of_year|
+-------+---+----------------+-------+----------+-----------+----+-----+---+-----------+-----------+------------------+------------+
| SHOP_1|  2|      2018-02-01| ITEM_1|     100.0|      200.0|2018|    2|  1|         32|          4|               Thu|           5|
| SHOP_1|  1|      2018-02-01| ITEM_2|     300.0|      300.0|2018|    2|  1|         32|          4|               Thu|           5|
| SHOP_1|  4|      2018-02-10| ITEM_3|      50.0|      200.0|2018|    2| 10|         41|          6|               Sat|           6|
| SHOP_2|  1|      2018-02-02| ITEM_3|      50.0|       50.0|2018|    2|  2|         33|          5|               Fri|           5|
| SHOP_2|  1|      2018-02-11|

### ex7. Weekly sales aggregation not starting on Monday

In [15]:
total_sales_df\
    .withColumn("transaction_date_moved", f.date_add(f.col("transaction_date"), 1))\
    .groupBy(f.weekofyear(f.col("transaction_date_moved")))\
    .agg(f.sum(f.col("total_sales")))\
    .show()

# Unfortunately week_of_year column will have incorrect values (shifted by 1 day)
# but that's not a problem for calculations that require only ordering

# Different solution where we preserve last day of every week 
#"Sat" can be seen as a day where week ends
total_sales_df\
    .withColumn("aggr_date", f.next_day(f.date_sub(f.col("transaction_date"), 1), "Sat"))\
    .groupBy(f.col("aggr_date"))\
    .agg(f.sum(f.col("total_sales")))\
    .withColumn("day_of_week_string",  f.date_format(f.col("aggr_date"), 'E'))\
    .show()

+----------------------------------+----------------+
|weekofyear(transaction_date_moved)|sum(total_sales)|
+----------------------------------+----------------+
|                                 6|           200.0|
|                                 5|           550.0|
|                                 7|           100.0|
+----------------------------------+----------------+

+----------+----------------+------------------+
| aggr_date|sum(total_sales)|day_of_week_string|
+----------+----------------+------------------+
|2018-02-10|           200.0|               Sat|
|2018-02-17|           100.0|               Sat|
|2018-02-03|           550.0|               Sat|
+----------+----------------+------------------+



# 7. Using results of one query in another

In [16]:
print("# Calculate global max date")
total_sales_df\
    .select(f.max(f.col("transaction_date")).alias("max_date"))\
    .show()
    
print("# Let's add it to every column using collect - calling an action")
# 1. using collect/first
max_date = total_sales_df\
    .select(f.max(f.col("transaction_date")).alias("max_date"))\
    .first()[0] #first returns first row, collect returns list of rows
    #.collect()[0][0]

print(max_date)
    
print("# adding it as a literal (constant) column")
sales_with_max_global_date_df = total_sales_df\
    .withColumn("global_max_date", f.lit(max_date))\
    .show()


# Calculate global max date
+----------+
|  max_date|
+----------+
|2018-02-11|
+----------+

# Let's add it to every column using collect - calling an action
2018-02-11
# adding it as a literal (constant) column
+-------+---+----------------+-------+----------+-----------+---------------+
|shop_id|qty|transaction_date|item_id|unit_price|total_sales|global_max_date|
+-------+---+----------------+-------+----------+-----------+---------------+
| SHOP_1|  2|      2018-02-01| ITEM_1|     100.0|      200.0|     2018-02-11|
| SHOP_1|  1|      2018-02-01| ITEM_2|     300.0|      300.0|     2018-02-11|
| SHOP_1|  4|      2018-02-10| ITEM_3|      50.0|      200.0|     2018-02-11|
| SHOP_2|  1|      2018-02-02| ITEM_3|      50.0|       50.0|     2018-02-11|
| SHOP_2|  1|      2018-02-11| ITEM_1|     100.0|      100.0|     2018-02-11|
+-------+---+----------------+-------+----------+-----------+---------------+



### ex8. using crossJoin (doesn't require invoking action - collect)

In [17]:
max_date_df = total_sales_df\
    .select(f.max(f.col("transaction_date")).alias("max_date"))
    
sales_with_max_global_date_cross_join_df = total_sales_df\
    .crossJoin(f.broadcast(max_date_df))\
    .show()
print("# make sure DF inside cross join has only one element, if not then we'll have too many rows")


+-------+---+----------------+-------+----------+-----------+----------+
|shop_id|qty|transaction_date|item_id|unit_price|total_sales|  max_date|
+-------+---+----------------+-------+----------+-----------+----------+
| SHOP_1|  2|      2018-02-01| ITEM_1|     100.0|      200.0|2018-02-11|
| SHOP_1|  1|      2018-02-01| ITEM_2|     300.0|      300.0|2018-02-11|
| SHOP_1|  4|      2018-02-10| ITEM_3|      50.0|      200.0|2018-02-11|
| SHOP_2|  1|      2018-02-02| ITEM_3|      50.0|       50.0|2018-02-11|
| SHOP_2|  1|      2018-02-11| ITEM_1|     100.0|      100.0|2018-02-11|
+-------+---+----------------+-------+----------+-----------+----------+

# make sure DF inside cross join has only one element, if not then we'll have too many rows


# 8. Window functions

In [18]:
print("# get max transaction date for each shop using simple aggregations")
max_date_by_store_df = total_sales_df\
    .groupBy(f.col("shop_id"))\
    .agg(f.max("transaction_date").alias("max_transaction_date_by_shop")) 
    
total_sales_df.join(max_date_by_store_df, ["shop_id"])\
    .show()
print('# careful: "shop_id" in join is not column - just a string. Can be also a list of strings.\
There\'s no need to drop column')

print("# another option is to use Windows")
print("# Note: Windows are experimental feature (even though they're available since Spark 1.4)")
from pyspark.sql import Window

window = Window.partitionBy(f.col("shop_id"))

total_sales_df\
    .withColumn("max_transaction_date_by_shop", f.max(f.col("transaction_date")).over(window)).show()
    
print("# Find ordinals for transactions for each item_id (so the oldest transaction with given item_id should be 1)")
window_by_item_sorted = Window.partitionBy(f.col("item_id")).orderBy(f.col("transaction_date"))

total_sales_df\
    .withColumn("item_transaction_ordinal", f.rank().over(window_by_item_sorted))\
    .show()
    
print("# Find average of prices from last two transactions in given shop ordered by transaction date")
window_by_transaction_date = Window\
    .partitionBy(f.col("shop_id"))\
    .orderBy(f.col("transaction_date"))\
    .rowsBetween(-1,Window.currentRow)

total_sales_df\
    .withColumn("price_moving_average", f.mean(f.col("total_sales")).over(window_by_transaction_date))\
    .orderBy(f.col("shop_id"), f.col("transaction_date"))\
    .show()


# get max transaction date for each shop using simple aggregations
+-------+---+----------------+-------+----------+-----------+----------------------------+
|shop_id|qty|transaction_date|item_id|unit_price|total_sales|max_transaction_date_by_shop|
+-------+---+----------------+-------+----------+-----------+----------------------------+
| SHOP_1|  2|      2018-02-01| ITEM_1|     100.0|      200.0|                  2018-02-10|
| SHOP_1|  1|      2018-02-01| ITEM_2|     300.0|      300.0|                  2018-02-10|
| SHOP_1|  4|      2018-02-10| ITEM_3|      50.0|      200.0|                  2018-02-10|
| SHOP_2|  1|      2018-02-02| ITEM_3|      50.0|       50.0|                  2018-02-11|
| SHOP_2|  1|      2018-02-11| ITEM_1|     100.0|      100.0|                  2018-02-11|
+-------+---+----------------+-------+----------+-----------+----------------------------+

# careful: "shop_id" in join is not column - just a string. Can be also a list of strings.There's no need to drop

### ex9. Find average of prices from current and all previous transactions in given shop ordered by transaction date

In [19]:
unbounded_window_by_transaction_date = Window\
    .partitionBy(f.col("shop_id"))\
    .orderBy(f.col("transaction_date"))\
    .rowsBetween(Window.unboundedPreceding,Window.currentRow)
    
total_sales_df\
    .withColumn("average_price_until_now", f.mean(f.col("total_sales")).over(unbounded_window_by_transaction_date))\
    .orderBy(f.col("shop_id"), f.col("transaction_date"))\
    .show()


+-------+---+----------------+-------+----------+-----------+-----------------------+
|shop_id|qty|transaction_date|item_id|unit_price|total_sales|average_price_until_now|
+-------+---+----------------+-------+----------+-----------+-----------------------+
| SHOP_1|  2|      2018-02-01| ITEM_1|     100.0|      200.0|                  200.0|
| SHOP_1|  1|      2018-02-01| ITEM_2|     300.0|      300.0|                  250.0|
| SHOP_1|  4|      2018-02-10| ITEM_3|      50.0|      200.0|     233.33333333333334|
| SHOP_2|  1|      2018-02-02| ITEM_3|      50.0|       50.0|                   50.0|
| SHOP_2|  1|      2018-02-11| ITEM_1|     100.0|      100.0|                   75.0|
+-------+---+----------------+-------+----------+-----------+-----------------------+



# 9. Complex aggregations

In [20]:
print("# produce weekly sales: one row per shop and a list of all transactions \
with week and year numbers for given store in one column ")

weekly_sales_by_shop_df = total_sales_df\
    .groupBy("shop_id", f.weekofyear("transaction_date").alias("week"), f.year("transaction_date").alias("year"))\
    .agg(f.sum("total_sales").alias("sales"))

print("# adding week and year columnns")
weekly_sales_by_shop_df.show()
        
print("# aggregating sales with three collect_list invocations")
shop_sales_weekly_series_df = weekly_sales_by_shop_df\
    .groupBy("shop_id")\
    .agg(f.collect_list("week"),f.collect_list("year"),  f.collect_list("sales"))

shop_sales_weekly_series_df.show(truncate=False)
print("# Solution above won't work as ordering in each column may be different")
    
# shop_sales_weekly_series_df = weekly_sales_by_shop_df\
#     .groupBy("shop_id")\
#     .agg(f.collect_list(["sales", "week"]))
# won't work, can't collect more than one column

print("# Using struct inside collect_list solves the problem")
shop_sales_weekly_series_df = weekly_sales_by_shop_df\
    .groupBy("shop_id")\
    .agg(f.collect_list(f.struct(["year", "week", "sales"])).alias("sales_ts"))
    
shop_sales_weekly_series_df.show(truncate=False)

print("# What about sorting?")
print("# we could do it before aggregation:")

ordered_weekly_sales_df = weekly_sales_by_shop_df\
    .orderBy("shop_id", "year", "week")
  
ordered_weekly_sales_df.show()

print("# And then use collect_list aggregation")
wrongly_sorted_series_df = ordered_weekly_sales_df\
    .groupBy("shop_id")\
    .agg(f.collect_list(f.struct(["year", "week", "sales"])).alias("sales_ts"))
    
wrongly_sorted_series_df.show(truncate=False)
print("# But it won't work, because collect_list may not preserve ordering!")

print("# We need to sort it for every row - and to do that we need UDFs - User Defined Functions")


# produce weekly sales: one row per shop and a list of all transactions with week and year numbers for given store in one column 
# adding week and year columnns
+-------+----+----+-----+
|shop_id|week|year|sales|
+-------+----+----+-----+
| SHOP_2|   5|2018| 50.0|
| SHOP_1|   5|2018|500.0|
| SHOP_1|   6|2018|200.0|
| SHOP_2|   6|2018|100.0|
+-------+----+----+-----+

# aggregating sales with three collect_list invocations
+-------+------------------+------------------+-------------------+
|shop_id|collect_list(week)|collect_list(year)|collect_list(sales)|
+-------+------------------+------------------+-------------------+
|SHOP_2 |[5, 6]            |[2018, 2018]      |[50.0, 100.0]      |
|SHOP_1 |[5, 6]            |[2018, 2018]      |[500.0, 200.0]     |
+-------+------------------+------------------+-------------------+

# Solution above won't work as ordering in each column may be different
# Using struct inside collect_list solves the problem
+-------+-----------------------------

# 10. Defining custom UDFs

In [21]:
def my_custom_function(column1):
    return "AFTER_UDF_" + str(column1)

my_custom_udf = f.udf(my_custom_function)
print("# Adding new column by appending string to another one")
df_after_udf = shop_sales_weekly_series_df.withColumn("sales_ts_after_udf", my_custom_udf(f.col("sales_ts")))
df_after_udf.show()
print("# Schema of the new dataframe")
df_after_udf.printSchema()

print("# We can register our UDF in catalog and use it in SQL query")
from pyspark import SparkContext
from pyspark.sql import SQLContext

sqlContext = SQLContext(spark.sparkContext)
sqlContext.registerFunction("my_udf", my_custom_function)

spark.sql("select my_udf(shop_id) from sales").show()


# Adding new column by appending string to another one
+-------+--------------------+--------------------+
|shop_id|            sales_ts|  sales_ts_after_udf|
+-------+--------------------+--------------------+
| SHOP_2|[[2018,5,50.0], [...|AFTER_UDF_[Row(ye...|
| SHOP_1|[[2018,5,500.0], ...|AFTER_UDF_[Row(ye...|
+-------+--------------------+--------------------+

# Schema of the new dataframe
root
 |-- shop_id: string (nullable = true)
 |-- sales_ts: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- year: integer (nullable = true)
 |    |    |-- week: integer (nullable = true)
 |    |    |-- sales: double (nullable = true)
 |-- sales_ts_after_udf: string (nullable = true)

# We can register our UDF in catalog and use it in SQL query
+----------------+
| my_udf(shop_id)|
+----------------+
|AFTER_UDF_SHOP_1|
|AFTER_UDF_SHOP_1|
|AFTER_UDF_SHOP_1|
|AFTER_UDF_SHOP_2|
|AFTER_UDF_SHOP_2|
+----------------+



### ex10. Create your own UDF calculating sales for given transaction by multiplying qty and unit_price  

In [23]:
def my_custom_multiply(col1, col2):
    return col1 * col2

my_custom_multiply_udf = f.udf(my_custom_multiply)

total_sales_df.withColumn("custom_udf_total_sales", my_custom_multiply(f.col("qty"), f.col("unit_price"))).show()

+-------+---+----------------+-------+----------+-----------+----------------------+
|shop_id|qty|transaction_date|item_id|unit_price|total_sales|custom_udf_total_sales|
+-------+---+----------------+-------+----------+-----------+----------------------+
| SHOP_1|  2|      2018-02-01| ITEM_1|     100.0|      200.0|                 200.0|
| SHOP_1|  1|      2018-02-01| ITEM_2|     300.0|      300.0|                 300.0|
| SHOP_1|  4|      2018-02-10| ITEM_3|      50.0|      200.0|                 200.0|
| SHOP_2|  1|      2018-02-02| ITEM_3|      50.0|       50.0|                  50.0|
| SHOP_2|  1|      2018-02-11| ITEM_1|     100.0|      100.0|                 100.0|
+-------+---+----------------+-------+----------+-----------+----------------------+



In [24]:
from pyspark.sql.types import IntegerType, StringType, StructType, ArrayType, StructField


print("# Returning more than one value from UDF without providing result schema")
def split_shop_id(shop_id):
    s, i = shop_id.split("_")
    return s, int(i) #must be cast to int, otherwise will return null

split_shop_id_udf = f.udf(split_shop_id)
df_udf_no_schema = shop_sales_weekly_series_df.withColumn("shop_id_splits", split_shop_id_udf(f.col("shop_id")))
print("# Results not as expected - seems like calling toString on object")
df_udf_no_schema.show(truncate=False)

print("# Actual inferred schema: one string instead of a tuple")
df_udf_no_schema.printSchema()

print("# Defining correct schema with two fields")
schema = StructType([StructField("s", StringType()), StructField("i", IntegerType())])
udf_with_schema = f.udf(split_shop_id, schema)

df = df_udf_no_schema.withColumn("shop_id_splits_with_schema", udf_with_schema(f.col("shop_id")))
df.show(truncate=False)
print("# Actual schema is correct as well")
df.printSchema()


# Returning more than one value from UDF without providing result schema
# Results not as expected - seems like calling toString on object
+-------+--------------------------------+----------------------------+
|shop_id|sales_ts                        |shop_id_splits              |
+-------+--------------------------------+----------------------------+
|SHOP_2 |[[2018,5,50.0], [2018,6,100.0]] |[Ljava.lang.Object;@2c6b5f2b|
|SHOP_1 |[[2018,5,500.0], [2018,6,200.0]]|[Ljava.lang.Object;@ea02eb8 |
+-------+--------------------------------+----------------------------+

# Actual inferred schema: one string instead of a tuple
root
 |-- shop_id: string (nullable = true)
 |-- sales_ts: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- year: integer (nullable = true)
 |    |    |-- week: integer (nullable = true)
 |    |    |-- sales: double (nullable = true)
 |-- shop_id_splits: string (nullable = true)

# Defining correct schema with two fields
+-------+--

## Creating multiple columns based on a result from UDF

In [25]:
print("# Extracting all fields from returned struct can be done using asterisk *")
df_split_shop_id = df.select("*", "shop_id_splits_with_schema.*").drop("shop_id_splits_with_schema")
df_split_shop_id.show()
print("# Schema was updated and new fields have correct types")
df_split_shop_id.printSchema()

print("# Solution above will invoke UDF as many times a there are new columns created - \
it's a pySpark behaviour https://issues.apache.org/jira/browse/SPARK-17728")
print("# for costly UDF (and in pySpark most of them are very costly) we have a workaround \
to explode an array with one element - result of the UDF")
df_split_shop_id_correct = df_udf_no_schema.withColumn("shop_id_splits_with_schema", \
                                 f.explode(f.array(udf_with_schema(f.col("shop_id")))))

df_split_shop_id_correct = df_split_shop_id_correct \
    .select("*", "shop_id_splits_with_schema.*") \
    .drop("shop_id_splits_with_schema")
df_split_shop_id_correct.show()
print("# Results and schema are the same")
df_split_shop_id_correct.printSchema()


# Extracting all fields from returned struct can be done using asterisk *
+-------+--------------------+--------------------+----+---+
|shop_id|            sales_ts|      shop_id_splits|   s|  i|
+-------+--------------------+--------------------+----+---+
| SHOP_2|[[2018,5,50.0], [...|[Ljava.lang.Objec...|SHOP|  2|
| SHOP_1|[[2018,5,500.0], ...|[Ljava.lang.Objec...|SHOP|  1|
+-------+--------------------+--------------------+----+---+

# Schema was updated and new fields have correct types
root
 |-- shop_id: string (nullable = true)
 |-- sales_ts: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- year: integer (nullable = true)
 |    |    |-- week: integer (nullable = true)
 |    |    |-- sales: double (nullable = true)
 |-- shop_id_splits: string (nullable = true)
 |-- s: string (nullable = true)
 |-- i: integer (nullable = true)

# Solution above will invoke UDF as many times a there are new columns created - it's a pySpark behaviour https://issu

In [26]:
print("# Identifying problems with UDFs")
    
print("# To see why Spark invokes UDF multiple times let's look at query execution plan")
print("# For the first version we can see: ")
print("# +- BatchEvalPython [split_shop_id(shop_id#238), split_shop_id(shop_id#238), split_shop_id(shop_id#238)], [shop_id#238, sales_ts#4212, pythonUDF0#5025, pythonUDF1#5026, pythonUDF2#5027]")
print("# which contains multiple pythonUDF references")
print("")
print("# For the updated solution there's only one invocation: ")
print("# +- BatchEvalPython [split_shop_id(shop_id#238)], [shop_id#238, sales_ts#4212, shop_id_splits#4556, pythonUDF0#5031]")
print("")
df_split_shop_id.explain()
print("")
df_split_shop_id_correct.explain()


# Identifying problems with UDFs
# To see why Spark invokes UDF multiple times let's look at query execution plan
# For the first version we can see: 
# +- BatchEvalPython [split_shop_id(shop_id#238), split_shop_id(shop_id#238), split_shop_id(shop_id#238)], [shop_id#238, sales_ts#4212, pythonUDF0#5025, pythonUDF1#5026, pythonUDF2#5027]
# which contains multiple pythonUDF references

# For the updated solution there's only one invocation: 
# +- BatchEvalPython [split_shop_id(shop_id#238)], [shop_id#238, sales_ts#4212, shop_id_splits#4556, pythonUDF0#5031]

== Physical Plan ==
*Project [shop_id#12, sales_ts#2049, pythonUDF0#2441 AS shop_id_splits#2230, pythonUDF2#2443.s AS s#2319, pythonUDF2#2443.i AS i#2320]
+- BatchEvalPython [split_shop_id(shop_id#12), split_shop_id(shop_id#12), split_shop_id(shop_id#12)], [shop_id#12, sales_ts#2049, pythonUDF0#2441, pythonUDF1#2442, pythonUDF2#2443]
   +- ObjectHashAggregate(keys=[shop_id#12], functions=[collect_list(named_struct(year, year#1904, wee

### ex.11 sort each time series from previous part in descending order and compare to initial ts (tip: use sorted method)

In [27]:
from pyspark.sql.types import FloatType, ArrayType

def sort_ts(ts):
    s_ts = sorted(ts, key=lambda row: (-row.week, -row.year))
    return s_ts

sort_ts_udf = f.udf(sort_ts, ArrayType(StructType(
            [StructField("year", IntegerType()),
             StructField("week", IntegerType()),
             StructField("sales", FloatType())])))

sorted_ts_df = wrongly_sorted_series_df.withColumn("sorted_ts", sort_ts_udf(f.col("sales_ts")))

sorted_ts_df.show()


+-------+--------------------+--------------------+
|shop_id|            sales_ts|           sorted_ts|
+-------+--------------------+--------------------+
| SHOP_2|[[2018,5,50.0], [...|[[2018,6,100.0], ...|
| SHOP_1|[[2018,5,500.0], ...|[[2018,6,200.0], ...|
+-------+--------------------+--------------------+

