In [None]:
sc.install_pypi_package("statsmodels")

In [None]:
sc.install_pypi_package("PyArrow==1.0.0")

In [None]:
sc.list_packages()

In [None]:
from pyspark.sql.types import DateType, FloatType, IntegerType, StructField, StructType
from pyspark.sql import SparkSession
from statsmodels.tsa.arima.model import ARIMA
import logging
import pandas as pd
from pyspark.sql.functions import current_date

In [None]:
import os
import sys

os.environ['PYSPARK_PYTHON'] = sys.executable
os.environ['PYSPARK_DRIVER_PYTHON'] = sys.executable

In [None]:
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")

In [None]:
# Create a Spark session
spark = SparkSession.builder.master("yarn").getOrCreate()

In [None]:
# read the training file into a dataframe from S3 bucket
s3_path = "s3://my-sales-data-storage/Sales.csv"
train_schema = StructType([
    StructField('date', DateType()),
    StructField('store', IntegerType()),
    StructField('item', IntegerType()),
    StructField('sales', IntegerType())
])

In [None]:
train = spark.read.csv(s3_path, header=True, schema=train_schema)

In [None]:
train.show()

In [None]:
# make the dataframe queriable as a temporary view
train.createOrReplaceTempView('train') 


#Retrieve Data for All Store-Item Combinations
sql_statement = '''
  SELECT
    store,
    item,
    CAST(date as date) as ds,
    SUM(sales) as y
  FROM train
  GROUP BY store, item, ds
  ORDER BY store, item, ds
  '''

In [None]:
# Run SQL statement and cache the result
store_item_history = (
  spark
    .sql( sql_statement )
    .repartition(sc.defaultParallelism, ['store', 'item'])
  ).cache()

In [None]:
#Define Result Schema
result_schema =StructType([
  StructField('ds',DateType()),
  StructField('store',IntegerType()),
  StructField('item',IntegerType()),
  StructField('y',FloatType()),
  StructField('yhat',FloatType()),
  StructField('yhat_upper',FloatType()),
  StructField('yhat_lower',FloatType())
  ])

In [None]:
# Define Function to Train Model & Generate Forecast
def forecast_store_item(history_pd: pd.DataFrame) -> pd.DataFrame:
    # TRAIN MODEL AS BEFORE
    # --------------------------------------
    # remove missing values (more likely at day-store-item level)
    history_pd = history_pd.dropna()

    # configure the model
    model = ARIMA(history_pd['y'], order=(1, 0, 0))

    # train the model
    model_fit = model.fit(disp=False)
    # --------------------------------------

    # BUILD FORECAST AS BEFORE
    # --------------------------------------
    # make predictions
    forecast_pd = model_fit.predict(start=len(history_pd), end=len(history_pd) + 89)
    # --------------------------------------

    # ASSEMBLE EXPECTED RESULT SET
    # --------------------------------------
    # get relevant fields from forecast
    f_pd = pd.DataFrame({
        'ds': pd.date_range(start=history_pd['ds'].max() + pd.DateOffset(days=1), periods=90),
        'yhat': forecast_pd,
        'yhat_upper': forecast_pd,
        'yhat_lower': forecast_pd
    })

    # get relevant fields from history
    h_pd = history_pd[['ds', 'store', 'item', 'y']]

    # join history and forecast
    results_pd = f_pd.join(h_pd, how='left')
    results_pd.reset_index(drop=True, inplace=True)

    # get store & item from incoming data set
    results_pd['store'] = history_pd['store'].iloc[0]
    results_pd['item'] = history_pd['item'].iloc[0]
    # --------------------------------------

    # return expected dataset
    return results_pd[['ds', 'store', 'item', 'y', 'yhat', 'yhat_upper', 'yhat_lower']]



In [None]:
# Apply Forecast Function to Each Store-Item Combination
results = (
    store_item_history
    .groupBy('store', 'item')
    .applyInPandas(forecast_store_item, schema=result_schema)
    .withColumn('training_date', current_date())
)


In [None]:
results.show(10)

In [None]:
results.to_csv('result.csv', index=False)

In [None]:
# Write the predictions to Redshift
jdbc_url = "jdbc:redshift://default-workgroup.572561648008.ap-south-1.redshift-serverless.amazonaws.com:5439/sales-forecast"
jdbc_properties = {
    "user": "admin",
    "password": "Password",
    "driver": "com.amazon.redshift.jdbc.Driver"
}
results.write \
    .format("jdbc") \
    .option("url", jdbc_url) \
    .option("dbtable", "predictions") \
    .options(**jdbc_properties) \
    .mode("append") \
    .save()

In [None]:
# Stop the Spark session
spark.stop()