# 0. Environment preparation


In [None]:
import env_setup
import pyspark.sql.functions as f

spark = env_setup.getSession(local=True)

In [None]:
sales_df = spark.table("sales")
item_prices_df = spark.table("item_prices")

print("# schema of both tables")
sales_df.printSchema()
item_prices_df.printSchema()

print("# sample results from both tables")
sales_df.show()
item_prices_df.show()

print("# query execution plan for this simple select")
sales_df.explain()
item_prices_df.explain()

# 1. SQL support

In [None]:
spark.sql('select item_id,transaction_date from sales where shop_id = "SHOP_1" order by transaction_date desc')\
    .show()

### ex1. Using plain SQL query select all transactions with quantity = 1


In [None]:
# TODO

### ex2. get mean unit price for all items

In [None]:
# TODO

# 2. Dataframe operations

In [None]:
shop1_transactions_df = sales_df.select("item_id", "transaction_date")\
    .filter(f.col("shop_id") == "SHOP_1")\
    .orderBy(f.col("transaction_date").desc())

print("# The same query as above using dataframe api")
shop1_transactions_df.show()


sales_df.select(sales_df.item_id, sales_df.transaction_date)\
    .filter(sales_df.shop_id == "SHOP_1")\
    .orderBy(sales_df.transaction_date.desc())\
    .show()

print("# Execution plan of a more complex query")

shop1_transactions_df.explain()
    

### ex3. rewrite query from ex1 to dataframe operations

In [None]:
# TODO

# 3. Joins

In [None]:
print("# joins using plain SQL queries")
spark.sql('select * from sales join item_prices on sales.item_id = item_prices.item_id').show()

print("# using Dataframe API - duplicated item_id column!")
sales_df.join(item_prices_df, sales_df.item_id == item_prices_df.item_id, "inner").show()

print("# dropping redundant column")
sales_with_unit_prices_df = sales_df\
    .join(item_prices_df, sales_df.item_id == item_prices_df.item_id)\
    .drop(sales_df.item_id)
    
sales_with_unit_prices_df.show()


### ex4. Filter out excluded items

In [None]:
print("# Dataframe with column of items we would like to exclude")
excluded_items_df = spark.createDataFrame([("ITEM_2",),("ITEM_4",)], ['item'])
excluded_items_df.show()

# TODO

# 4. Adding columns

In [None]:
total_sales_df = sales_with_unit_prices_df\
    .withColumn("total_sales", f.col("qty") * f.col("unit_price"))

print("# Added new total_sales column which is a multuply of unit_price and qty")
total_sales_df.show()

print("# Adding price category column based on a condition")
sales_with_transaction_category = total_sales_df\
    .withColumn("price_category", \
                f.when(f.col("total_sales") > 150, "High")\
                .when(f.col("total_sales") < 60, "Low")\
                .otherwise("Medium"))

sales_with_transaction_category.show()

### ex5. We want to create two-packs of items, but their price must be lower than 360, choose those items.
hint: use cross join, and alias


In [None]:
# TODO

# 5. Simple aggregations

In [None]:
print("# aggregate sales by shop - ugly column name")
total_sales_df\
    .groupBy("shop_id")\
    .agg(f.sum(total_sales_df.total_sales))\
    .orderBy(f.col("sum(total_sales)")).show()
    
print("# using alias to have a better column name")
total_sales_df\
    .groupBy("shop_id")\
    .agg(f.sum(total_sales_df.total_sales).alias("sales"))\
    .orderBy(f.col("sales").desc())\
    .show()
    # .orderBy(sales_df.sales) won't work as sales_with_prices has no price column (we define it later)
    

### ex6. produce a list of all shops where each item was sold, new column should be named "shops"
hint: collect_list function

In [None]:
# TODO

# 6. Date handling

In [None]:
print("# extracting multiple elements of date")
total_sales_df\
    .withColumn("year", f.year(f.col("transaction_date")))\
    .withColumn("month", f.month(f.col("transaction_date")))\
    .withColumn("day", f.dayofmonth(f.col("transaction_date")))\
    .withColumn("day_of_year", f.dayofyear(f.col("transaction_date")))\
    .withColumn("day_of_week", f.date_format(f.col("transaction_date"), 'u'))\
    .withColumn("day_of_week_string", f.date_format(f.col("transaction_date"), 'E'))\
    .withColumn("week_of_year", f.weekofyear(f.col("transaction_date")))\
    .show()
    

print("# aggregate sales by week")
total_sales_df\
    .groupBy(f.weekofyear(f.col("transaction_date")))\
    .agg(f.sum(f.col("total_sales")))\
    .show()


### ex7. Weekly sales aggregation not starting on Monday

In [None]:
# TODO

# 7. Using results of one query in another

In [None]:
print("# Calculate global max date")
total_sales_df\
    .select(f.max(f.col("transaction_date")).alias("max_date"))\
    .show()
    
print("# Let's add it to every column using collect - calling an action")
# 1. using collect/first
max_date = total_sales_df\
    .select(f.max(f.col("transaction_date")).alias("max_date"))\
    .first()[0] #first returns first row, collect returns list of rows
    #.collect()[0][0]

print(max_date)
    
print("# adding it as a literal (constant) column")
sales_with_max_global_date_df = total_sales_df\
    .withColumn("global_max_date", f.lit(max_date))\
    .show()


### ex8. using crossJoin (doesn't require invoking action - collect)

In [None]:
# TODO

# 8. Window functions

In [None]:
print("# get max transaction date for each shop using simple aggregations")
max_date_by_store_df = total_sales_df\
    .groupBy(f.col("shop_id"))\
    .agg(f.max("transaction_date").alias("max_transaction_date_by_shop")) 
    
total_sales_df.join(max_date_by_store_df, ["shop_id"])\
    .show()
print('# careful: "shop_id" in join is not column - just a string. Can be also a list of strings.\
There\'s no need to drop column')

print("# another option is to use Windows")
print("# Note: Windows are experimental feature (even though they're available since Spark 1.4)")
from pyspark.sql import Window

window = Window.partitionBy(f.col("shop_id"))

total_sales_df\
    .withColumn("max_transaction_date_by_shop", f.max(f.col("transaction_date")).over(window)).show()
    
print("# Find ordinals for transactions for each item_id (so the oldest transaction with given item_id should be 1)")
window_by_item_sorted = Window.partitionBy(f.col("item_id")).orderBy(f.col("transaction_date"))

total_sales_df\
    .withColumn("item_transaction_ordinal", f.rank().over(window_by_item_sorted))\
    .show()
    
print("# Find average of prices from last two transactions in given shop ordered by transaction date")
window_by_transaction_date = Window\
    .partitionBy(f.col("shop_id"))\
    .orderBy(f.col("transaction_date"))\
    .rowsBetween(-1,Window.currentRow)

total_sales_df\
    .withColumn("price_moving_average", f.mean(f.col("total_sales")).over(window_by_transaction_date))\
    .orderBy(f.col("shop_id"), f.col("transaction_date"))\
    .show()


### ex9. Find average of prices from current and all previous transactions in given shop ordered by transaction date

In [None]:
# TODO

# 9. Complex aggregations

In [None]:
print("# produce weekly sales: one row per shop and a list of all transactions \
with week and year numbers for given store in one column ")

weekly_sales_by_shop_df = total_sales_df\
    .groupBy("shop_id", f.weekofyear("transaction_date").alias("week"), f.year("transaction_date").alias("year"))\
    .agg(f.sum("total_sales").alias("sales"))

print("# adding week and year columnns")
weekly_sales_by_shop_df.show()
        
print("# aggregating sales with three collect_list invocations")
shop_sales_weekly_series_df = weekly_sales_by_shop_df\
    .groupBy("shop_id")\
    .agg(f.collect_list("week"),f.collect_list("year"),  f.collect_list("sales"))

shop_sales_weekly_series_df.show(truncate=False)
print("# Solution above won't work as ordering in each column may be different")
    
# shop_sales_weekly_series_df = weekly_sales_by_shop_df\
#     .groupBy("shop_id")\
#     .agg(f.collect_list(["sales", "week"]))
# won't work, can't collect more than one column

print("# Using struct inside collect_list solves the problem")
shop_sales_weekly_series_df = weekly_sales_by_shop_df\
    .groupBy("shop_id")\
    .agg(f.collect_list(f.struct(["year", "week", "sales"])).alias("sales_ts"))
    
shop_sales_weekly_series_df.show(truncate=False)

print("# What about sorting?")
print("# we could do it before aggregation:")

ordered_weekly_sales_df = weekly_sales_by_shop_df\
    .orderBy("shop_id", "year", "week")
  
ordered_weekly_sales_df.show()

print("# And then use collect_list aggregation")
wrongly_sorted_series_df = ordered_weekly_sales_df\
    .groupBy("shop_id")\
    .agg(f.collect_list(f.struct(["year", "week", "sales"])).alias("sales_ts"))
    
wrongly_sorted_series_df.show(truncate=False)
print("# But it won't work, because collect_list may not preserve ordering!")

print("# We need to sort it for every row - and to do that we need UDFs - User Defined Functions")


# 10. Defining custom UDFs

In [None]:
def my_custom_function(column1):
    return "AFTER_UDF_" + str(column1)

my_custom_udf = f.udf(my_custom_function)
print("# Adding new column by appending string to another one")
df_after_udf = shop_sales_weekly_series_df.withColumn("sales_ts_after_udf", my_custom_udf(f.col("sales_ts")))
df_after_udf.show()
print("# Schema of the new dataframe")
df_after_udf.printSchema()

print("# We can register our UDF in catalog and use it in SQL query")
from pyspark import SparkContext
from pyspark.sql import SQLContext

sqlContext = SQLContext(spark.sparkContext)
sqlContext.registerFunction("my_udf", my_custom_function)

spark.sql("select my_udf(shop_id) from sales").show()


### ex10. Create your own UDF calculating sales for given transaction by multiplying qty and unit_price  

In [None]:
# TODO

In [None]:
from pyspark.sql.types import IntegerType, StringType, StructType, ArrayType, StructField


print("# Returning more than one value from UDF without providing result schema")
def split_shop_id(shop_id):
    s, i = shop_id.split("_")
    return s, int(i) #must be cast to int, otherwise will return null

split_shop_id_udf = f.udf(split_shop_id)
df_udf_no_schema = shop_sales_weekly_series_df.withColumn("shop_id_splits", split_shop_id_udf(f.col("shop_id")))
print("# Results not as expected - seems like calling toString on object")
df_udf_no_schema.show(truncate=False)

print("# Actual inferred schema: one string instead of a tuple")
df_udf_no_schema.printSchema()

print("# Defining correct schema with two fields")
schema = StructType([StructField("s", StringType()), StructField("i", IntegerType())])
udf_with_schema = f.udf(split_shop_id, schema)

df = df_udf_no_schema.withColumn("shop_id_splits_with_schema", udf_with_schema(f.col("shop_id")))
df.show(truncate=False)
print("# Actual schema is correct as well")
df.printSchema()


## Creating multiple columns based on a result from UDF

In [None]:
print("# Extracting all fields from returned struct can be done using asterisk *")
df_split_shop_id = df.select("*", "shop_id_splits_with_schema.*").drop("shop_id_splits_with_schema")
df_split_shop_id.show()
print("# Schema was updated and new fields have correct types")
df_split_shop_id.printSchema()

print("# Solution above will invoke UDF as many times a there are new columns created - \
it's a pySpark behaviour https://issues.apache.org/jira/browse/SPARK-17728")
print("# for costly UDF (and in pySpark most of them are very costly) we have a workaround \
to explode an array with one element - result of the UDF")
df_split_shop_id_correct = df_udf_no_schema.withColumn("shop_id_splits_with_schema", \
                                 f.explode(f.array(udf_with_schema(f.col("shop_id")))))

df_split_shop_id_correct = df_split_shop_id_correct \
    .select("*", "shop_id_splits_with_schema.*") \
    .drop("shop_id_splits_with_schema")
df_split_shop_id_correct.show()
print("# Results and schema are the same")
df_split_shop_id_correct.printSchema()


In [None]:
print("# Identifying problems with UDFs")
    
print("# To see why Spark invokes UDF multiple times let's look at query execution plan")
print("# For the first version we can see: ")
print("# +- BatchEvalPython [split_shop_id(shop_id#238), split_shop_id(shop_id#238), split_shop_id(shop_id#238)], [shop_id#238, sales_ts#4212, pythonUDF0#5025, pythonUDF1#5026, pythonUDF2#5027]")
print("# which contains multiple pythonUDF references")
print("")
print("# For the updated solution there's only one invocation: ")
print("# +- BatchEvalPython [split_shop_id(shop_id#238)], [shop_id#238, sales_ts#4212, shop_id_splits#4556, pythonUDF0#5031]")
print("")
df_split_shop_id.explain()
print("")
df_split_shop_id_correct.explain()


### ex.11 sort each time series from previous part in descending order and compare to initial ts (tip: use sorted method)

In [None]:
# TODO