In [2]:
from pyspark.sql import SparkSession
from pyspark.sql import Window
from pyspark.sql.functions import *
from pyspark.sql.types import *

In [3]:
spark = SparkSession.builder.appName("Cummulative Sum").config("spark.jars.packages", "com.crealytics:spark-excel_2.12:0.13.5").getOrCreate()

data = spark.read.format("com.crealytics.spark.excel").option("header", "true").option("inferSchema", "true").load("./dataset/cumsum.xlsx")

In [7]:
data.printSchema()

root
 |-- month_bs_id: double (nullable = true)
 |-- account_id: double (nullable = true)
 |-- txn_amt: double (nullable = true)
 |-- txn_cnt: double (nullable = true)



In [16]:
df = data.groupBy('account_id', 'month_bs_id').agg(sum('txn_amt').alias('total_txn_amt'), sum('txn_cnt').alias('total_txn_cnt')).orderBy(asc('account_id'), asc('month_bs_id'))

df.show()

+----------+-----------+-------------+-------------+
|account_id|month_bs_id|total_txn_amt|total_txn_cnt|
+----------+-----------+-------------+-------------+
|      15.0|     9811.0|       3000.0|          6.0|
|      15.0|     9812.0|       2000.0|          4.0|
|      15.0|     9813.0|        500.0|          1.0|
|      15.0|     9814.0|        500.0|          1.0|
|      44.0|     9811.0|       8000.0|          8.0|
|      44.0|     9812.0|       1000.0|          1.0|
|      44.0|     9813.0|       1000.0|          1.0|
|      44.0|     9814.0|       1000.0|          1.0|
|      44.0|     9815.0|       1000.0|          1.0|
|     325.0|     9811.0|     192000.0|         16.0|
|     325.0|     9812.0|       2000.0|          1.0|
|     325.0|     9813.0|      42000.0|          5.0|
|     325.0|     9814.0|       2000.0|          1.0|
|     325.0|     9815.0|       2000.0|          1.0|
|     335.0|     9811.0|     240000.0|          8.0|
|     335.0|     9812.0|      30000.0|        

In [19]:
window_spec = Window.partitionBy("account_id").orderBy("month_bs_id")

df = df.withColumn("cumsum_amt", sum("total_txn_amt").over(window_spec))
df = df.withColumn("cumsum_cnt", sum("total_txn_cnt").over(window_spec))

df.show()

+----------+-----------+-------------+-------------+----------+----------+
|account_id|month_bs_id|total_txn_amt|total_txn_cnt|cumsum_amt|cumsum_cnt|
+----------+-----------+-------------+-------------+----------+----------+
|      15.0|     9811.0|       3000.0|          6.0|    3000.0|       6.0|
|      15.0|     9812.0|       2000.0|          4.0|    5000.0|      10.0|
|      15.0|     9813.0|        500.0|          1.0|    5500.0|      11.0|
|      15.0|     9814.0|        500.0|          1.0|    6000.0|      12.0|
|      44.0|     9811.0|       8000.0|          8.0|    8000.0|       8.0|
|      44.0|     9812.0|       1000.0|          1.0|    9000.0|       9.0|
|      44.0|     9813.0|       1000.0|          1.0|   10000.0|      10.0|
|      44.0|     9814.0|       1000.0|          1.0|   11000.0|      11.0|
|      44.0|     9815.0|       1000.0|          1.0|   12000.0|      12.0|
|     325.0|     9811.0|     192000.0|         16.0|  192000.0|      16.0|
|     325.0|     9812.0| 

In [4]:
from pyspark.sql import SparkSession

def df_table(df, schema_name, table_name):
    '''
    Load DataFrame to MySQL table
    '''
    url = f"jdbc:mysql://localhost/{schema_name}"
    properties = {
        "user": "root",
        "password": "mysql@123",
        "driver": "com.mysql.cj.jdbc.Driver"
    }
    
    df.write.jdbc(url=url, table=table_name, mode='overwrite', properties=properties)


In [5]:
df_table(data, "mydb", "cumsum")