In [11]:
from pyspark.sql import Window
from pyspark.sql.functions import col, avg, lag
from spark_session import get_spark_session

In [12]:
# Create Spark session
spark_session = get_spark_session("Test Jupyter") 

In [13]:
df = spark_session.read_dataframe("data.csv", header=True, inferSchema=True)

df.show()

+----+-------+------------+
|year|quarter|revenue_rate|
+----+-------+------------+
|2023|     Q1|        0.15|
|2023|     Q2|        0.17|
|2023|     Q3|        0.19|
|2023|     Q4|        0.22|
|2024|     Q1|        0.23|
|2024|     Q2|        0.21|
|2024|     Q3|        0.24|
+----+-------+------------+



In [14]:
# Define a window specification
# The window is partitioned by year and ordered by quarter
w = Window.partitionBy("year").orderBy("quarter")

# 1. Calculate the rate change (difference between current and previous quarters revenue)
df = df.withColumn(
    "rate_change",
    col("revenue_rate") - lag("revenue_rate", 1).over(w),
)

# 2. Calculate the cumulative average within each year
df = df.withColumn(
    "cumulative_avg",
    avg("revenue_rate").over(w.rowsBetween(Window.unboundedPreceding, Window.currentRow)),
)

# 3. Calculate the moving average over the last two quarters (current and previous quarter)
df = df.withColumn(
    "moving_avg",
    avg("revenue_rate").over(w.rowsBetween(-1, 0)),
)

# Display result
df.show()

+----+-------+------------+--------------------+-------------------+-------------------+
|year|quarter|revenue_rate|         rate_change|     cumulative_avg|         moving_avg|
+----+-------+------------+--------------------+-------------------+-------------------+
|2023|     Q1|        0.15|                NULL|               0.15|               0.15|
|2023|     Q2|        0.17|0.020000000000000018|               0.16|               0.16|
|2023|     Q3|        0.19| 0.01999999999999999|               0.17|               0.18|
|2023|     Q4|        0.22|                0.03|             0.1825|0.20500000000000002|
|2024|     Q1|        0.23|                NULL|               0.23|               0.23|
|2024|     Q2|        0.21|-0.02000000000000...|               0.22|               0.22|
|2024|     Q3|        0.24|                0.03|0.22666666666666666|0.22499999999999998|
+----+-------+------------+--------------------+-------------------+-------------------+



In [15]:
# Store output
df.write_dataframe(format="csv")

In [16]:
spark_session.stop()

END