In [57]:
spark.stop()

In [58]:
from pyspark.sql import SparkSession 
from pyspark.sql.window import Window
from pyspark.sql.functions import *
from pyspark.sql.types import *

spark = (
    SparkSession.Builder().appName('window_funnctions')\
    .config('spark.sql.session.logLevel','ERROR')\
    .getOrCreate()
)

spark.sparkContext.setLogLevel('ERROR')

In [59]:
# Sample data in a list of tuples
sales_data = [
    (1, '2024-01-01', 1500.50),
    (2, '2024-01-02', 2300.75),
    (2, '2024-01-03', 1800.00),
    (1, '2024-01-04', 2200.25),
    (2, '2024-01-05', 1700.80),
    (1, '2024-01-06', 2100.90),
    (2, '2024-01-07', 2500.00),
    (2, '2024-01-08', 1900.60),
    (1, '2024-01-09', 2400.40),
    (2, '2024-01-10', 2000.00)
]

schema = StructType(
    [
        StructField('category',IntegerType(),False),
        StructField('sale_date',StringType(),False),
        StructField('sale_amt',DoubleType(),False),
    ]
)

sales_df = spark.createDataFrame(sales_data, schema)
sales_df.show(truncate=False)

+--------+----------+--------+
|category|sale_date |sale_amt|
+--------+----------+--------+
|1       |2024-01-01|1500.5  |
|2       |2024-01-02|2300.75 |
|2       |2024-01-03|1800.0  |
|1       |2024-01-04|2200.25 |
|2       |2024-01-05|1700.8  |
|1       |2024-01-06|2100.9  |
|2       |2024-01-07|2500.0  |
|2       |2024-01-08|1900.6  |
|1       |2024-01-09|2400.4  |
|2       |2024-01-10|2000.0  |
+--------+----------+--------+



In [60]:
sales_df = sales_df.withColumn('sale_date_v2', to_date('sale_date', 'yyyy-MM-dd'))
sales_df.show()

+--------+----------+--------+------------+
|category| sale_date|sale_amt|sale_date_v2|
+--------+----------+--------+------------+
|       1|2024-01-01|  1500.5|  2024-01-01|
|       2|2024-01-02| 2300.75|  2024-01-02|
|       2|2024-01-03|  1800.0|  2024-01-03|
|       1|2024-01-04| 2200.25|  2024-01-04|
|       2|2024-01-05|  1700.8|  2024-01-05|
|       1|2024-01-06|  2100.9|  2024-01-06|
|       2|2024-01-07|  2500.0|  2024-01-07|
|       2|2024-01-08|  1900.6|  2024-01-08|
|       1|2024-01-09|  2400.4|  2024-01-09|
|       2|2024-01-10|  2000.0|  2024-01-10|
+--------+----------+--------+------------+



In [61]:
windowSpec = Window.partitionBy('category').orderBy('sale_date_v2').rowsBetween(-2,0)

In [62]:
sales_df.withColumn('rolling_sum', sum('sale_amt').over(windowSpec)).show()

+--------+----------+--------+------------+-----------------+
|category| sale_date|sale_amt|sale_date_v2|      rolling_sum|
+--------+----------+--------+------------+-----------------+
|       1|2024-01-01|  1500.5|  2024-01-01|           1500.5|
|       1|2024-01-04| 2200.25|  2024-01-04|          3700.75|
|       1|2024-01-06|  2100.9|  2024-01-06|          5801.65|
|       1|2024-01-09|  2400.4|  2024-01-09|6701.549999999999|
|       2|2024-01-02| 2300.75|  2024-01-02|          2300.75|
|       2|2024-01-03|  1800.0|  2024-01-03|          4100.75|
|       2|2024-01-05|  1700.8|  2024-01-05|          5801.55|
|       2|2024-01-07|  2500.0|  2024-01-07|           6000.8|
|       2|2024-01-08|  1900.6|  2024-01-08|           6101.4|
|       2|2024-01-10|  2000.0|  2024-01-10|           6400.6|
+--------+----------+--------+------------+-----------------+



In [45]:
assert sales_df.count() > 0, 'Too Small Dataset'

In [63]:
sales_df.explain()

== Physical Plan ==
*(1) Project [category#659, sale_date#660, sale_amt#661, cast(gettimestamp(sale_date#660, yyyy-MM-dd, TimestampType, Some(Asia/Calcutta), false) as date) AS sale_date_v2#678]
+- *(1) Scan ExistingRDD[category#659,sale_date#660,sale_amt#661]


