In [3]:
import pandas as pd
import numpy as np
import yfinance as yf

In [103]:
obj = yf.Ticker('AIG')
fin = obj.financials
fin.T[['Pretax Income']]

Unnamed: 0,Pretax Income
2024-12-31,3870000000.0
2023-12-31,2867000000.0
2022-12-31,3772000000.0
2021-12-31,13347000000.0
2020-12-31,


In [128]:
from pyspark.sql import SparkSession

spark = (
    SparkSession.builder
        .appName("Commodity+ShortVol ETL")
        .master("local[*]")      # use all CPU cores on your laptop
        .getOrCreate()
)

raw_path = "../data/raw"

shortvol_df = spark.read.parquet(f"{raw_path}/CNMS_shortvol_2020to2025-07-07.parquet")
prices_df   = spark.read.parquet(f"{raw_path}/equity_commodity_data.parquet")
fundamental_df   = spark.read.parquet(f"{raw_path}/equity_fundamental_data.parquet")


print(f"Short‑vol rows : {shortvol_df.count():,}")
print(f"Price rows : {prices_df.count():,}")
print(f"Fundamentals rows : {fundamental_df.count():,}")

Short‑vol rows : 15,327,394
Price rows : 847,530
Fundamentals rows : 2,020


## Fundamentals

In [129]:
from pyspark.sql import functions as F
from pyspark.sql import DataFrame

def add_ratio_pct(
        df: DataFrame,
        numer: str,
        denom: str,
        out_col: str,
        decimals: int = 2,
        make_positive: bool = False
) -> DataFrame:
    """
    Adds a numer/denom ratio (x 100) column to the relevant df

    Args:
        df (DataFrame): Spark DataFrame
        numer (str): Column name for the numerator
        denom (str): Column name for the denominator
        out_col (str): Column name for the new ratio column
        decimals (int): Number of decimal places for rounding
        make_positive (bool): If True, take the absolute value of numerator (e.g. for Capex which is negative)

    Returns:
        DataFrame: Original DataFrame with additional column
    """

    # Turn numerator positive or keep the same
    num_expr = F.abs(F.col(numer)) if make_positive else F.col(numer)
    
    # Caclulate ratio and x 100, rounded
    ratio = num_expr / F.col(denom)
    ratio_pct = F.round(ratio*100, decimals)

    return df.withColumn(f"{out_col} %", ratio_pct)

In [130]:
specs = [
    ("Capex", "Revenue", "Capex/Revenue", True),
    ("OCF", "Revenue", "OCF/Revenue", False),
    ("R&D", "Revenue", "R&D/Revenue", False),
    ("GrossProfit", "Revenue", "GP Margin", False),
    ("OpProfit", "Revenue", "OP Margin", False),
    ("Pretax Income", "Revenue", "Pretax Margin", False)
]

for numer, denom, out, pos in specs:
    fundamental_df = add_ratio_pct(fundamental_df, numer, denom, out, make_positive=pos)

In [131]:
fundamental_df.count()

2020

In [132]:
(fundamental_df
 .fillna({"Capex": 0, "Revenue": -1, "`GP Margin %`": 0, "`OP Margin %`": 0})
 .filter("`Pretax Margin %` > 100")
).show()

+-------------------+----------+----------+-------------+----------+-------+-----------+----------+------+---------------+-------------+-------------+-----------+-----------+---------------+
|         FiscalYear|     Capex|       OCF|Pretax Income|  OpProfit|    R&D|GrossProfit|   Revenue|Ticker|Capex/Revenue %|OCF/Revenue %|R&D/Revenue %|GP Margin %|OP Margin %|Pretax Margin %|
+-------------------+----------+----------+-------------+----------+-------+-----------+----------+------+---------------+-------------+-------------+-----------+-----------+---------------+
|2022-12-31 00:00:00|-4.59773E8|3.117141E9|     4.3806E9|2.169438E9|1.754E7| 3.129256E9|4.182163E9|   PSA|          10.99|        74.53|         0.42|      74.82|      51.87|         104.74|
+-------------------+----------+----------+-------------+----------+-------+-----------+----------+------+---------------+-------------+-------------+-----------+-----------+---------------+



In [105]:
"Revenue IS NULL OR Capex IS NULL OR `GP Margin %` IS NULL OR `OP Margin %` IS NULL"

dropped = fundamental_df.filter(
    "`GP Margin %` IS NULL OR `GP Margin %` < 100"
)
dropped.show()

+-------------------+----------+----------+----------+---------+------------+------------+------+---------------+-------------+-------------+-----------+-----------+
|         FiscalYear|     Capex|       OCF|  OpProfit|      R&D| GrossProfit|     Revenue|Ticker|Capex/Revenue %|OCF/Revenue %|R&D/Revenue %|GP Margin %|OP Margin %|
+-------------------+----------+----------+----------+---------+------------+------------+------+---------------+-------------+-------------+-----------+-----------+
|2024-12-31 00:00:00|  -1.181E9|   1.819E9|   4.822E9|  1.085E9|   1.0128E10|   2.4575E10|   MMM|           4.81|          7.4|         4.42|      41.21|      19.62|
|2023-12-31 00:00:00|  -1.615E9|    6.68E9|-1.0725E10|  1.154E9|     9.627E9|    2.461E10|   MMM|           6.56|        27.14|         4.69|      39.12|     -43.58|
|2022-12-31 00:00:00|  -1.749E9|   5.591E9|   1.916E9|   1.16E9|   1.0308E10|   2.6161E10|   MMM|           6.69|        21.37|         4.43|       39.4|       7.32|
|202

In [None]:
from typing import List

def sanity_filter(df: DataFrame) -> DataFrame:
    """
    Remove rows with impossible values for revenue, capex and margin columns
    """

    # Need to convert NULLs to 0s or -1 otherwise rows are dropped
    return(
        df
        .fillna({"Capex": 0, "Revenue": -1})
        .filter("Revenue > 0")
        .filter("Capex <= 0")
        .filter("`GP Margin %` IS NULL OR `GP Margin %` < 100")
        .filter("`OP Margin %` IS NULL OR `OP Margin %` < 100")
    )

def winsorise(
        df: DataFrame,
        cols: List[str],
        lower: float = 0.01,
        upper: float = 0.99
) -> DataFrame:
    """
    For each column in cols, add:
    - <col>_capped - numeric, capped at [lower, upper] quantiles
    - <col>_was_capped - 1/0 flag

    Returns new DataFrame
    """

    out = df
    for c in cols:
        q_lo, q_hi = out.approxQuantile(c, [lower, upper], 0.001)

        out = (
            out.withColumn(
                f"{c}_capped",
                F.when(F.col(c) < q_lo, q_lo)
                .when(F.col(c) > q_hi, q_hi)
                .otherwise(F.col(c))
            )
            .withColumn(
                f"{c}_was_capped",
                F.when((F.col(c) < q_lo) | (F.col(c) > q_hi), 1).otherwise(0)
            )
        )

    return out

In [126]:
COLUMNS_TO_CAP = [
    "Capex/Revenue %", "OCF/Revenue %",
    "R&D/Revenue %", "GP Margin %", "OP Margin %"
]

new = sanity_filter(fundamental_df)
fundamental_final = winsorise(new, COLUMNS_TO_CAP)

In [127]:
fundamental_final.count()

2011

In [118]:
process_path = "../data/processed"

process_df = spark.read.parquet(f"{process_path}/fundamentals_processed.parquet")

In [120]:
process_df.show()

+-------------------+----------+----------+-------------+----------+---------+------------+------------+------+---------------+-------------+-------------+-----------+-----------+---------------+----------------------+--------------------------+--------------------+------------------------+--------------------+------------------------+------------------+----------------------+------------------+----------------------+----------------------+--------------------------+
|         FiscalYear|     Capex|       OCF|Pretax Income|  OpProfit|      R&D| GrossProfit|     Revenue|Ticker|Capex/Revenue %|OCF/Revenue %|R&D/Revenue %|GP Margin %|OP Margin %|Pretax Margin %|Capex/Revenue %_capped|Capex/Revenue %_was_capped|OCF/Revenue %_capped|OCF/Revenue %_was_capped|R&D/Revenue %_capped|R&D/Revenue %_was_capped|GP Margin %_capped|GP Margin %_was_capped|OP Margin %_capped|OP Margin %_was_capped|Pretax Margin %_capped|Pretax Margin %_was_capped|
+-------------------+----------+----------+-------------

25/07/21 18:14:54 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


## Price and volume

In [8]:
shortvol_df.show(5, truncate=False)

+----------+------+-----------+-----------------+-----------+
|date      |symbol|shortvolume|shortexemptvolume|totalvolume|
+----------+------+-----------+-----------------+-----------+
|2019-01-02|A     |126407.0   |3075.0           |297263.0   |
|2019-01-02|AA    |465945.0   |0.0              |944291.0   |
|2019-01-02|AAAU  |9909.0     |0.0              |10408.0    |
|2019-01-02|AABA  |168834.0   |0.0              |945871.0   |
|2019-01-02|AAC   |125310.0   |0.0              |194619.0   |
+----------+------+-----------+-----------------+-----------+
only showing top 5 rows



In [9]:
# Normalise column names
shortvol_df = shortvol_df.withColumnRenamed("symbol", "Ticker")
shortvol_df = shortvol_df.withColumnRenamed("date", "Date")
shortvol_df = shortvol_df.withColumn("Ticker", F.upper("Ticker"))

# Merge price and short vol together
prices_short = (
    prices_df.alias("p")
    .join(
        shortvol_df.alias("s"),
        on=["Date", "Ticker"],
        how="left"
    )
    .drop("Close")
)

prices_short.show(5, truncate=False)

                                                                                

+-------------------+------+------------------+-----------------+-----------------+------------------+---------+---------+-----------+-----------------+-----------+
|Date               |Ticker|Adj Close         |High             |Low              |Open              |Volume   |AssetType|shortvolume|shortexemptvolume|totalvolume|
+-------------------+------+------------------+-----------------+-----------------+------------------+---------+---------+-----------+-----------------+-----------+
|2019-01-02 00:00:00|ADBE  |224.57000732421875|226.1699981689453|219.0            |219.91000366210938|2784100.0|Equity   |343055.0   |119.0            |736190.0   |
|2019-01-02 00:00:00|ABT   |61.766048431396484|70.95999908447266|69.06999969482422|70.38999938964844 |8737200.0|Equity   |890617.0   |6912.0           |1829654.0  |
|2019-01-02 00:00:00|A     |62.67951583862305 |66.56999969482422|65.30000305175781|66.5              |2113300.0|Equity   |126407.0   |3075.0           |297263.0   |
|2019-01-0

                                                                                

In [10]:
from pyspark.sql import Window

# Calculate short_vol ratio when volume > 0
prices_short = (
    prices_short
    .withColumn(
        "ShortVolRatio",
        F.when(F.col("totalvolume") > 0, F.col("shortvolume") / F.col("totalvolume"))
        .otherwise(None)
    )
)

# Flag issues with short_vol and prices
prices_short = (
    prices_short
    .withColumn(
        "flag_bad_short_vol",
        (F.col("shortvolume") < 0) | (F.col("ShortVolRatio") > 1)
    )
    .withColumn(
        "flag_bad_price",
        (F.col("Adj Close") <= 0) | (F.col("High") < F.col("Low"))
    )
)

prices_short = prices_short.dropna(how='all', subset=['Adj Close', 'High', 'Low', 'Open', 'Volume'])
prices_short.show(20, truncate=False)



+-------------------+------+------------------+------------------+------------------+------------------+---------+---------+-----------+-----------------+-----------+-------------------+------------------+--------------+
|Date               |Ticker|Adj Close         |High              |Low               |Open              |Volume   |AssetType|shortvolume|shortexemptvolume|totalvolume|ShortVolRatio      |flag_bad_short_vol|flag_bad_price|
+-------------------+------+------------------+------------------+------------------+------------------+---------+---------+-----------+-----------------+-----------+-------------------+------------------+--------------+
|2019-01-02 00:00:00|ALLE  |74.51044464111328 |81.06999969482422 |78.30999755859375 |78.6500015258789  |864300.0 |Equity   |42528.0    |104.0            |146013.0   |0.29126173696862606|false             |false         |
|2019-01-02 00:00:00|AOS   |38.59386444091797 |43.4900016784668  |41.61000061035156 |42.209999084472656|1998500.0|Eq

                                                                                

## Checks

In [11]:
prices_short.filter("flag_bad_price = TRUE").count()

n_true = (
    prices_short
        .agg(F.sum(F.col("flag_bad_short_vol").cast("int")).alias("true_count"))
        .first()["true_count"]
)

print("rows where flag_bad_short_vol is true:", n_true)

[Stage 43:>                                                       (0 + 10) / 11]

rows where flag_bad_short_vol is true: 0


                                                                                

25/07/21 07:27:31 WARN HeartbeatReceiver: Removing executor driver with no recent heartbeats: 909130 ms exceeds timeout 120000 ms
25/07/21 07:27:31 WARN SparkContext: Killing executors is not supported by current scheduler.
25/07/21 07:27:39 ERROR Inbox: Ignoring error
org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.SparkThreadUtils$.awaitResult(SparkThreadUtils.scala:56)
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:310)
	at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRefByURI(RpcEnv.scala:102)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRef(RpcEnv.scala:110)
	at org.apache.spark.util.RpcUtils$.makeDriverRef(RpcUtils.scala:36)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.driverEndpoint$lzycompute(BlockManagerMasterEndpoint.scala:124)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.org$apache$spark$storage$BlockManagerMasterEndpoint$$