In [None]:
# Before starting, ensure that the kernel is set to 'PySpark'

# To run our notebook on the EMR Cluster, we firstly need to create the SparkSession. 
# Since we're running this in the Pyspark Kernel, the SparkSession is automatically defined as 'spark'.
# It will start the spark session as soon as you run a block of this notebook.

In [None]:
# Now that sparkSession has been started, we want to do a few configurations on our session in an attempt to
# slightly optimize it's performance. 
# You may add/remove/change any configuration settings according to your needs/cluster configuration.

spark.conf.set("spark.sql.shuffle.partitions", 7500)
spark.conf.set("spark.executor.memory", '2g')

In [None]:
# Normally in python, we are able to install our packages using pip3 or some other package manager. Since we have
# multiple machines in our cluster, we aren't able to use simple pip3 commands inside the command line to install 
# packages efficiently. 
# So, we can use the sparkContext to install our packages on our machines for us.

# This can be done by first importing and creating the sparkContext.
from pyspark.context import SparkContext
sc = spark.sparkContext

# And then using the install_pypi_package function to install the packages.
# It uses PyPi package index, so make sure you get the package names from: https://pypi.org/
sc.install_pypi_package("statsmodels")
sc.install_pypi_package("pandas")
sc.install_pypi_package("pyarrow==0.11.0")
sc.install_pypi_package("s3fs")

In [None]:
# Now that we have all of our packages installed using the install_pypi_package function, we can now import these
# packages into our notebook.
import numpy as np, pandas as pd
import calendar
import ast
import datetime
import logging
import statsmodels.api as sm
import pyarrow as pa
from pyspark.sql.types import *
from pyspark.sql.functions import pandas_udf, PandasUDFType

In [None]:
# The notebook is now setup to run this example.

# To get data from the Redshift database, we can use the spark session to read the data straight into a dataframe.
# This can be done by using the following line for each table:
# dataframe = spark.read.format("jdbc").option("url", "jdbc:redshift://host:port/database").option("dbtable", "table name here").option("user", "username here").option("password", "password here").load()
# You can configure this line with your own parameters.

# For our example, we can use these two lines to read the data from the fcst20.state_to_state table, and 
# the parameters from the fcst20.arima_calib table. (You still need to configure your own Username and password)
df = spark.read.format("jdbc").option("url", "jdbc:redshift://demo.3victorsaws.com:5439/demo").option("dbtable", "fcst20.state_to_state").option("user", "").option("password", "").load()
params_df = spark.read.format("jdbc").option("url", "jdbc:redshift://demo.3victorsaws.com:5439/demo").option("dbtable", "fcst20.arima_calib").option("user", "").option("password", "").load()

In [None]:
# Before we can start to manipulate the data in the Spark dataframes, we need to declare a schema. This schema will be
# used to tell PySpark how to automatically translate the Pandas dataframe that we return from our functions back into 
# a Spark Dataframe. The schema is the structure of the Pandas dataframe that we're return.


schema = StructType([
                     StructField('origin_state_code', StringType(), True),
                     StructField('destination_state_code', StringType(), True),
                     StructField('travel_month', StringType(), True),
                     StructField('index', StringType(), True),
                     StructField('mean', FloatType(), True),
                     StructField('mean_se', FloatType(), True),  
                     StructField('mean_ci_lower', FloatType(), True),
                     StructField('mean_ci_upper', FloatType(), True)                     
                    ])

In [None]:
# Now that we have declared what the returned dataframe's schema is, we can create our function. This function will be
# used to run predictions on Grouped Data. Grouped Data, in this notebook is the data in our Spark Dataframe that has
# been split up into smaller Pandas Dataframes depending on the values of origin_state_code, destination_state_code 
# and  month_state for each group of rows.
# This means that the data used in each instance of this function will have the same values of each of these variables.

# The function requires a decorator in order to be used with our spark dataframe. 
# The first value in the decorator indicates our schema/format for the outputted pandas dataframe.
# The second value in the decorator tells Pyspark that the outputted/returned Pandas dataframe should be converted
# back into a spark dataframe upon completion. Additionally it also tells PySpark that the function should be run
# on data according to the groupBy() function ran before it. 
@pandas_udf(schema, PandasUDFType.GROUPED_MAP)
def run_sarima_bystate(state_data):
    # Since we have told pyspark that the state_data variable in this function should be grouped data,
    # we can start to manipulate data. Additionally, as mentioned before, the grouped data is stored using Pandas
    # whilst in this function.
    
    
    # In order to be able to run our prediction on our data, we need to firstly format the Pandas Dataframe.
    # This is done by firstly sorting the dataframe by the date column. (This is so it can be used as an index later on.
    state_data = state_data.sort_values('sales_date', ascending=True) 
 
    # For simplicity and readability we can split the state_data dataframe, which contains a combination of our
    # state_to_state table and params table, into the data needed for the prediction.
    data=pd.DataFrame() 
    # Converting the 'date' column of state_data from string to date
    data['date']           = pd.to_datetime(state_data['date']) 
    # Copying the 'total_searches' column from one df to the other.
    data['total_searches'] = state_data['total_searches']
    
    # As mentioned before, the data inside the function is grouped, which means that we only have data which have 
    # the exact same values for origin_state_code, destination_state_code and month_start.
    # This means that we can actually just take a singular value from each of the columns in the dataframe and use 
    # them as variables.
    origin_state      = state_data['origin_state_code'][0] 
    destination_state = state_data['destination_state_code'][0]
    # (Data is grouped by month_start, so it's grouped by month_end too)
    month_end         = pd.to_datetime(state_data['month_end'])[0].strftime('%Y-%m-%d') 
    
    # We want to be able to tell the Sarimax model where to learn and where to predict from.
    # Get Yesterday's date as it will be the last day of data to for the model to learn from.
    yesterday         = (datetime.datetime.now() + datetime.timedelta(-1)).strftime('%Y-%m-%d') 
    # Get today's date, as this will be the cut off as to where to start forecasting.
    today             = datetime.datetime.now().strftime('%Y-%m-%d') 
    
    # Sarimax needs a continuous column of dates, which each row containing data. So if there's data missing, it
    # has the potential to cause our Sarimax model to fail.
    # We can combat this by creating a column containing all of our desired dates, merging those dates to our data 
    # dataframe, and filling any empty values with the value of 0.
    reference_dates = pd.DataFrame() 
    # Getting a column of dates from the start date -> yesterday 
    reference_dates['date'] = pd.date_range(start='2020-04-01', end=yesterday, freq='D') 
    # Merge the reference_dates dataframe into the data dataframe. Merges on the date column and uses outer since it will 
    # use a union of both df's dates. (All the dates are kept, duplicate dates are removed)
    data = pd.merge(reference_dates, data, on=('date'), how='outer') 
    # If there is any missing data, then replace it with 0's.
    data = data.fillna(0) 
    
    
    # Set the dataframe's index as the date column. This is required for sarimax model to work as it is a time-series
    # forecasting algorithm.
    data = data.set_index('date')
    
    # Since order and seasonal_order are based on origin state, destination state and travel month, these variables
    # end up being the same for our entire set of grouped data.
    # Unfortunately, the redshift database chose to store our tuple data as strings, so we need to first remove the
    # brackets and extract the data, before repacking the data into a tuple.
    # [1:-1] is used to remove the brackets. 
    order_str = (state_data['order'][0])[1:-1] 
    seasonal_order_str = (state_data['seasonal_order'][0])[1:-1] 
    # Create a tuple from the values. Restores the tuple data structure that we want.
    order = tuple(map(int, order_str.split(', '))) 
    seasonal_order = tuple(map(int, seasonal_order_str.split(', '))) 
    

    try:
        # Now that we have all of the data formatted and variables we need, we can configure and train the model
        # and make our predictions.
        
        # Create the sarimax model and define the configuration.
        mod_sarimax     = sm.tsa.SARIMAX(data, order=order, seasonal_order=seasonal_order) 
        # Fit/Train the sarimax model
        smodel          = mod_sarimax.fit() 
        # Make predictions from today till month_end.
        predict_results = smodel.get_prediction(start=today, end=month_end, freq='D', dynamic=True) 
                
        # Most results are collected in the `summary_frame` attribute.
        # Here we specify that we want a confidence level of 90%
        predictions = predict_results.summary_frame(alpha=0.10) 
        
        # Now that we have the predictions we want, we now have to format the resulting pandas dataframe into one 
        # that reflects the schema for returned dataframes.
        
        # Changes the index from date type into a string
        predictions.index=predictions.index.strftime('%Y-%m-%d') 
        #Reset the index without creating a new index.
        predictions.reset_index(inplace=True) 
        # Set the origin_state_code column to the origin_state
        predictions['origin_state_code']=origin_state 
        # Set the destination_state_code column to the destination_state
        predictions['destination_state_code']=destination_state 
        # Set the travel_month column to the month_end
        predictions['travel_month']=month_end 
                
        #Return pandas df.
        return predictions 
    
    except Exception:
        print(Exception)
        
    # If an exception occurs during the code block in the try, create an empty dataframe to return with prediction 
    # values of 0.
    empty = pd.DataFrame()    
    empty['mean'] = 0
    empty['mean_se'] = 0
    empty['mean_ci_lower'] = 0
    empty['mean_ci_upper'] = 0
    empty['index']=month_end
    empty['origin_state_code']=origin_state
    empty['destination_state_code']=destination_state
    empty['travel_month']=month_end
    
    return empty

In [None]:
# Now that we have our function setup, we can give spark some instructions to lazily execute. 
# However we firstly need to create tables that will be evaluated lazily (only executes instructions when a 
# value is needed/requested/output)

# Creates a lazily evaluated view/table of the df dataframe.
df.createOrReplaceTempView('data') 
# Creates a lazily evaluated view/table of the params_df dataframe.
params_df.createOrReplaceTempView('params') 

# This long line of code uses spark sql to join the contents of the data view with the contents of the params view.
# It joins the data together depending on whether the sales_date older than 2020-04-01 and whether there are
# available parameters (no_order = 'N') or if arima is possible (no_arima = 'N')
flattened = spark.sql("select data.sales_date, date_format(data.sales_date, 'MM/dd/yyyy') as date,  " +
                      "data.origin_state_code, data.destination_state_code, data.month_start, data.month_end, data.total_searches, " +
                      "params.dptr_month, params.m_order as order, params.m_seasonal_order as seasonal_order " +
                      "from data " +
                      "join params on params.origin_state_code = data.origin_state_code " + 
                      "and params.destination_state_code = data.destination_state_code " +
                      "and params.dptr_month = CAST(EXTRACT(MONTH FROM data.month_start) AS bigint) " +
                      "where data.sales_date >= '2020-04-01' " +
                      "AND params.no_order = 'N' " +
                      "AND params.no_arima = 'N' " +
                      "ORDER BY sales_date")


In [None]:
# Using the latest spark version at the current time, Spark 3.0.0. The apply function still works, however
# is deprecated. If you are using this in version 3.0.0 or later, replace apply with applyInPandas.
# In order for your Python function to be able to operate on the data,
# PySpark converts the Spark DataFrame into a Pandas DataFrame before your function is called. 
# Likewise, when you return a Pandas DataFrame from your function, PySpark will convert it back into a Spark DataFrame.

# Its worth noting, that PySpark will call your function for each group/partition of data that it has, hence, 
# in the example, there is a groupBy origin_state_code, destination_state_code, travel_month - so that the Python 
# function will operate on a Pandas DataFrame containing all the data for each combination.
results = flattened.groupby('origin_state_code', 'destination_state_code','month_start').apply(run_sarima_bystate)

In [None]:
# We can now select whether we want to enforce partitioning or not.

# Repartitions the results received into 4 separate partitions. (Use this to specify how many csv's you want to return.)

# output = results.repartition(4)  #Commented out since we don't want to overwrite it.


# Different version of the line above. Where partitioning isn't defined.
output = results 

In [None]:
# Increase this number depending on what the highest sarimax- folder number is.
count = 1 

In [None]:
count = count + 1
# The data can now be written out to csv files.
output.write.format("csv").save("s3://folder path goes here/sarimax-" + str(count) + "/")