In [64]:
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql.functions import pandas_udf, PandasUDFType, sum, max, col, concat, lit, monotonically_increasing_id
import sys
import os



from pyspark import Row


from datetime import datetime,timedelta

from fbprophet import Prophet
import pandas as pd
import numpy as np

In [65]:
def dfZipWithIndex (df, offset=1, colName="rowId"):
    '''
        Enumerates dataframe rows is native order, like rdd.ZipWithIndex(), but on a dataframe 
        and preserves a schema

        :param df: source dataframe
        :param offset: adjustment to zipWithIndex()'s index
        :param colName: name of the index column
    '''

    new_schema = StructType(
                    [StructField(colName,LongType(),True)]        # new added field in front
                    + df.schema.fields                            # previous schema
                )

    zipped_rdd = df.rdd.zipWithIndex()

    new_rdd = zipped_rdd.map(lambda args: ([args[1] + offset] + list(args[0])))

    return spark.createDataFrame(new_rdd, new_schema)

In [66]:
schema = StructType([
        StructField("ds", DateType(), True),
        StructField("yhat", DoubleType(), True)
    ])

In [67]:
@pandas_udf(schema, PandasUDFType.GROUPED_MAP)
def fit_pandas_udf(df):
    """
    :param df: Dataframe (train + test data)
    :return: predictions as defined in the output schema
    """

    def train_fitted_prophet(df, cutoff):
        
        names = df.columns
        
        #train
        ts_train = (df
                    .query('id <= @cutoff')
                    .rename(columns={names[1]: 'ds', names[2]: 'y'})
                    .sort_values('ds')
                    )[['ds','y']]
        
        print(ts_train.columns)
        
        
        # test
        ts_test = (df
                   .query('id > @cutoff')
                   .rename(columns={names[1]: 'ds', names[2]: 'y'})
                   .sort_values('ds')
                   .assign(ds=lambda x: pd.to_datetime(x["ds"]))
                   .drop('y', axis=1)
                   )[['ds']]
        
        print(ts_test.columns)

 

        # init model
        m = Prophet(yearly_seasonality=True,
                    weekly_seasonality=True,
                    daily_seasonality=True)
        m.fit(ts_train)
        
        

        # to date
        
        # at this step we predict the future and we get plenty of additional columns be cautious
        ts_hat = (m.predict(ts_test)[["ds", "yhat"]]
                  .assign(ds=lambda x: pd.to_datetime(x["ds"]))
                  ).merge(ts_test, on=["ds"], how="left")  
        

        return pd.DataFrame(ts_hat, columns=schema.fieldNames())

    return train_fitted_prophet(df, cutoff)

In [68]:
if __name__ == '__main__':
    spark = (SparkSession
             .builder
             .appName("forecasting")
             .getOrCreate()
             #.config('spark.sql.execution.arrow.enable', 'true')
             )
    
    data = (spark
                .read
                .format("csv")
                .option('header', 'true')
                .option('inferSchema','true')
                .load('data_simulation.csv')
                #.load('Downloads/AEP_hourly.csv')
            )
    
    data.createOrReplaceTempView("data")
    data = spark.sql(f"SELECT LEFT(Datetime,10) AS Datetime, {data.columns[1]}  FROM data")
    data = data.groupBy("Datetime")\
               .mean("MW")\
               .sort(col('DateTime'))
    
    
    # 70% of the real dataset
    data_length = data.count()
    train_size = int(round(0.7 * data_length,0))
    
    
    ##Add future days to predict
    
    #last_day = data.tail(1)[0].__getitem__("Datetime")  # Não sei se é viável
    last_day = data.tail(1)[0].asDict()['Datetime']
    future_days = pd.date_range(start = last_day,
                                periods = 29)
    sequence_days = list(future_days.strftime("%Y-%m-%d"))[1:-1]
    future = spark.createDataFrame(sequence_days, 
                                   StringType())
    future.createOrReplaceTempView("future")
    future = spark.sql("SELECT value AS Datetime FROM future")
    future = future.withColumn(data.columns[1],
                               lit(None))
    

    
    df = (data.union(future)).sort(col('Datetime'))
    df = dfZipWithIndex(df,colName="id")
    
    
    
    cutoff = train_size
    # Apply forcasting
    global_predictions = (df
                          .groupBy()
                          .apply(fit_pandas_udf)
                          )



In [69]:
global_predictions.show()

Index(['ds', 'y'], dtype='object')                                  (0 + 1) / 1]
Index(['ds'], dtype='object')

Initial log joint probability = -123.798
Iteration  1. Log joint probability =    79.4266. Improved by 203.225.
Iteration  2. Log joint probability =    121.029. Improved by 41.6029.
Iteration  3. Log joint probability =    181.698. Improved by 60.6681.
Iteration  4. Log joint probability =    223.977. Improved by 42.2797.
Iteration  5. Log joint probability =    234.288. Improved by 10.3106.
Iteration  6. Log joint probability =    237.122. Improved by 2.83383.
Iteration  7. Log joint probability =    238.937. Improved by 1.81497.
Iteration  8. Log joint probability =    239.109. Improved by 0.171925.
Iteration  9. Log joint probability =    239.451. Improved by 0.342607.
Iteration 10. Log joint probability =    239.646. Improved by 0.194713.
Iteration 11. Log joint probability =     239.75. Improved by 0.10374.
Iteration 12. Log joint probability =    239.915. Improved by 0

+----------+------------------+
|        ds|              yhat|
+----------+------------------+
|2004-12-10|12906.682380411503|
|2004-12-11|11477.671421249173|
|2004-12-12| 10609.99888017381|
|2004-12-13|11979.518964026262|
|2004-12-14| 11972.48772068324|
|2004-12-15| 12050.61727188952|
|2004-12-16|12395.886262198313|
|2004-12-17|13307.344100721153|
|2004-12-18|13714.029261295087|
|2004-12-19|15374.204216504259|
|2004-12-20|  20050.6489498212|
|2004-12-21| 24206.56112400464|
|2004-12-22|29366.773784125533|
|2004-12-23| 35758.75923743601|
|2004-12-24| 43705.78250977809|
|2004-12-25|52136.454855189855|
|2004-12-26| 62781.02057443888|
|2004-12-27| 77344.24609119965|
|2004-12-28| 92199.57843566514|
|2004-12-29| 108749.3909311865|
+----------+------------------+
only showing top 20 rows



                                                                                