In [1]:
!pip install pyspark



In [2]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, to_date, to_timestamp
from pyspark.sql import functions

In [None]:
my_spark = SparkSession.builder.appName('SalesForecast').getOrCreate()
sales_data = my_spark.read.csv('Online Retail', header=True, inferSchema=True)
sales_data.show()

In [None]:
sales_data = sales_data.withColumn('InvoiceData', to_date(to_timestamp(col('InvoiceDate'))))

In [None]:
daily_sales_data = sales_data.groupby('Country', 'StockCode', 'InvoiceDate', 'Year', 'Month', 'Week', 'DayofWeek').agg(functions.sum('Quantity'), function.sum('UnitPrice'))

In [None]:
import pandas as pd
df = pd.read_csv('Online Retail.csv')
df['Year'].value_counts()
df.info

In [None]:
split_date_train_test = '2011-06-30'

In [None]:
train_data = sales_data.filter(col('InvoiceDate') <= split_data_train_test)

In [None]:
train_data = sales_data.filter(col('InvoiceDate') > split_data_train_test)

In [None]:
from pyspark.ml.feature import StringIndexer, VectorAssembler

In [None]:
country_indexer = StringIndexer(inputCol='Country', outputCol='CountryIndexer').setHandleInvalid('keep')

In [None]:
features_cols = ['CountryIndexer', 'StockCodeIndexer', 'Month', 'Year', 'DayofWeek', 'Day', 'Week']

In [None]:
assembler = VectorAssembler(inputCols=features_cols, outputCol='Features')

In [None]:
from pyspark.ml.regression import RandomForestRegressor
from pyspark.ml import Pipline

In [None]:
rf = RandomForestRegressor(featuresCol='Features', labelCol='Quantity', maxBins=4000)

In [None]:
pipeline = Pipeline(
    stages=[
        country_indexer, stock_code_indexer, assembler, rf
    ]
)

In [None]:
model = pipeline.fit(train_data)

In [None]:
test_predictions = model.transform(test_data).withColumn('Prediction', col('Prediction').cast('double'))

In [None]:
from pyspark.ml.evaluation import RegressionEvaluator

In [None]:
mae = mae.evaluate(test_predictions)
print('Mean absolute error:', mae)

In [None]:
df = df[(df['Year'] == 2011) & (df['Month']==1)]
df['Day'].value_counts()