In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum as _sum, to_date, count as _count, year, month

In [0]:
# ----------------------------
# 1. Read Data into a DataFrame for Processing
# ----------------------------

df_transactions = spark.table('bank_transactions')
df_branches = spark.table('branches')

print("=== Raw Transactions ===")
df_transactions.show(5)

print("=== Branch Info ===")
df_branches.show(5)

=== Raw Transactions ===
+--------------+----------+----------+--------+----------------+-------+--------+------------+---------+------------------+
|transaction_id|      date|account_id|  branch|transaction_type| amount|currency|counterparty|   status|      product_type|
+--------------+----------+----------+--------+----------------+-------+--------+------------+---------+------------------+
|          1001|2026-01-01|   ACC1001|New York|         Deposit| 500000|     USD|           -|Completed|   Commercial Loan|
|          1002|2026-01-01|   ACC2001|  London|      Withdrawal| 200000|     GBP|           -|Completed|Investment Account|
|          1003|2026-01-02|   ACC1001|New York|        Transfer| 100000|     USD|     ACC3001|  Pending|   Commercial Loan|
|          1004|2026-01-03|   ACC3001|   Tokyo|         Deposit|3000000|     JPY|           -|Completed|Investment Account|
|          1005|2026-01-03|   ACC2001|  London|        Transfer| 150000|     GBP|     ACC1001|Completed|Inv

In [0]:
# ----------------------------
# 2. Filter completed transactions
# ----------------------------
df_transactions = df_transactions.filter(col("status") == "Completed")

In [0]:
df_transactions.show()

+--------------+----------+----------+--------+----------------+-------+--------+------------+---------+------------------+
|transaction_id|      date|account_id|  branch|transaction_type| amount|currency|counterparty|   status|      product_type|
+--------------+----------+----------+--------+----------------+-------+--------+------------+---------+------------------+
|          1001|2026-01-01|   ACC1001|New York|         Deposit| 500000|     USD|           -|Completed|   Commercial Loan|
|          1002|2026-01-01|   ACC2001|  London|      Withdrawal| 200000|     GBP|           -|Completed|Investment Account|
|          1004|2026-01-03|   ACC3001|   Tokyo|         Deposit|3000000|     JPY|           -|Completed|Investment Account|
|          1005|2026-01-03|   ACC2001|  London|        Transfer| 150000|     GBP|     ACC1001|Completed|Investment Account|
+--------------+----------+----------+--------+----------------+-------+--------+------------+---------+------------------+



In [0]:
# ----------------------------
# 3. Rename date to txn date
# ----------------------------
df_transactions_date = df_transactions.withColumnRenamed("date", "txn_date")

In [0]:
df_transactions_date.show()

+--------------+----------+----------+--------+----------------+-------+--------+------------+---------+------------------+
|transaction_id|  txn_date|account_id|  branch|transaction_type| amount|currency|counterparty|   status|      product_type|
+--------------+----------+----------+--------+----------------+-------+--------+------------+---------+------------------+
|          1001|2026-01-01|   ACC1001|New York|         Deposit| 500000|     USD|           -|Completed|   Commercial Loan|
|          1002|2026-01-01|   ACC2001|  London|      Withdrawal| 200000|     GBP|           -|Completed|Investment Account|
|          1004|2026-01-03|   ACC3001|   Tokyo|         Deposit|3000000|     JPY|           -|Completed|Investment Account|
|          1005|2026-01-03|   ACC2001|  London|        Transfer| 150000|     GBP|     ACC1001|Completed|Investment Account|
+--------------+----------+----------+--------+----------------+-------+--------+------------+---------+------------------+



In [0]:
# ----------------------------
# 4. Join with branch info
# ----------------------------
df_joined = df_transactions_date.join(
    df_branches,
    df_transactions["branch"] == df_branches["branch_name"],
    "inner"
)

In [0]:
# Drop duplicate columns
df_joined = df_joined.drop("branch")
df_joined.show()

+--------------+----------+----------+----------------+-------+--------+------------+---------+------------------+---------+-----------+--------+-------+------------+--------------+----------+-----------+----------+
|transaction_id|  txn_date|account_id|transaction_type| amount|currency|counterparty|   status|      product_type|branch_id|branch_name|    city|country|      region|       manager| open_date|branch_type|assets_usd|
+--------------+----------+----------+----------------+-------+--------+------------+---------+------------------+---------+-----------+--------+-------+------------+--------------+----------+-----------+----------+
|          1001|2026-01-01|   ACC1001|         Deposit| 500000|     USD|           -|Completed|   Commercial Loan|     B001|   New York|New York|    USA|  North East| Alice Johnson|2005-03-15| Commercial|   1200000|
|          1005|2026-01-03|   ACC2001|        Transfer| 150000|     GBP|     ACC1001|Completed|Investment Account|     B003|     London|

In [0]:
#5. Data Quality Checks, segregation of bad data and reporting and alerting if any bad data encountered
# Valid records
df_valid = df_joined.filter(
    (col("branch_name").isNotNull()) &
    (col("txn_date").isNotNull()) &
    (col("amount") > 0)
)

# Invalid records (quarantine)
df_invalid = df_joined.filter(
    (col("branch_name").isNull()) |
    (col("txn_date").isNull()) |
    (col("amount") <= 0)
)

if df_invalid.count() > 0:
    print("⚠️ Warning: Some records were quarantined due to invalid data")
    #TODO - Send alert to New Relic or other monitoring system
    df_invalid.write.mode("overwrite").saveAsTable("invalid_transactions_table")


In [0]:
# ----------------------------
# 6. Aggregate totals per branch & product
# ----------------------------
df_aggregated = df_valid.groupBy("branch_name", "product_type") \
                         .agg(_sum(col("amount").cast("double")).alias("total_amount"))

print("=== Aggregated Totals ===")
df_aggregated.show(5)

=== Aggregated Totals ===
+-----------+------------------+------------+
|branch_name|      product_type|total_amount|
+-----------+------------------+------------+
|      Tokyo|Investment Account|   3000000.0|
|   New York|   Commercial Loan|    500000.0|
|     London|Investment Account|    350000.0|
+-----------+------------------+------------+



In [0]:
# Save as a Databricks table
df_aggregated.write.mode("overwrite").option("mergeSchema", "true").saveAsTable("processed_transactions_table")


In [0]:
spark.sql("SELECT * FROM processed_transactions_table LIMIT 10").show()

+-----------+------------------+------------+--------------+--------+----------+----------------+------+--------+------------+------+---------+----+-------+------+-------+---------+-----------+----------+
|branch_name|      product_type|total_amount|transaction_id|txn_date|account_id|transaction_type|amount|currency|counterparty|status|branch_id|city|country|region|manager|open_date|branch_type|assets_usd|
+-----------+------------------+------------+--------------+--------+----------+----------------+------+--------+------------+------+---------+----+-------+------+-------+---------+-----------+----------+
|      Tokyo|Investment Account|   3000000.0|          NULL|    NULL|      NULL|            NULL|  NULL|    NULL|        NULL|  NULL|     NULL|NULL|   NULL|  NULL|   NULL|     NULL|       NULL|      NULL|
|   New York|   Commercial Loan|    500000.0|          NULL|    NULL|      NULL|            NULL|  NULL|    NULL|        NULL|  NULL|     NULL|NULL|   NULL|  NULL|   NULL|     NULL

In [0]:
# ================================
# Scaling for Large Datasets - Note this can be done before aggregation if data is very large(millions of rows per txn_day)
# ================================

# Repartition the DataFrame for better parallelism
# (useful when dataset grows to millions of rows)
df_repart = df_valid.repartition("txn_date", "branch_name")

In [0]:
df_aggregated_repart = df_repart.groupBy("txn_date","branch_name", "product_type") \
                         .agg(_sum(col("amount").cast("double")).alias("total_amount"))
df_aggregated_repart.show()
df_aggregated_repart.printSchema()

+----------+-----------+------------------+------------+
|  txn_date|branch_name|      product_type|total_amount|
+----------+-----------+------------------+------------+
|2026-01-03|     London|Investment Account|    150000.0|
|2026-01-01|   New York|   Commercial Loan|    500000.0|
|2026-01-03|      Tokyo|Investment Account|   3000000.0|
|2026-01-01|     London|Investment Account|    200000.0|
+----------+-----------+------------------+------------+

root
 |-- txn_date: date (nullable = true)
 |-- branch_name: string (nullable = true)
 |-- product_type: string (nullable = true)
 |-- total_amount: double (nullable = true)



In [0]:
# Write the table as a managed, partitioned table
# Partitioned by txn_date and branch_name for efficient queries
df_aggregated_repart.write.mode("overwrite").option("mergeSchema", "true").saveAsTable("processed_transactions_repart_table")

In [0]:
spark.sql("SELECT * FROM processed_transactions_repart_table LIMIT 10").show()

+----------+-----------+------------------+------------+
|  txn_date|branch_name|      product_type|total_amount|
+----------+-----------+------------------+------------+
|2026-01-03|     London|Investment Account|    150000.0|
|2026-01-01|   New York|   Commercial Loan|    500000.0|
|2026-01-03|      Tokyo|Investment Account|   3000000.0|
|2026-01-01|     London|Investment Account|    200000.0|
+----------+-----------+------------------+------------+



In [0]:
#Daily roll ups by txn_date, product_type
df_daily = df_repart.groupBy("txn_date", "product_type") \
                         .agg(_sum(col("amount").cast("double")).alias("total_amount"),
                              _count(col("transaction_id")).alias("txn_count")
                         )
df_daily.show()

df_daily_rollup = df_daily.orderBy("txn_date", "product_type")
df_daily_rollup.write.mode("overwrite") \
    .option("overwriteSchema", "true") \
    .saveAsTable("transactions_daily_rollups")

+----------+------------------+------------+---------+
|  txn_date|      product_type|total_amount|txn_count|
+----------+------------------+------------+---------+
|2026-01-03|Investment Account|   3150000.0|        2|
|2026-01-01|   Commercial Loan|    500000.0|        1|
|2026-01-01|Investment Account|    200000.0|        1|
+----------+------------------+------------+---------+



In [0]:
#Monthly Rollups
df_monthly_rollup = df_repart.withColumn("year", year("txn_date")) \
                             .withColumn("month", month("txn_date")) 
df_monthly_rollup = df_monthly_rollup.groupBy("year", "month", "product_type") \
                         .agg(_sum(col("amount").cast("double")).alias("total_amount"),
                              _count(col("transaction_id")).alias("txn_count")
                         ).orderBy("year", "month", "product_type")
df_monthly_rollup.show()
df_monthly_rollup.write.mode("overwrite").option("overwriteSchema", "true").saveAsTable("transactions_monthly_rollups")

+----+-----+------------------+------------+---------+
|year|month|      product_type|total_amount|txn_count|
+----+-----+------------------+------------+---------+
|2026|    1|   Commercial Loan|    500000.0|        1|
|2026|    1|Investment Account|   3350000.0|        3|
+----+-----+------------------+------------+---------+

