<a href="https://colab.research.google.com/github/asupraja3/spark-mlops-lab/blob/main/pyspark.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install PySpark
!pip install pyspark



**1. Start the Session and Set Config**
* In most environments (like Databricks or pyspark), the spark object is ready. If not, run this:


In [None]:
from pyspark.sql import SparkSession
# Standard way to get/create a SparkSession
spark = SparkSession.builder.appName("Pyspark Practice").getOrCreate()
# NEW: Set time policy for reliable date/time parsing across different formats
spark.sql("SET spark.sql.legacy.timeParserPolicy=LEGACY")

DataFrame[key: string, value: string]

**2. DataFrame Creation Examples**

In [None]:
#From Python List
data = [("Alice", 1, "2024-01-01"), ("Bob", 2, "2024-01-02")]
df_data = spark.createDataFrame(data, ["Name", "ID", "Date"])

#From JSON File
#df_json = spark.read.json("/temp.JSON")

#From CSV File
#df_json = spark.read.csv("/temp.csv", header=True, inferSchema=True)

#From Parquet File
#df_json = spark.read.parquet("/temp.pq")

**II. Schema: The Blueprint of Your Data**
* A Schema is the definition of the column names and their data types (e.g., String, Integer, Date). Spark can either guess the schema or you can define it precisely.
* inferSchema=True: Spark scans the data to guess the types. Quick for small datasets or exploration.
* StructType and StructField: Reliable for production ETL. Avoids misinterpretation, like treating a column of numbers as strings.



In [None]:
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
sample_data = [
    ("Alice", 25, "New York"),
    ("Bob", 30, "London"),
    ("Charlie", 25, "New York")
]

# 2. Explicitly define the schema (StructType)
defined_schema = StructType([StructField("name", StringType(), True),
                StructField("mployee_age", IntegerType(), True),
                StructField("city", StringType(), True)])

# 3. Create the DataFrame using the explicit schema
df = spark.createDataFrame(data=sample_data, schema=defined_schema)

df.printSchema()
df.show()

root
 |-- name: string (nullable = true)
 |-- mployee_age: integer (nullable = true)
 |-- city: string (nullable = true)

+-------+-----------+--------+
|   name|mployee_age|    city|
+-------+-----------+--------+
|  Alice|         25|New York|
|    Bob|         30|  London|
|Charlie|         25|New York|
+-------+-----------+--------+



**III. Core DataFrame Operations**  
These are the fundamental ways to manipulate columns and rows.

**1. Column Operations (Shape & Content)  **
Used to select, add, rename, or remove columns.

In [None]:
from pyspark.sql.functions import lit
# a. SELECT (Selecting and Renaming Columns)
df_select = df.select("city", df["name"].alias("emp_name"))
df_select.show(3)
# b. withColumn (Adding or Modifying a Column)
df1 = df.withColumn("city_NY", (df.city == "New York"))# Boolean flag
df2 = df1.withColumn("const", lit(100)) # Add column with constant value
df2.show()
# c. drop (Removing Columns)
df3 = df2.drop("const")
df3.show()

+--------+--------+
|    city|emp_name|
+--------+--------+
|New York|   Alice|
|  London|     Bob|
|New York| Charlie|
+--------+--------+

+-------+-----------+--------+-------+-----+
|   name|mployee_age|    city|city_NY|const|
+-------+-----------+--------+-------+-----+
|  Alice|         25|New York|   true|  100|
|    Bob|         30|  London|  false|  100|
|Charlie|         25|New York|   true|  100|
+-------+-----------+--------+-------+-----+

+-------+-----------+--------+-------+
|   name|mployee_age|    city|city_NY|
+-------+-----------+--------+-------+
|  Alice|         25|New York|   true|
|    Bob|         30|  London|  false|
|Charlie|         25|New York|   true|
+-------+-----------+--------+-------+



In [None]:
df.select("name","mployee_age").show()

+-------+-----------+
|   name|mployee_age|
+-------+-----------+
|  Alice|         25|
|    Bob|         30|
|Charlie|         25|
+-------+-----------+



**2. Filtering Operations (Selecting Rows)**  
Used to keep only the rows that match a specified condition. filter() and where() are identical in Spark.

In [None]:
# Select rows where age is less than 30
df.filter(df.mployee_age < 30).show()

# Use 'where' for multi-condition logic (identical to filter)
df2 = df.where((df.mployee_age<30) & (df['mployee_age']>=25)) #remeber to seperate with paranthesis
df2.show()

# isin (Checking if a value is in a list)
cities = ['New York']
df.filter(df.city.isin(cities)).show()

+-------+-----------+--------+
|   name|mployee_age|    city|
+-------+-----------+--------+
|  Alice|         25|New York|
|Charlie|         25|New York|
+-------+-----------+--------+

+-------+-----------+--------+
|   name|mployee_age|    city|
+-------+-----------+--------+
|  Alice|         25|New York|
|Charlie|         25|New York|
+-------+-----------+--------+

+-------+-----------+--------+
|   name|mployee_age|    city|
+-------+-----------+--------+
|  Alice|         25|New York|
|Charlie|         25|New York|
+-------+-----------+--------+



**3. Sorting Operations (Ordering Rows)**
Used to order the final result set.

In [None]:
from pyspark.sql.functions import col

# a. orderBy (Primary method for sorting)
df.orderBy(col("city").asc(), col('mployee_age').desc()).show() #notice asc is a function
# b. sort (Alias for orderBy)
df.sort("name").show()

+-------+-----------+--------+
|   name|mployee_age|    city|
+-------+-----------+--------+
|    Bob|         30|  London|
|  Alice|         25|New York|
|Charlie|         25|New York|
+-------+-----------+--------+

+-------+-----------+--------+
|   name|mployee_age|    city|
+-------+-----------+--------+
|  Alice|         25|New York|
|    Bob|         30|  London|
|Charlie|         25|New York|
+-------+-----------+--------+



In [None]:
df.sort(df.name.asc()).show()

+-------+-----------+--------+
|   name|mployee_age|    city|
+-------+-----------+--------+
|  Alice|         25|New York|
|    Bob|         30|  London|
|Charlie|         25|New York|
+-------+-----------+--------+



**I. Aggregations (Summarizing Data) 📊**  
Aggregations summarize data across groups.  
**groupBy:**	Groups rows based on one or more columns (e.g., grouping sales by region).  
**agg:**	Applies summary functions (like sum, avg, count) to the grouped data.  
**pivot:**	Rotates a unique value from one column into multiple new columns (like creating a cross-tabulation table).  
**rollup:** 	Creates subtotals for hierarchical columns. (e.g., Total Sales for Year, then for Quarter within Year, then a Grand Total.)  
**cube:**	Creates subtotals for every possible combination of grouping columns, regardless of hierarchy. (e.g., Sales by Region, by Product, and by Region-Product combination.)

In [None]:
from pyspark.sql.functions import sum, avg, count

# Sample DataFrame (DF)
data = [("A", "X", 10), ("A", "Y", 20), ("B", "X", 15), ("B", "Y", 5)]
df = spark.createDataFrame(data, ["key1", "key2", "value"])

# Simple GroupBy and Agg
df.groupBy('key1').agg(count('*').alias("total count"), sum('value').alias('total value')).show()

# Pivot Example (Move Key2 values into columns)
df.groupBy('key1').pivot('key2').agg(sum('value')).show()

# Rollup Example
df.rollup('key1','key2').agg(sum('value')).sort("key1", "key2").show()

# Cube Example
df.cube('key1','key2').agg(sum('value')).sort("key1", "key2").show()

+----+-----------+-----------+
|key1|total count|total value|
+----+-----------+-----------+
|   A|          2|         30|
|   B|          2|         20|
+----+-----------+-----------+

+----+---+---+
|key1|  X|  Y|
+----+---+---+
|   B| 15|  5|
|   A| 10| 20|
+----+---+---+

+----+----+----------+
|key1|key2|sum(value)|
+----+----+----------+
|NULL|NULL|        50|
|   A|NULL|        30|
|   A|   X|        10|
|   A|   Y|        20|
|   B|NULL|        20|
|   B|   X|        15|
|   B|   Y|         5|
+----+----+----------+

+----+----+----------+
|key1|key2|sum(value)|
+----+----+----------+
|NULL|NULL|        50|
|NULL|   X|        25|
|NULL|   Y|        25|
|   A|NULL|        30|
|   A|   X|        10|
|   A|   Y|        20|
|   B|NULL|        20|
|   B|   X|        15|
|   B|   Y|         5|
+----+----+----------+



**II. Joins (Combining Data) 🔗**  
Joins combine rows from two DataFrames based on a shared key.  
### Spark Join Types

- **Type: inner**  
  **What it Does:** Keeps only rows that have matches in both DataFrames. (Default.)  
  **Result Size:** Smallest  

- **Type: left**  
  **What it Does:** Keeps all rows from the left DF; includes matching rows from the right (filling with nulls if no match).  
  **Result Size:** Size of Left DF (or larger)  

- **Type: right**  
  **What it Does:** Keeps all rows from the right DF; includes matching rows from the left.  
  **Result Size:** Size of Right DF (or larger)  

- **Type: outer**  
  **What it Does:** Keeps all rows from both DFs; fills missing columns with nulls.  
  **Result Size:** Largest  

- **Type: left_semi**  
  **What it Does:** Keeps only the matching rows from the left DF. Does not include any columns from the right DF.  
  **Result Size:** Size of Left DF (or smaller)  

- **Type: left_anti**  
  **What it Does:** Keeps only the rows from the left DF that do not have a match in the right DF (The unmatched rows).  
  **Result Size:** Size of Left DF (or smaller)  


**Self-Joins ➕**  
A Self-Join is simply a standard join (usually inner or left) where you join a DataFrame to itself. You do this to compare rows within the same table, often using aliases to distinguish the two sides.

*Example Use:* Finding employees who report to the same manager, or finding pairs of products with the same price.

In [None]:
# Sample DataFrames
employee_data = [("Alice", 101, 201), ("Bob", 102, 201), ("Charlie", 103, 202), ("David", 104, None)]
dept_data = [(201, "Sales"), (203, "Marketing")]

df_emp = spark.createDataFrame(employee_data, ["Name", "EmpID", "DeptID"])
df_dept = spark.createDataFrame(dept_data, ["DeptID", "DeptName"])

# Inner Join
df_emp.join(df_dept, 'DeptID', 'inner').show()
# Left Anti Join (Employees without a matching department)
df_emp.join(df_dept, 'DeptID', 'left_anti').show()

+------+-----+-----+--------+
|DeptID| Name|EmpID|DeptName|
+------+-----+-----+--------+
|   201|Alice|  101|   Sales|
|   201|  Bob|  102|   Sales|
+------+-----+-----+--------+

+------+-------+-----+
|DeptID|   Name|EmpID|
+------+-------+-----+
|   202|Charlie|  103|
|  NULL|  David|  104|
+------+-------+-----+



In [None]:
# Self-Join Example
df_mgr = df_emp.alias("Manager")
df_emp_join = df_emp.alias("Employee")

df_self_join = df_emp_join.join(
    df_mgr,
    df_emp_join.DeptID == df_mgr.DeptID,
    "inner"
).select(
    df_emp_join.Name.alias("Employee"),
    df_mgr.Name.alias("Peer")
).filter("Employee != Peer") # Remove self-matches
print("--- Self Join (Employees with the same DeptID/Peer) ---")
df_self_join.show()

### **Spark SQL 📝**  
Spark SQL allows you to use familiar SQL syntax directly on your DataFrames.

**createOrReplaceTempView(name):** This action makes a DataFrame available as a temporary SQL table (a "view") for the current SparkSession. It's the bridge between the DataFrame API and the SQL API.

**spark.sql("..."):** This executes a SQL query string against the defined views and returns the result as a new DataFrame.

In [None]:
# Create a temporary view from the employee DataFrame
df_emp.createOrReplaceTempView('Employee')

# Run a SQL query against the view
spark.sql("SELECT Name FROM Employee WHERE EmpID = '103'").show()

+-------+
|   Name|
+-------+
|Charlie|
+-------+



**Caching Strategies 💾**

Caching is how Spark achieves its high speed and fault tolerance: it saves the intermediate results of an RDD/DataFrame in memory or on disk for faster reuse.


In [None]:
from pyspark.storagelevel import StorageLevel

# Cache the employee DataFrame using the default (MEMORY_ONLY)
df_emp.cache()
# Alternative: df_emp.persist(StorageLevel.MEMORY_AND_DISK)

# Force Spark to materialize and store the data now
df_emp.count()

print("DataFrame 'df_emp' is now cached in memory.")

**Window Functions (rank, dense_rank, lead, lag, rowsBetween, rangeBetween)**

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.window import Window
from pyspark.sql.functions import rank, dense_rank, lag, sum, desc, col

# df created with: ("A", 200, 2), ("A", 100, 1), ("A", 100, 3), ("B", 150, 5), ("B", 50, 4)
df = spark.createDataFrame([
    ("A", 200, 1), ("A", 200, 2), ("A", 100, 3), ("A", 300, 4), ("B", 50, 4), ("B", 150, 5)
], ["category", "sales", "day"])

# Ranking window: Partition by category, ordered by sales descending
rank_window = Window.partitionBy("category").orderBy(desc("sales"))

# Time-series/Rolling window: Partition by category, ordered by day
ts_window = Window.partitionBy("category").orderBy("day")

# Rolling sum: look at 2 rows before and including the current row
rolling_window = ts_window.rowsBetween(-2, 0)

result = df.withColumn("rank", rank().over(rank_window)) \
           .withColumn("dense_rank", dense_rank().over(rank_window)) \
           .withColumn("prev_sales", lag(col("sales"), 1).over(ts_window)) \
           .withColumn("roll_sum_sales", sum("sales").over(rolling_window))

print("--- Window Functions Output ---")
result.show()

--- Window Functions Output ---
+--------+-----+---+----+----------+----------+--------------+
|category|sales|day|rank|dense_rank|prev_sales|roll_sum_sales|
+--------+-----+---+----+----------+----------+--------------+
|       A|  200|  1|   2|         2|      NULL|           200|
|       A|  200|  2|   2|         2|       200|           400|
|       A|  100|  3|   4|         3|       200|           500|
|       A|  300|  4|   1|         1|       100|           600|
|       B|   50|  4|   2|         2|      NULL|            50|
|       B|  150|  5|   1|         1|        50|           200|
+--------+-----+---+----+----------+----------+--------------+



First, let's create a sample DataFrame representing a few credit trades made by different traders during the day.

In [None]:
from pyspark.sql import SparkSession, Window
from pyspark.sql import functions as F

# Initialize Spark Session
spark = SparkSession.builder.appName("CitadelWindowFunctions").getOrCreate()

# Sample trade data for Credit Default Swaps (CDS)
trade_data = [
    ("trader_A", "2025-10-06 09:01:00", 5000000),
    ("trader_B", "2025-10-06 09:02:00", 7000000),
    ("trader_A", "2025-10-06 09:03:00", 8000000), # Trader A's biggest trade
    ("trader_C", "2025-10-06 09:04:00", 6000000),
    ("trader_B", "2025-10-06 09:05:00", 7000000), # Tied with their first trade
    ("trader_A", "2025-10-06 09:06:00", 2000000),
]
columns = ["trader_id", "trade_time", "notional_usd"]
trades_df = spark.createDataFrame(trade_data, columns)

1. **Ranking Functions**:   
rank(), dense_rank(), row_number()
These functions assign a rank to each row within a partition based on some ordering.

*Real-time Scenario:* A trading desk manager at Citadel wants to quickly identify the top 3 largest trades (by notional value) for each trader to review their daily performance and risk exposure.

In [None]:
# Define a window partitioned by trader and ordered by trade size (descending)
window_spec_rank = Window.partitionBy("trader_id").orderBy(desc("notional_usd"))
# Apply ranking functions
df1 = trades_df.withColumn("rank", F.rank().over(window_spec_rank))\
              .withColumn("dense_rank", F.dense_rank().over(window_spec_rank)\
              .withColumn("row_num", F.row_number().over(window_spec_rank)
df1.show()

+---------+-------------------+------------+----+----------+-------+
|trader_id|         trade_time|notional_usd|rank|dense_rank|row_num|
+---------+-------------------+------------+----+----------+-------+
| trader_A|2025-10-06 09:03:00|     8000000|   1|         1|      1|
| trader_A|2025-10-06 09:01:00|     5000000|   2|         2|      2|
| trader_A|2025-10-06 09:06:00|     2000000|   3|         3|      3|
| trader_B|2025-10-06 09:02:00|     7000000|   1|         1|      1|
| trader_B|2025-10-06 09:05:00|     7000000|   1|         1|      2|
| trader_C|2025-10-06 09:04:00|     6000000|   1|         1|      1|
+---------+-------------------+------------+----+----------+-------+



2. **Analytic Functions: lag() and lead()**

These functions access data from a previous (lag) or subsequent (lead) row within the same partition.

**Real-time Scenario:** For a high-frequency credit trading algorithm, you need to calculate the change in notional value between a trader's consecutive trades. This helps in analyzing trading patterns or identifying sudden, large changes in position size that might trigger a risk alert.

In [None]:
# Define a window partitioned by trader and ordered by time
window_spec_time = Window.partitionBy("trader_id").orderBy("trade_time")
# Apply ranking functions
df3 = trades_df.withColumn("lag", F.lag("notional_usd", 1).over(window_spec_time))
df3.show()

+---------+-------------------+------------+-------+
|trader_id|         trade_time|notional_usd|    lag|
+---------+-------------------+------------+-------+
| trader_A|2025-10-06 09:01:00|     5000000|   NULL|
| trader_A|2025-10-06 09:03:00|     8000000|5000000|
| trader_A|2025-10-06 09:06:00|     2000000|8000000|
| trader_B|2025-10-06 09:02:00|     7000000|   NULL|
| trader_B|2025-10-06 09:05:00|     7000000|7000000|
| trader_C|2025-10-06 09:04:00|     6000000|   NULL|
+---------+-------------------+------------+-------+



3. **Aggregate Functions Over a Window: sum()**  
You can use standard aggregate functions like sum(), avg(), max(), etc., over a window to create running totals or moving averages.

**Real-time Scenario:** A risk management system at Citadel needs to calculate the cumulative trading volume (running total) for each trader throughout the day. This is critical for monitoring intraday credit limits and ensuring no single trader exceeds their authorized exposure.

In [None]:
# Define a window that includes all rows from the start up to the current row for each trader
window_spec_running_total = Window.partitionBy("trader_id").orderBy("trade_time").rowsBetween(Window.unboundedPreceding, Window.currentRow)

# Calculate the running total of notional value
df4 = trades_df.withColumn("sum", F.sum("notional_usd").over(window_spec_running_total))
df4.show()

+---------+-------------------+------------+--------+
|trader_id|         trade_time|notional_usd|     sum|
+---------+-------------------+------------+--------+
| trader_A|2025-10-06 09:01:00|     5000000| 5000000|
| trader_A|2025-10-06 09:03:00|     8000000|13000000|
| trader_A|2025-10-06 09:06:00|     2000000|15000000|
| trader_B|2025-10-06 09:02:00|     7000000| 7000000|
| trader_B|2025-10-06 09:05:00|     7000000|14000000|
| trader_C|2025-10-06 09:04:00|     6000000| 6000000|
+---------+-------------------+------------+--------+



**Python UDF/Pandas UDF**

In [None]:
import pandas as pd
from pyspark.sql.functions import udf, pandas_udf, col
from pyspark.sql.types import DoubleType

df = spark.createDataFrame([(10,), (20,), (30,)], ["feature_a"])

# Standard Python UDF (Slow: Row-at-a-time)
def python_udf_logic(a):
    return a * 1.5 + 2

python_udf = udf(python_udf_logic, DoubleType())

# Pandas UDF (Fast: Batch/Vectorized processing via Arrow)
@pandas_udf(DoubleType())
def pandas_udf_logic(a: pd.Series) -> pd.Series:
    # Vectorized operation is done here
    return a * 1.5 + 2

print("--- UDF Comparison Output ---")
result = df.withColumn("python_udf_result", python_udf(col("feature_a"))) \
           .withColumn("pandas_udf_result", pandas_udf_logic(col("feature_a")))
result.show()

--- UDF Comparison Output ---
+---------+-----------------+-----------------+
|feature_a|python_udf_result|pandas_udf_result|
+---------+-----------------+-----------------+
|       10|             17.0|             17.0|
|       20|             32.0|             32.0|
|       30|             47.0|             47.0|
+---------+-----------------+-----------------+



**Complex Types (arrays/maps/structs) with explode, posexplode**

In [None]:
from pyspark.sql import SparkSession, Window
from pyspark.sql.functions import explode, posexplode, concat, lit, row_number, count


spark = SparkSession.builder.appName("pyspark1").getOrCreate()

data = [("User1", ["tag1", "tag2", "tag3"]),
        ("User2", ["tagA"])]
column_name = ["user", "tags"]
df = spark.createDataFrame(data,column_name)

df1 = df.select(df.user, explode(df.tags).alias('Tag'))
df1.show()

df2 = df.select(df.user, posexplode(df.tags))\
                   .withColumnRenamed("pos", "tag_index") \
                   .withColumnRenamed("col", "single_tag")
df2.show()

+-----+----+
| user| Tag|
+-----+----+
|User1|tag1|
|User1|tag2|
|User1|tag3|
|User2|tagA|
+-----+----+

+-----+---------+----------+
| user|tag_index|single_tag|
+-----+---------+----------+
|User1|        0|      tag1|
|User1|        1|      tag2|
|User1|        2|      tag3|
|User2|        0|      tagA|
+-----+---------+----------+



In [None]:
#df3 = df1.withColumn("litcol", concat(df1.user,lit(i)) for i in range(1,4))
window1 = Window().orderBy(df1.user)
df3 = df1.withColumn("litcol", concat(df1.user,lit(row_number().over(window1))))
# df3 = df1.withColumn("litcol", concat(lit("user"),monotonically_increasing_id()))
df3.show()
df4 = df1.withColumn("litcol", concat(df1.user,row_number().over(window1)))
df4.show()

+-----+----+------+
| user| Tag|litcol|
+-----+----+------+
|User1|tag1|User11|
|User1|tag2|User12|
|User1|tag3|User13|
|User2|tagA|User24|
+-----+----+------+

+-----+----+------+
| user| Tag|litcol|
+-----+----+------+
|User1|tag1|User11|
|User1|tag2|User12|
|User1|tag3|User13|
|User2|tagA|User24|
+-----+----+------+



**Caching in Pyspark**

In [None]:
# --- CRITICAL STEP: Cache the result of the expensive work ---
processed_features.cache()

print("--- Running Scenario B: WITH CACHING ---")
start_time = time.time()

for i in range(ITERATIONS):
    # Action: The first count() triggers the caching.
    # Subsequent count() calls read from the cache.
    count = processed_features.count()
    print(f"Iteration {i+1} (Cached): Counted {count} records.")

end_time = time.time()
print(f"Total time WITH Caching: {end_time - start_time:.4f} seconds.")

# Clean up memory after the iterative process is done
processed_features.unpersist()
spark.stop()

## **Pytest and SparkTestingBase in the Code**

**A. Isolating Transformation Logic (The Code)**

In [None]:
from pyspark.sql import DataFrame, functions as F

# This is the core 'unit' of business logic you must test
def calculate_spread_change_signal(df: DataFrame) -> DataFrame:
    """Calculates the 5-day change in the CDS spread and flags a trade signal."""

    # 1. Feature Engineering (The business logic)
    df = df.withColumn(
        "prev_5d_spread",
        F.lag(F.col("cds_spread"), 5).over(Window.partitionBy("ticker").orderBy("trade_date"))
    )

    # 2. Transformation
    df = df.withColumn(
        "spread_change",
        F.col("cds_spread") - F.col("prev_5d_spread")
    )

    # 3. Signal Generation
    df = df.withColumn(
        "is_buy_signal",
        F.when(F.col("spread_change") < -0.05, 1).otherwise(0) # Logic: Buy if spread tightens significantly
    )
    return df.drop("prev_5d_spread")

**B. The Pytest/Testing Code**

Step 1: The Spark Fixture (Setup)  
You use a pytest fixture (usually in a file named conftest.py) to create a single, lightweight local Spark Session that is reused across hundreds of tests, eliminating startup overhead.

In [None]:
# conftest.py
import pytest
from pyspark.sql import SparkSession

@pytest.fixture(scope="session")
def spark_session():
    # Setup Spark in local mode for fast, single-machine testing
    spark = SparkSession.builder.master("local[*]").appName("test_session").getOrCreate()
    yield spark # Provide the session to tests
    spark.stop() # Teardown after all tests finish

Step 2: The Unit Test (Act & Assert)

A test function is written that uses the spark_session fixture and compares the output DataFrame against a manually defined expected DataFrame.

In [None]:
# test_credit_logic.py
from your_module import calculate_spread_change_signal
# Assuming you use spark-testing-base or chispa for comparison logic

def test_spread_tightening_triggers_buy_signal(spark_session):
    # ARRANGE: Create Canned Input Data (The "Unit" Input)
    input_data = [
        ("CDE", "2025-10-01", 1.50),
        ("CDE", "2025-10-02", 1.45),
        ("CDE", "2025-10-03", 1.40),
        ("CDE", "2025-10-04", 1.35),
        ("CDE", "2025-10-05", 1.20)  # Spread tightens by 0.30
    ]
    input_df = spark_session.createDataFrame(input_data, ["ticker", "trade_date", "cds_spread"])

    # ACT: Run the function under test
    actual_df = calculate_spread_change_signal(input_df)

    # ASSERT: Define Expected Output Data
    expected_data = [
        ("CDE", "2025-10-01", 1.50, None, None),
        # ... rows 2, 3, 4 will have change < -0.05, so signal=0
        ("CDE", "2025-10-05", 1.20, -0.30, 1) # Signal=1 because 1.20 - 1.50 = -0.30 < -0.05
    ]
    expected_df = spark_session.createDataFrame(expected_data, ["ticker", "trade_date", "cds_spread", "spread_change", "is_buy_signal"])

    # Final ASSERTION using a DataFrame comparison utility
    # spark_testing_base or similar utility handles the comparison logic (order, schema, data)
    assert_dfs_equal(actual_df.sort("trade_date"), expected_df.sort("trade_date"))