In [7]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import max, date_sub
from pyspark.sql.functions import expr

In [2]:
spark = (
    SparkSession.builder
    .appName("M5-Split")
    .master("local[2]")
    .config("spark.driver.memory", "6g")
    .config("spark.executor.memory", "6g")
    .config("spark.sql.shuffle.partitions", "32")
    .getOrCreate()
)

In [3]:
df = spark.read.parquet("../dataset/cooked/2")


In [4]:
df.printSchema()
df.select("date").show(3)


root
 |-- date: date (nullable = true)
 |-- item_id: string (nullable = true)
 |-- store_id: string (nullable = true)
 |-- sales: integer (nullable = true)
 |-- price: double (nullable = true)
 |-- event: string (nullable = true)
 |-- weekday: string (nullable = true)
 |-- lag_7: integer (nullable = true)
 |-- lag_14: integer (nullable = true)
 |-- lag_28: integer (nullable = true)
 |-- rolling_mean_7: double (nullable = true)
 |-- rolling_mean_28: double (nullable = true)
 |-- price_change: double (nullable = true)
 |-- dayofweek: integer (nullable = true)
 |-- weekofyear: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- is_weekend: integer (nullable = true)
 |-- event_idx: double (nullable = true)

+----------+
|      date|
+----------+
|2014-05-04|
|2011-04-24|
|2015-02-07|
+----------+
only showing top 3 rows



In [9]:
cutoff_date = df.selectExpr("date_sub(max(date), 28)").collect()[0][0]

In [10]:
train_df = df.filter(df.date < cutoff_date)
test_df  = df.filter(df.date >= cutoff_date)


In [11]:
train_df.selectExpr("min(date)", "max(date)").show()
test_df.selectExpr("min(date)", "max(date)").show()


+----------+----------+
| min(date)| max(date)|
+----------+----------+
|2011-02-26|2016-03-26|
+----------+----------+

+----------+----------+
| min(date)| max(date)|
+----------+----------+
|2016-03-27|2016-04-24|
+----------+----------+



In [12]:
train_df.rdd.getNumPartitions()
test_df.rdd.getNumPartitions()


6

In [None]:
(
    train_df
    .repartition(4)
    .write
    .mode("overwrite")
    .parquet("../dataset/cooked/3/train")
)

(
    test_df
    .repartition(4)
    .write
    .mode("overwrite")
    .parquet("../dataset/cooked/3/test")
)
