In [2]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import sum
from pyspark.sql.window import Window

spark = SparkSession.builder.getOrCreate()

data = [
    ("North", "2024-01-01", 100),
    ("North", "2024-01-02", 200),
    ("North", "2024-01-03", 300),
    ("South", "2024-01-01", 50),
    ("South", "2024-01-02", 150)
]

columns = ["region", "date", "sales"]
df = spark.createDataFrame(data, columns)
df.show()


+------+----------+-----+
|region|      date|sales|
+------+----------+-----+
| North|2024-01-01|  100|
| North|2024-01-02|  200|
| North|2024-01-03|  300|
| South|2024-01-01|   50|
| South|2024-01-02|  150|
+------+----------+-----+



In [4]:
df.groupBy("region").sum("sales").show()

+------+----------+
|region|sum(sales)|
+------+----------+
| North|       600|
| South|       200|
+------+----------+



In [5]:
window_spec = Window.partitionBy("region")

In [6]:
df.withColumn("region_total",sum("sales").over(window_spec)).show()

+------+----------+-----+------------+
|region|      date|sales|region_total|
+------+----------+-----+------------+
| North|2024-01-01|  100|         600|
| North|2024-01-02|  200|         600|
| North|2024-01-03|  300|         600|
| South|2024-01-01|   50|         200|
| South|2024-01-02|  150|         200|
+------+----------+-----+------------+

