# Shared utils 

This document explains each section of `src/shared_utils.py`. The module provides functions used by both training semi-final.ipynb and live-predictions real_time_inference.py: candle aggregation, feature generation, and a UDF to extract the UP-class probability from Spark ML output.

## Imports

PySpark SQL (Window, column functions, types), ML feature (VectorAssembler), and reduce for building expressions.  
col, lag, lead, when, unix_timestamp, floor, lit, first, last, max as _max, min as _min, sum as _sum, pow, sqrt, avg, abs as sabs, greatest, udf are aliased to avoid shadowing Python built-ins.

In [None]:
from pyspark.sql import Window
from pyspark.sql.functions import (
    col, lag, lead, when, unix_timestamp, floor, lit,
    first, last, max as _max, min as _min, sum as _sum,
    pow, sqrt, avg, abs as sabs, greatest, udf
)
from pyspark.sql.types import DoubleType
from pyspark.ml.feature import VectorAssembler
from functools import reduce

## `extract_prob_udf`

Spark ML classification models output a probability vector; for binary classification the second element is the probability of the positive class. This UDF takes that vector and returns that value as a Double, or 0.5 if the vector is null. Used in inference to get rf_prob, gbt_prob, lr_prob, and final_prob.

In [None]:
extract_prob_udf = udf(lambda v: float(v[1]) if v is not None else 0.5, DoubleType())

## `aggregate_candles` — Timestamp handling

Aggregates smaller-timeframe candles (1m) into larger buckets (30m). First we need a single time representation: training data often has open_time as TimestampType, while streaming sends unix milliseconds. So we check the column type and add open_time_ms either by converting the timestamp or by using the column itself.

In [None]:
def aggregate_candles(df, candle_minutes=30):
    from pyspark.sql.types import TimestampType

    open_time_type = df.schema["open_time"].dataType

    if isinstance(open_time_type, TimestampType):
        df = df.withColumn("open_time_ms", unix_timestamp(col("open_time")) * 1000)
    else:
        df = df.withColumn("open_time_ms", col("open_time"))

---

## `aggregate_candles` — Bucketing and aggregation

Time is bucketed by integer division by the interval length in milliseconds, then multiplied back so the bucket is the interval start. We decided to group this bucket by symbol and aggregate: first open, max high, min low, last close, and sums for volume, number_of_trades, and taker_buy_quote_asset_volume. The result is ordered by symbol and open_time. The column is renamed back to open_time for downstream use.

In [None]:
    df = df.withColumn(
            "time_bucket",
            floor(col("open_time_ms") / lit(candle_minutes * 60 * 1000))
            * lit(candle_minutes * 60 * 1000)
        )

        df = df.groupBy("symbol", "time_bucket").agg(
            first("open").alias("open"),
            _max("high").alias("high"),
            _min("low").alias("low"),
            last("close").alias("close"),
            _sum("volume").alias("volume"),
            _sum("number_of_trades").alias("number_of_trades"),
            _sum("taker_buy_quote_asset_volume").alias("taker_buy_quote_asset_volume")
        ).withColumnRenamed("time_bucket", "open_time").orderBy("symbol", "open_time")

        return df

## `generate_features` — Windows and lag config

Feature-generation implies two window specs partition by symbol and order by open_time. Lag configuration specifies which columns get which lags: e.g. high lags 1–6, close lags 1–lookback, open 3/5/6, number_of_trades and taker_buy_quote_asset_volume 1–3. These lags are the base for many derived features.

In [None]:
def generate_features(df, dataset_name="", lookback=20, generate_label=True):
    print(f"Start feature generation for {dataset_name}")

    window_spec = Window.partitionBy("symbol").orderBy("open_time")
    window_symbol = Window.partitionBy("symbol").orderBy("open_time")

    lag_config = {
        "high": [1, 2, 3, 4, 5, 6],
        "close": list(range(1, lookback + 1)),
        "open": [3, 5, 6],
        "number_of_trades": [1, 2, 3],
        "taker_buy_quote_asset_volume": [1, 2, 3]
    }

    for column, lags in lag_config.items():
        for lag_period in lags:
            df = df.withColumn(
                f"{column}_lag{lag_period}",
                lag(col(column), lag_period).over(window_spec)
            )

---

## `generate_features` — Label generation

For training we need a label: UP (1) if next close > current close, DOWN (0) otherwise. We use lead(close, 1) to get the next close, then assign the label and drop the previously created helper column. For inference we don’t have future data, so we add a dummy label (0); the model expects a label column but it’s not used for prediction. This was done to avoid data leakage during streaming.

In [None]:
if generate_label:
        df = df.withColumn("close_next", lead(col("close"), 1).over(window_spec))
        df = df.withColumn("label",
            when(col("close_next") > col("close"), 1)
            .when(col("close_next") < col("close"), 0)
            .otherwise(0)
        )
        df = df.drop("close_next")
    else:
        df = df.withColumn("label", lit(0))

---

## `generate_features` — Candlestick and SMA features

Adds candlestick shape: 
body (close − open), 
range (high − low), 
upper_wick and lower_wick. 
Then SMA 5/10/20 from the corresponding close lags,
price deviation from each SMA as ratios (price_to_sma5/10/20). 
The raw SMA columns are dropped after the ratio is computed to keep the schema small !

In [None]:
df = df.withColumn("body", col("close") - col("open"))
    df = df.withColumn("range", col("high") - col("low"))
    df = df.withColumn(
        "upper_wick",
        col("high") - when(col("close") > col("open"), col("close")).otherwise(col("open"))
    )
    df = df.withColumn(
        "lower_wick",
        when(col("close") < col("open"), col("close")).otherwise(col("open")) - col("low")
    )

    df = df.withColumn("sma_5",
        (col("close_lag1") + col("close_lag2") + col("close_lag3") + col("close_lag4") + col("close_lag5")) / 5)
    df = df.withColumn("price_to_sma5",
        when(col("sma_5") != 0, (col("close") - col("sma_5")) / col("sma_5")).otherwise(0))
    df = df.drop("sma_5")

    close_lags_10 = [col(f"close_lag{i}") for i in range(1, 11)]
    df = df.withColumn("sma_10", reduce(lambda a, b: a + b, close_lags_10) / lit(10))
    df = df.withColumn("price_to_sma10",
        when(col("sma_10") != 0, (col("close") - col("sma_10")) / col("sma_10")).otherwise(0)
    )
    df = df.drop("sma_10")

    close_lags_20 = [col(f"close_lag{i}") for i in range(1, 21)]
    df = df.withColumn("sma_20", reduce(lambda a, b: a + b, close_lags_20) / lit(20))
    df = df.withColumn("price_to_sma20",
        when(col("sma_20") != 0, (col("close") - col("sma_20")) / col("sma_20")).otherwise(0)
    )

---

## `generate_features` — Momentum, volatility, Bollinger, ATR (More complex features that showed best importance rates during trainings)

5-period price momentum (relative change from close_lag5)
simple volatility measure from the last three close returns, 
Bollinger-style position as (close − sma_20) / (2 × volatility)
    volatility and sma_20 are dropped after !
True range is max(high−low, |high−prev_close|, |low−prev_close|)
ATR is the average over the previous 10 rows (rowsBetween(-10, -1))
    true_range is dropped after atr_10 is computed !

In [None]:
df = df.withColumn("price_momentum",
        when(col("close_lag5") != 0, (col("close") - col("close_lag5")) / col("close_lag5"))
        .otherwise(0)
    )

    df = df.withColumn("volatility",
        sqrt((
            pow((col("close_lag1") - col("close_lag2")) / col("close_lag2"), 2) +
            pow((col("close_lag2") - col("close_lag3")) / col("close_lag3"), 2) +
            pow((col("close_lag3") - col("close_lag4")) / col("close_lag4"), 2)
        ) / 3)
    )

    df = df.withColumn("bb_position",
        when(col("volatility") != 0,
            (col("close") - col("sma_20")) / (2 * col("volatility"))
        ).otherwise(0)
    )
    df = df.drop("volatility", "sma_20")

    df = df.withColumn("true_range",
        greatest(
            col("high") - col("low"),
            sabs(col("high") - lag(col("close"), 1).over(window_symbol)),
            sabs(col("low") - lag(col("close"), 1).over(window_symbol))
        )
    )
    window_tr = Window.partitionBy("symbol").orderBy("open_time").rowsBetween(-10, -1)
    df = df.withColumn("atr_10", avg(col("true_range")).over(window_tr))
    df = df.drop("true_range")

    df = df.dropna()

---

## `generate_features` — Feature selection and vector assembly

Only a subset of the computed columns are used by the model. 
SELECTED_FEATURES lists them: price_to_sma5/10/20, price_momentum, body/range/wicks, selected lags for high/close/open/trades/taker volume, bb_position, atr_10. 
We keep these plus system columns (symbol, open_time, label, close), drop the rest to save memory, then use VectorAssembler with handleInvalid="skip" to build the feature vector. 
The result DataFrame has symbol, features, label, open_time, and close. 
The function returns this DataFrame and the SELECTED_FEATURES list.

In [None]:
SELECTED_FEATURES = [
        "price_to_sma5", "price_to_sma10", "price_to_sma20",
        "price_momentum",
        "body", "range", "upper_wick", "lower_wick",
        "high_lag1", "high_lag2", "high_lag3", "high_lag4", "high_lag5", "high_lag6",
        "close_lag3", "close_lag4",
        "open_lag3", "open_lag5", "open_lag6",
        "number_of_trades_lag1", "number_of_trades_lag2", "number_of_trades_lag3",
        "taker_buy_quote_asset_volume_lag1",
        "taker_buy_quote_asset_volume_lag2",
        "taker_buy_quote_asset_volume_lag3",
        "bb_position", "atr_10"
    ]

    system_cols = ["symbol", "open_time", "label", "close"]
    keep_cols = set(SELECTED_FEATURES + system_cols)
    drop_cols = [c for c in df.columns if c not in keep_cols]
    if drop_cols:
        df = df.drop(*drop_cols)

    assembler = VectorAssembler(
        inputCols=SELECTED_FEATURES,
        outputCol="features",
        handleInvalid="skip"
    )
    df = assembler.transform(df)

    result_df = df.select("symbol", "features", "label", "open_time", "close")

    return result_df, SELECTED_FEATURES