In [0]:
dbutils.widgets.text(name="env",defaultValue="",label="Enter the environment in lower case")
env = dbutils.widgets.get("env")

from pyspark.sql.functions import (
    lit, avg, col, row_number, when, stddev, lag, sum as spark_sum, max as spark_max
)
from pyspark.sql.window import Window
from datetime import datetime, timedelta
from functools import reduce
from pyspark.sql.functions import date_sub, current_date


input_table = f"{env}_silver.daily_price_aggregates"
output_table = f"{env}_gold.daily_price_indicators"
windows = [60,26,12]

def read_prices(input_table):

    # Calculate start date 90 days ago (Python side)
    start_date = (datetime.today() - timedelta(days=90)).date()

    # Extract year and month from start_date
    start_year = start_date.year
    start_month = start_date.month

    # Also get current year and month
    current_year = datetime.today().year
    current_month = datetime.today().month

    df_recent = spark.table(input_table) \
        .filter(
            ( (col("year") == start_year) & (col("month") >= start_month) ) |
            ( (col("year") > start_year) & (col("year") < current_year) ) |
            ( (col("year") == current_year) & (col("month") <= current_month) )
        )

    df_filtered = df_recent.filter(col("date") >= lit(start_date.strftime("%Y-%m-%d")))
    window_spec = Window.partitionBy("symbol").orderBy(col("date").desc())
    df_latest = df_filtered.withColumn("rn", row_number().over(window_spec))

    return df_latest


def moving_average(df_input, df_output, window_sizes):
    ma_dfs = []
    for ws in window_sizes:
        valid_symbols = (
            df_input.groupBy("symbol")
            .agg(spark_max("rn").alias("max_rn"))
            .filter(col("max_rn") >= ws)
            .select("symbol")
        )
        df_filtered = df_input.join(valid_symbols, on="symbol", how="inner")

        ma_df = df_filtered.groupBy("symbol").agg(
            avg(when(col("rn") <= ws, col("avg_price"))).alias(f"ma_{ws}")
        )
        ma_dfs.append(ma_df)

    # Join all MAs on symbol
    if ma_dfs:
        ma_all = reduce(lambda left, right: left.join(right, on="symbol", how="outer"), ma_dfs)
    else:
        ma_all = df_input.select("symbol").distinct()

    df_output = df_output.join(ma_all, on="symbol", how="left")
    return df_output


def take_latest_row(table_path, granulity):

    if granulity == "daily":
        datecol = "date"
    elif granulity == "hourly":
        datecol = "timestamp"

    five_days_ago = date_sub(current_date(), 5)
    df_recent = (
        spark.read.table(table_path)
        .filter(col(datecol) >= five_days_ago)  # Adjust to your date column
    )

    # Define window: partition by symbol, order by date descending
    window_spec = Window.partitionBy("symbol").orderBy(col(datecol).desc())

    # Assign row number to get the latest record per symbol
    df_latest = (
        df_recent
        .withColumn("rn", row_number().over(window_spec))
        .filter(col("rn") == 1)
        .drop("rn")
    )
    return df_latest


def exponential_moving_average(latest_records, df_output, window_sizes):
    # Step 1: Select relevant EMA columns from latest_records and rename to avoid ambiguity
    ema_cols = [f"ema_{ws}" for ws in window_sizes]
    existing_cols = [c for c in ema_cols if c in latest_records.columns]

    latest_selected = latest_records.select(
        col("symbol"),
        *[col(c).alias(f"prev_{c}") for c in existing_cols]
    )

    # Step 2: Join historical EMA values to today's output
    joined_df = df_output.join(latest_selected, on="symbol", how="left")

    # Step 3: Calculate each EMA using the smoothing formula
    for ws in window_sizes:
        k = 2 / (ws + 1)
        ema_col = f"ema_{ws}"
        prev_col = f"prev_{ema_col}"
        ma_col = f"ma_{ws}"

        joined_df = joined_df.withColumn(
            ema_col,
            when(
                col(prev_col).isNull(),  # First day fallback to MA
                col(ma_col)
            ).otherwise(
                col(prev_col) * (1 - k) + col("avg_price") * k
            )
        )

    # Step 4: Select calculated EMAs
    result_cols = ["symbol"] + [f"ema_{ws}" for ws in window_sizes]
    ema_today = joined_df.select(*result_cols)

    # Step 5: Merge into df_output
    df_output = df_output.join(ema_today, on="symbol", how="left")

    return df_output


def calculate_rsi(df_input, df_output, window_sizes):
    rsi_dfs = []

    for ws in window_sizes:
        win_spec = Window.partitionBy("symbol").orderBy(col("date").desc())

        # Calculate price difference (delta)
        df_with_delta = df_input.withColumn("price_diff", col("avg_price") - lag("avg_price", 1).over(win_spec))

        # Calculate gain and loss
        df_with_gains = df_with_delta.withColumn("gain", when(col("price_diff") > 0, col("price_diff")).otherwise(0))
        df_with_gains = df_with_gains.withColumn("loss", when(col("price_diff") < 0, -col("price_diff")).otherwise(0))

        # Only keep rows where we have enough history
        df_valid = df_with_gains.filter(col("rn") <= ws)

        # Aggregate total gains and losses over the window
        rsi_df = df_valid.groupBy("symbol").agg(
            (spark_sum("gain") / ws).alias("avg_gain"),
            (spark_sum("loss") / ws).alias("avg_loss")
        )

        # Calculate RSI
        rsi_df = rsi_df.withColumn(
            f"rsi_{ws}",
            when(
                col("avg_loss") == 0, 100.0
            ).otherwise(
                100 - (100 / (1 + (col("avg_gain") / col("avg_loss"))))
            )
        ).select("symbol", f"rsi_{ws}")

        rsi_dfs.append(rsi_df)

    # Join all RSI outputs on symbol
    if rsi_dfs:
        rsi_all = reduce(lambda left, right: left.join(right, on="symbol", how="outer"), rsi_dfs)
    else:
        rsi_all = df_input.select("symbol").distinct()

    df_output = df_output.join(rsi_all, on="symbol", how="left")

    return df_output
    


def calculate_macd(latest_records, df_output):
    # Step 1: Get EMA_12 and EMA_26 from previous output (rename to avoid ambiguity)
    macd_required_emas = ["ema_12", "ema_26", "macd", "signal"]
    existing_cols = [col_name for col_name in macd_required_emas if col_name in latest_records.columns]

    latest_selected = latest_records.select(
        col("symbol"),
        *[col(c).alias(f"prev_{c}") for c in existing_cols]
    )

    # Step 2: Join with today's base data (avg_price already in df_output)
    df_joined = df_output.join(latest_selected, on="symbol", how="left")

    # EMA smoothing constants
    k_12 = 2 / (12 + 1)
    k_26 = 2 / (26 + 1)
    k_9 = 2 / (9 + 1)

    # Step 3: Calculate EMA_12
    df_joined = df_joined.withColumn(
        "ema_12",
        when(col("prev_ema_12").isNull(), col("ma_12"))
        .otherwise(col("prev_ema_12") * (1 - k_12) + col("avg_price") * k_12)
    )

    # Step 4: Calculate EMA_26
    df_joined = df_joined.withColumn(
        "ema_26",
        when(col("prev_ema_26").isNull(), col("ma_26"))
        .otherwise(col("prev_ema_26") * (1 - k_26) + col("avg_price") * k_26)
    )

    # Step 5: Calculate MACD Line
    df_joined = df_joined.withColumn("macd", col("ema_12") - col("ema_26"))

    # Step 6: Calculate Signal Line (EMA_9 of MACD)
    df_joined = df_joined.withColumn(
        "signal",
        when(col("prev_signal").isNull(), col("macd"))
        .otherwise(col("prev_signal") * (1 - k_9) + col("macd") * k_9)
    )

    # Step 7: Histogram
    df_joined = df_joined.withColumn("macd_hist", col("macd") - col("signal"))

    # Step 8: Select only relevant columns for update
    macd_cols = ["symbol", "ema_12", "ema_26", "macd", "signal", "macd_hist"]
    macd_df = df_joined.select(macd_cols)

    # Step 9: Merge into df_output
    df_output = df_output.join(macd_df, on="symbol", how="left")

    return df_output


def calculate_bollinger_bands(df_input, df_output, window_sizes):
    bb_dfs = []

    for ws in window_sizes:
        # Only calculate BB for symbols with at least `ws` days
        valid_symbols = (
            df_input.groupBy("symbol")
            .agg(spark_max("rn").alias("max_rn"))
            .filter(col("max_rn") >= ws)
            .select("symbol")
        )
        df_filtered = df_input.join(valid_symbols, on="symbol", how="inner")

        # Compute average and stddev of avg_price over last `ws` days
        bb_df = df_filtered.groupBy("symbol").agg(
            avg(when(col("rn") <= ws, col("avg_price"))).alias(f"bb_ma_{ws}"),
            stddev(when(col("rn") <= ws, col("avg_price"))).alias(f"bb_std_{ws}")
        )

        # Calculate upper and lower bands
        bb_df = bb_df.withColumn(f"bb_upper_{ws}", col(f"bb_ma_{ws}") + 2 * col(f"bb_std_{ws}"))
        bb_df = bb_df.withColumn(f"bb_lower_{ws}", col(f"bb_ma_{ws}") - 2 * col(f"bb_std_{ws}"))
        bb_df = bb_df.withColumn(f"bb_width_{ws}", col(f"bb_upper_{ws}") - col(f"bb_lower_{ws}"))

        # Select relevant columns
        bb_df = bb_df.select("symbol", f"bb_upper_{ws}", f"bb_lower_{ws}", f"bb_width_{ws}")
        bb_dfs.append(bb_df)

    # Join all BB results
    if bb_dfs:
        bb_all = reduce(lambda left, right: left.join(right, on="symbol", how="outer"), bb_dfs)
    else:
        bb_all = df_input.select("symbol").distinct()

    df_output = df_output.join(bb_all, on="symbol", how="left")

    return df_output


def calculate_obv(df_input, df_output):

    # Create a window by symbol ordered by date
    win = Window.partitionBy("symbol").orderBy("date")

    # Compute price difference
    df_with_delta = df_input.withColumn("price_diff", col("avg_price") - lag("avg_price", 1).over(win))

    # Determine OBV adjustment: +volume if price up, -volume if price down, 0 otherwise
    df_with_direction = df_with_delta.withColumn(
        "obv_change",
        when(col("price_diff") > 0, col("volume"))
        .when(col("price_diff") < 0, -col("volume"))
        .otherwise(0)
    )

    # OBV is cumulative sum of obv_change per symbol
    df_with_obv = df_with_direction.withColumn("obv", spark_sum("obv_change").over(win.rowsBetween(Window.unboundedPreceding, Window.currentRow)))

    # Keep only the latest OBV (rn == 1)
    latest_obv = df_with_obv.filter(col("rn") == 1).select("symbol", "obv")

    # Join into result
    df_output = df_output.join(latest_obv, on="symbol", how="left")

    return df_output


df_filtered = read_prices(input_table)
df_filtered.cache()

df_result = df_filtered.filter(col("rn") == 1).drop("rn")
#Moving averages
df_result = moving_average(df_filtered, df_result, windows)
#Exponential moving averages
latest_records = take_latest_row(output_table, "daily")
df_result = exponential_moving_average(latest_records, df_result, windows)
#RSI
df_result = calculate_rsi(df_filtered, df_result, windows)
#MACD
df_result = calculate_macd(latest_records, df_result)
#Bollinger bands
df_result = calculate_bollinger_bands(df_filtered, df_result, windows)
#OBV
df_result = calculate_obv(df_filtered, df_result)

pandas_df = df_result.limit(100).toPandas()
display(pandas_df)



In [0]:
df = spark.sql("SELECT * FROM hive_metastore.dev_gold.daily_price_indicators")
pandas_df = df.limit(100).toPandas()
display(pandas_df)