# Spark SQL Tutorial
- Author: Akira Takihara Wang (https://github.com/akiratwang)
- Tutorial Up-to-Date as of: April 2021  
- Usage: For MAST30034 students only  


In [1]:
from pyspark import SparkContext
from pyspark.sql import SparkSession

# Start the spark context
sc = SparkContext.getOrCreate(conf=swan_spark_conf)

# create a spark session (which will run spark jobs)
spark = SparkSession.builder.getOrCreate()

# apply settings to session
spark.conf.set('spark.sql.repl.eagerEval.enabled', True)
spark.conf.set('spark.sql.execution.arrow.pyspark.enabled', True)

Previously, we saw how to create a `schema` using PySpark code. If you are more familiar with SQL, then you may want to use a Data Definition Language (DDL) string to create a schema. 

Syntax:
```python
ddl = """
`COLUMN 1` DATATYPE1, `COLUMN 2` DATATYPE2, ...
"""
```

Although it may make sense to pass through the datetimes as type `TIMESTAMP`, the format may be inconsistent. As an example, we will assume the worst case and "manually" fix it using a function for your benefit.

In [11]:
# create a DDL
schema = """
`VendorID` INT, `tpep_pickup_datetime` STRING, `tpep_dropoff_datetime` STRING,
`passenger_count` INT, `trip_distance` DOUBLE, `pickup_longitude` DOUBLE, `pickup_latitude` DOUBLE,
`RateCodeID` INT, `store_and_fwd_flag` STRING, `dropoff_longitude` DOUBLE, `dropoff_latitude` DOUBLE,
`payment_type` INT, `fare_amount` DOUBLE, `extra` DOUBLE, `mta_tax` DOUBLE, `tip_amount` DOUBLE,
`tolls_amount` DOUBLE, `improvement_surcharge` DOUBLE, `total_amount` DOUBLE
"""

Like PySpark, you can simply pass it through the `schema` parameter.

In [12]:
sdf = spark.read.csv('../data/sample.csv', header=True, schema=schema)
sdf.printSchema()

root
 |-- VendorID: integer (nullable = true)
 |-- tpep_pickup_datetime: string (nullable = true)
 |-- tpep_dropoff_datetime: string (nullable = true)
 |-- passenger_count: integer (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- pickup_longitude: double (nullable = true)
 |-- pickup_latitude: double (nullable = true)
 |-- RateCodeID: integer (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- dropoff_longitude: double (nullable = true)
 |-- dropoff_latitude: double (nullable = true)
 |-- payment_type: integer (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)



- As you can see, this achieves the same result as using `StructType()`, but may be easier or more difficult depending on the number of columns you have. 
- My personal preference would be using `StructType()` for this dataset as you can use generator functions to simplify the allocation of dtypes.
- However, depending on your role or project scope, you may already have an existing DDL you can use!

In [13]:
sdf.limit(5)

VendorID,tpep_pickup_datetime,tpep_dropoff_datetime,passenger_count,trip_distance,pickup_longitude,pickup_latitude,RateCodeID,store_and_fwd_flag,dropoff_longitude,dropoff_latitude,payment_type,fare_amount,extra,mta_tax,tip_amount,tolls_amount,improvement_surcharge,total_amount
2,1/12/15 0:00,1/12/15 0:05,5,0.96,-73.97994232,40.76538086,1,N,-73.96630859,40.76308823,1,5.5,0.5,0.5,1.0,0.0,0.3,7.8
2,1/12/15 0:00,1/12/15 0:00,2,2.69,-73.97233582,40.76237869,1,N,-73.99362946,40.74599838,1,21.5,0.0,0.5,3.34,0.0,0.3,25.64
2,1/12/15 0:00,1/12/15 0:00,1,2.62,-73.96884918,40.76453018,1,N,-73.97454834,40.79164124,1,17.0,0.0,0.5,3.56,0.0,0.3,21.36
1,1/12/15 0:00,1/12/15 0:05,1,1.2,-73.99393463,40.74168396,1,N,-73.99766541,40.74746704,1,6.5,0.5,0.5,0.2,0.0,0.3,8.0
1,1/12/15 0:00,1/12/15 0:09,2,3.0,-73.98892212,40.72698975,1,N,-73.97559357,40.6968689,2,11.0,0.5,0.5,0.0,0.0,0.3,12.3


Now let's get our datetimes into `TIMESTAMP` formats (since our current dataset does not have the correct format).

In [5]:
import pyspark.sql.functions as F

from pyspark.sql.functions import col
from pyspark.sql.types import *

In [6]:
# create UDF
from datetime import datetime

@F.udf("timestamp")
def format_dtime(dtime):
    date, time = dtime.split()
    # map the iterable to integer
    d, m, y = map(int, date.split('/'))
    # year is abbreviated so we need to add 20 in front
    y = int(f"20{y}")
    h, mins = map(int, time.split(':'))
    return datetime(y, m, d, h, mins)

In [7]:
sdf.withColumn("tpep_pickup_datetime", format_dtime('tpep_pickup_datetime')) \
    .withColumn("tpep_dropoff_datetime", format_dtime('tpep_dropoff_datetime')) \
    .limit(5)

VendorID,tpep_pickup_datetime,tpep_dropoff_datetime,passenger_count,trip_distance,pickup_longitude,pickup_latitude,RateCodeID,store_and_fwd_flag,dropoff_longitude,dropoff_latitude,payment_type,fare_amount,extra,mta_tax,tip_amount,tolls_amount,improvement_surcharge,total_amount
2,2015-12-01 00:00:00,2015-12-01 00:05:00,5,0.96,-73.97994232,40.76538086,1,N,-73.96630859,40.76308823,1,5.5,0.5,0.5,1.0,0.0,0.3,7.8
2,2015-12-01 00:00:00,2015-12-01 00:00:00,2,2.69,-73.97233582,40.76237869,1,N,-73.99362946,40.74599838,1,21.5,0.0,0.5,3.34,0.0,0.3,25.64
2,2015-12-01 00:00:00,2015-12-01 00:00:00,1,2.62,-73.96884918,40.76453018,1,N,-73.97454834,40.79164124,1,17.0,0.0,0.5,3.56,0.0,0.3,21.36
1,2015-12-01 00:00:00,2015-12-01 00:05:00,1,1.2,-73.99393463,40.74168396,1,N,-73.99766541,40.74746704,1,6.5,0.5,0.5,0.2,0.0,0.3,8.0
1,2015-12-01 00:00:00,2015-12-01 00:09:00,2,3.0,-73.98892212,40.72698975,1,N,-73.97559357,40.6968689,2,11.0,0.5,0.5,0.0,0.0,0.3,12.3


- Conversion looks good to me, so let's keep it.
- Remember, Spark is immutable, so we will need to overwrite the `sdf` variable.

In [14]:
sdf = sdf.withColumn("tpep_pickup_datetime", format_dtime('tpep_pickup_datetime')) \
    .withColumn("tpep_dropoff_datetime", format_dtime('tpep_dropoff_datetime'))

In [15]:
sdf.limit(5)

VendorID,tpep_pickup_datetime,tpep_dropoff_datetime,passenger_count,trip_distance,pickup_longitude,pickup_latitude,RateCodeID,store_and_fwd_flag,dropoff_longitude,dropoff_latitude,payment_type,fare_amount,extra,mta_tax,tip_amount,tolls_amount,improvement_surcharge,total_amount
2,2015-12-01 00:00:00,2015-12-01 00:05:00,5,0.96,-73.97994232,40.76538086,1,N,-73.96630859,40.76308823,1,5.5,0.5,0.5,1.0,0.0,0.3,7.8
2,2015-12-01 00:00:00,2015-12-01 00:00:00,2,2.69,-73.97233582,40.76237869,1,N,-73.99362946,40.74599838,1,21.5,0.0,0.5,3.34,0.0,0.3,25.64
2,2015-12-01 00:00:00,2015-12-01 00:00:00,1,2.62,-73.96884918,40.76453018,1,N,-73.97454834,40.79164124,1,17.0,0.0,0.5,3.56,0.0,0.3,21.36
1,2015-12-01 00:00:00,2015-12-01 00:05:00,1,1.2,-73.99393463,40.74168396,1,N,-73.99766541,40.74746704,1,6.5,0.5,0.5,0.2,0.0,0.3,8.0
1,2015-12-01 00:00:00,2015-12-01 00:09:00,2,3.0,-73.98892212,40.72698975,1,N,-73.97559357,40.6968689,2,11.0,0.5,0.5,0.0,0.0,0.3,12.3


## Creating a SQL Table with an existing Spark DataFrame
The easiest method is to use `sdf.createOrReplaceTempView(TABLE_NAME)`


In [16]:
sdf.createOrReplaceTempView("taxi_data")

Select all columns from our table, where:
- the Vendor is `VeriFone Inc.`;
- we have at least 1 passenger;
- and a trip distance greater than 1 mile.

In [18]:
sql_query = """
SELECT * 
FROM taxi_data
WHERE VendorID = 2
    AND passenger_count >= 1
    AND trip_distance >= 1
LIMIT 5;
"""

spark.sql(sql_query)

VendorID,tpep_pickup_datetime,tpep_dropoff_datetime,passenger_count,trip_distance,pickup_longitude,pickup_latitude,RateCodeID,store_and_fwd_flag,dropoff_longitude,dropoff_latitude,payment_type,fare_amount,extra,mta_tax,tip_amount,tolls_amount,improvement_surcharge,total_amount
2,2015-12-01 00:00:00,2015-12-01 00:00:00,2,2.69,-73.97233582,40.76237869,1,N,-73.99362946,40.74599838,1,21.5,0.0,0.5,3.34,0.0,0.3,25.64
2,2015-12-01 00:00:00,2015-12-01 00:00:00,1,2.62,-73.96884918,40.76453018,1,N,-73.97454834,40.79164124,1,17.0,0.0,0.5,3.56,0.0,0.3,21.36
2,2015-12-01 00:00:00,2015-12-01 00:08:00,2,1.91,-73.99420929,40.74610138,1,N,-74.00424957,40.72180939,1,8.0,0.5,0.5,1.86,0.0,0.3,11.16
2,2015-12-01 00:00:00,2015-12-01 00:17:00,1,4.5,-74.00675964,40.7189064,1,N,-73.98969269,40.77285385,1,16.5,0.5,0.5,3.56,0.0,0.3,21.36
2,2015-12-01 00:00:00,2015-12-01 00:10:00,2,1.42,-73.99963379,40.73477173,1,N,-73.98906708,40.72312164,1,8.5,0.5,0.5,2.45,0.0,0.3,12.25


Below is the alternative query using PySpark. As you can see, it becomes _less interpretable_. Using Spark SQL ensures that the query is consistent and can also be run directly on the database if need be, allowing for a much more consistent way of testing queries.

In [22]:
sdf.select(sdf.columns) \
    .filter((col('VendorID') == 2) & (col('passenger_count') >= 1) & (col('trip_distance') >= 1)) \
    .limit(5)

VendorID,tpep_pickup_datetime,tpep_dropoff_datetime,passenger_count,trip_distance,pickup_longitude,pickup_latitude,RateCodeID,store_and_fwd_flag,dropoff_longitude,dropoff_latitude,payment_type,fare_amount,extra,mta_tax,tip_amount,tolls_amount,improvement_surcharge,total_amount
2,2015-12-01 00:00:00,2015-12-01 00:00:00,2,2.69,-73.97233582,40.76237869,1,N,-73.99362946,40.74599838,1,21.5,0.0,0.5,3.34,0.0,0.3,25.64
2,2015-12-01 00:00:00,2015-12-01 00:00:00,1,2.62,-73.96884918,40.76453018,1,N,-73.97454834,40.79164124,1,17.0,0.0,0.5,3.56,0.0,0.3,21.36
2,2015-12-01 00:00:00,2015-12-01 00:08:00,2,1.91,-73.99420929,40.74610138,1,N,-74.00424957,40.72180939,1,8.0,0.5,0.5,1.86,0.0,0.3,11.16
2,2015-12-01 00:00:00,2015-12-01 00:17:00,1,4.5,-74.00675964,40.7189064,1,N,-73.98969269,40.77285385,1,16.5,0.5,0.5,3.56,0.0,0.3,21.36
2,2015-12-01 00:00:00,2015-12-01 00:10:00,2,1.42,-73.99963379,40.73477173,1,N,-73.98906708,40.72312164,1,8.5,0.5,0.5,2.45,0.0,0.3,12.25


To list all metadata:
- Databases: `spark.catalog.listDatabases()`
- Tables: `spark.catalog.listTables()`
- Columns of a table: `spark.catalog.listColumns(TABLE_NAME)`

In [23]:
spark.catalog.listDatabases()

[Database(name='default', description='default database', locationUri='file:/mnt/c/users/akira/documents/github/MAST30034_Python/advanced_tutorials/spark-warehouse')]

In [24]:
spark.catalog.listTables()

[Table(name='taxi_data', database=None, description=None, tableType='TEMPORARY', isTemporary=True)]

## Creating SQL views directly from files
- If you don't have a Spark DataFrame, you can still read it in directly using Spark SQL.

Syntax:
```python
q = """
CREATE OR REPLACE TEMPORARY VIEW PARQUET_NAME
USING parquet
OPTIONS (path PARQUET_FPATH)
"""
```

- `CREATE OR REPLACE TEMPORARY VIEW TABLE_NAME` is the same as `sdf.createOrReplaceTempView(TABLE_NAME)`.
- `USING` denotes the file type (i.e `csv`).
- `OPTIONS` indicates the file path we wish to read from.

In [25]:
import os

sql_query = f"""
CREATE OR REPLACE TEMPORARY VIEW aggregation_parquet
USING parquet
OPTIONS (path 
    "{'/'.join(os.getcwd().split('/')[:-1])}/data/aggregated_results.parquet/")
"""

spark.sql(sql_query)

In [26]:
spark.sql("SELECT * FROM aggregation_parquet")

passenger_count,avg_trip_amount
0,13.033572474377738
1,14.667256653975397
2,15.37395457318419
3,14.876514480790858
4,14.968314051202968
5,14.864738578983076
6,14.564280093219365
7,16.295
8,58.05
9,21.076666666666668
