# Real-time inference

This document explains each section of `src/real_time_inference.py`. The module consumes candles from Kafka, keeps a sliding window of 21×30m candles, generates features, runs the stacked ML models, and outputs trading signals. It also backfills the previous prediction’s actual label when the next candle arrives.

## Imports and shared utils

Imports for the Spark session, ML models (RF, GBT, LR), VectorAssembler, and a deque for the sliding window. The try/except imports `shared_utils` so the script works if run from the project root or from `src/`.

In [1]:
import os
from collections import deque
from typing import Dict, List, Optional
import logging

from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from pyspark.ml.classification import (
    RandomForestClassificationModel,
    GBTClassificationModel,
    LogisticRegressionModel)
from pyspark.ml.feature import VectorAssembler
from pyspark.sql.types import DoubleType

try:
    from shared_utils import generate_features, aggregate_candles, extract_prob_udf
except ImportError:
    from src.shared_utils import generate_features, aggregate_candles, extract_prob_udf

## Logging and constants

Logging is configured for the module. Constants define the candle timeframe (30m), lookback length (20), and window size (21 = lookback + 1). The extra candle is needed because features use lagged values.

In [2]:
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

CANDLE_MINUTES = 30
LOOKBACK = 20
WINDOW_SIZE = LOOKBACK + 1

## Class `CandlesProcessor`

Creates the real-time processor: initializes Spark, minute buffer (for 1m aggregation, unused when using 30m), a deque of max 21 for the 30m window, prediction output path, and `pending_prediction` for backfilling. Loads the four trained models (RF, GBT, LR, meta-learner) from disk.

In [3]:
class CandlesProcessor:
    def __init__(self, model_dir: str = "models", predictions_dir: str = "data/live_predictions_parquet"):
        self.spark = self._initialize_spark()

        self.minute_buffer: List[Dict] = []
        self.window_30m = deque(maxlen=WINDOW_SIZE)

        self.predictions_path = predictions_dir
        self.pending_prediction: Optional[Dict] = None

        logger.info("Loading trained models")
        self.rf_model = RandomForestClassificationModel.load(
            os.path.join(model_dir, "stacked_rf_model")
        )
        self.gbt_model = GBTClassificationModel.load(
            os.path.join(model_dir, "stacked_gbt_model")
        )
        self.lr_model = LogisticRegressionModel.load(
            os.path.join(model_dir, "stacked_lr_model")
        )
        self.meta_model = LogisticRegressionModel.load(
            os.path.join(model_dir, "stacked_meta_model")
        )
        logger.info("All models loaded successfully")

        os.makedirs(predictions_dir, exist_ok=True)

---

## `_initialize_spark`

Builds the Spark session for real-time inference: adaptive execution, 8 shuffle partitions, 2g driver/executor memory, and codegen disabled to avoid compatibility issues.

In [4]:
def _initialize_spark(self) -> SparkSession:
        return (SparkSession.builder
            .appName("CryptoTradingSignals_RealTime")
            .config("spark.sql.adaptive.enabled", "true")
            .config("spark.sql.shuffle.partitions", "8")
            .config("spark.driver.memory", "2g")
            .config("spark.executor.memory", "2g")
            .config("spark.sql.codegen.wholeStage", "false")
            .config("spark.sql.codegen.factoryMode", "NO_CODEGEN")
            .getOrCreate())

---

## `add_1minute_candle` and `add_candle`

`add_1minute_candle` appends a 1m candle to the buffer until 30 are collected. If collected it aggregates them into one bigger 30m candle and adds it to the rolling window. It also checks the candle’s `interval` and either adds a 30m candle directly or passes 1m candles to the buffer/aggregation path.

In [6]:
def add_1minute_candle(self, candle: Dict):
        self.minute_buffer.append(candle)

        if len(self.minute_buffer) >= 30:
            self._process_30minute_candle()
            self.minute_buffer.clear()

    def add_candle(self, candle: Dict):
        interval = candle.get('interval', '1m')

        if interval == '30m':
            self.add_30minute_candle(candle)
        else:
            self.add_1minute_candle(candle)

IndentationError: unindent does not match any outer indentation level (<tokenize>, line 8)

---

## `add_30minute_candle`

Used for bootstrapped or live 30m candles. Converts the candle dict to a Spark Row in order to match DataFrame row, appends it to the 30m deque (FIFO approach, first in first out), logs, and triggers prediction when the window has exactly 21 candles.

In [None]:
def add_30minute_candle(self, candle: Dict):
        from pyspark.sql import Row

        candle_row = Row(
            open_time=candle['open_time'],
            symbol=candle['symbol'],
            open=candle['open'],
            high=candle['high'],
            low=candle['low'],
            close=candle['close'],
            volume=candle['volume'],
            number_of_trades=candle['number_of_trades'],
            taker_buy_quote_asset_volume=candle['taker_buy_quote_asset_volume']
        )

        self.window_30m.append(candle_row)

        logger.info(
            f"Added 30m candle: {candle['symbol']} "
            f"@ {candle['open_time']} | Close: {candle['close']:.2f} "
            f"(Window: {len(self.window_30m)}/{WINDOW_SIZE})"
        )

        if len(self.window_30m) == WINDOW_SIZE:
            self._predict()

---

## `_process_30minute_candle`

Builds a DataFrame from the 1m buffer, aggregates to 30m using `aggregate_candles`, takes the single resulting row, appends it to the window, and runs prediction if the window is full. Used only when 1m candles are fed.

In [None]:
def _process_30minute_candle(self):
        df_1min = self.spark.createDataFrame(self.minute_buffer)
        df_30m = aggregate_candles(df_1min)
        new_30m_row = df_30m.collect()[0]
        self.window_30m.append(new_30m_row)

        logger.info(
            f"Aggregated 30m candle: {new_30m_row['symbol']} "
            f"@ {new_30m_row['open_time']} | Close: {new_30m_row['close']:.2f}"
        )

        if len(self.window_30m) == WINDOW_SIZE:
            self._predict()

---

## `_predict` — Core prediction pipeline

Turns the 21-candle window into a DataFrame and gets the latest candle. If there is a `pending_prediction`, it backfills that row’s actual label using the current close. Then it generates features with `generate_label=False` (no future data). If no rows survive feature generation, it returns. Otherwise it runs RF, GBT, and LR, adds `rf_prob`, `gbt_prob`, `lr_prob` via `extract_prob_udf`, joins these by `open_time`, assembles meta-features, runs the meta-model, and gets `final_prob` and `prediction`. It computes the trading signal, prints it, logs the prediction to CSV, and stores this prediction in `pending_prediction` for the next backfill.

In [None]:
def _predict(self):
        df_window = self.spark.createDataFrame(list(self.window_30m))
        current_candle = self.window_30m[-1]

        if self.pending_prediction is not None:
            self._backfill_actual_label(
                self.pending_prediction['timestamp'],
                self.pending_prediction['close'],
                float(current_candle['close'])
            )

        df_features, _ = generate_features(
            df_window,
            dataset_name="LIVE_STREAM",
            generate_label=False
        )

        if df_features.count() == 0:
            logger.warning("No features generated - skipping prediction")
            return

        rf_res = self.rf_model.transform(df_features).withColumn(
            "rf_prob", extract_prob_udf(col("probability"))
        )
        gbt_res = self.gbt_model.transform(df_features).withColumn(
            "gbt_prob", extract_prob_udf(col("probability"))
        )
        lr_res = self.lr_model.transform(df_features).withColumn(
            "lr_prob", extract_prob_udf(col("probability"))
        )

        meta_df = (rf_res.select("open_time", "symbol", "close", "rf_prob")
            .join(gbt_res.select("open_time", "gbt_prob"), "open_time")
            .join(lr_res.select("open_time", "lr_prob"), "open_time"))

        meta_assembler = VectorAssembler(
            inputCols=["rf_prob", "gbt_prob", "lr_prob"],
            outputCol="meta_features"
        )
        meta_df = meta_assembler.transform(meta_df)

        final_prediction = self.meta_model.transform(meta_df)
        final_prediction = final_prediction.withColumn(
            "final_prob", extract_prob_udf(col("probability"))
        )

        result = final_prediction.collect()[0]
        prob_up = float(result['final_prob'])
        prediction = float(result['prediction'])

        signal, confidence = self._determine_signal(prob_up, prediction)

        self._print_trade_signal(result['symbol'], signal, prob_up, confidence)

        self._log_prediction(result, prob_up, prediction, current_candle)

        logger.info(
            f"Prediction: {result['symbol']} | Prob UP: {prob_up:.4f} | "
            f"Signal: {signal} | Confidence: {confidence}"
        )

---

## `_backfill_actual_label`

When a new candle arrives, we know whether the previous candle closed up or down. This method sets `actual_label` (1 if current close > previous close, else 0), compares it to `pending_prediction['prediction']`, and writes `actual_label` and `is_correct` into the appropriate row of `data/predictions.csv` . In the end reads the whole file, updates that row, and writes it back.

In [None]:
def _backfill_actual_label(self, prev_timestamp: int, prev_close: float, current_close: float):
        import csv

        try:
            actual_label = 1 if current_close > prev_close else 0

            predicted = self.pending_prediction.get('prediction', None)
            row_num = self.pending_prediction.get('row_number', None)

            if predicted is not None and row_num is not None:
                is_correct = 1 if predicted == actual_label else 0

                predicted_str = "UP" if predicted == 1 else "DOWN"
                actual_str = "UP" if actual_label == 1 else "DOWN"

                if is_correct == 1:
                    logger.info(f"Previous prediction correct: Predicted {predicted_str}, Actual {actual_str}")
                else:
                    logger.info(f"Previous prediction wrong: Predicted {predicted_str}, Actual {actual_str}")

                log_file = "data/predictions.csv"

                with open(log_file, 'r') as f:
                    rows = list(csv.reader(f))

                if row_num < len(rows):
                    rows[row_num][6] = str(actual_label)
                    rows[row_num][7] = str(is_correct)

                with open(log_file, 'w', newline='') as f:
                    writer = csv.writer(f)
                    writer.writerows(rows)

                logger.info(f"Backfilled row {row_num} in CSV")

        except Exception as e:
            logger.error(f"Error backfilling: {e}")

---

## `_determine_signal`

Maps the meta-model’s UP probability and predicted class to a human-readable signal and confidence. prob ≥ 0.7 or ≤ 0.3 → VERY HIGH (BUY or SELL); 0.6/0.4 → HIGH; otherwise Low confidence and WAIT. BUY/SELL corresponds (1 = Long, 0 = Short).

In [None]:
def _determine_signal(self, prob_up: float, prediction: float) -> tuple:
        if prob_up >= 0.7 or prob_up <= 0.3:
            confidence = "VERY HIGH"
            signal = "BUY (Long)" if prediction == 1.0 else "SELL (Short)"
        elif prob_up >= 0.6 or prob_up <= 0.4:
            confidence = "HIGH"
            signal = "BUY (Long)" if prediction == 1.0 else "SELL (Short)"
        else:
            confidence = "Low"
            signal = "WAIT"

        return signal, confidence

---

## `_log_prediction` and `_count_csv_rows`

`_log_prediction` ensures `data/predictions.csv` exists. It appends one row per prediction: timestamp, datetime string, symbol, close, final_prob, prediction; actual_label and is_correct are left empty and filled later by `_backfill_actual_label`. It then sets `pending_prediction` with timestamp, close, prediction, and the CSV data row index (from `_count_csv_rows`) so the next run can update that row. `_count_csv_rows` returns the number of data rows without header.

In [None]:
def _log_prediction(self, result, prob_up: float, prediction: float, current_candle):
        import csv
        from datetime import datetime

        log_file = "data/predictions.csv"

        if not os.path.exists(log_file):
            os.makedirs("data", exist_ok=True)
            with open(log_file, 'w', newline='') as f:
                writer = csv.writer(f)
                writer.writerow([
                    'timestamp', 'datetime', 'symbol', 'close',
                    'final_prob', 'prediction', 'actual_label', 'is_correct'
                ])

        timestamp = result['open_time']
        dt_str = datetime.fromtimestamp(timestamp / 1000).strftime('%Y-%m-%d %H:%M:%S')

        with open(log_file, 'a', newline='') as f:
            writer = csv.writer(f)
            writer.writerow([
                timestamp,
                dt_str,
                result['symbol'],
                float(result['close']),
                float(prob_up),
                int(prediction),
                '',
                ''
            ])

        logger.info(f"Logged prediction to {log_file}")

        self.pending_prediction = {
            'timestamp': result['open_time'],
            'close': float(current_candle['close']),
            'prediction': int(prediction),
            'row_number': self._count_csv_rows(log_file)
        }

    def _count_csv_rows(self, filepath):
        with open(filepath, 'r') as f:
            return sum(1 for line in f) - 1

---

## `_print_trade_signal`

Prints the symbol, action (BUY/SELL/WAIT), probability, and confidence. Uses ANSI escape codes: green for BUY, red for SELL, yellow for WAIT.

In [None]:
def _print_trade_signal(self, symbol: str, signal: str, prob: float, confidence: str):
        if "BUY" in signal:
            color = "\033[92m"
        elif "SELL" in signal:
            color = "\033[91m"
        else:
            color = "\033[93m"
        reset = "\033[0m"

        print(f"Signal: {symbol}")
        print(f"Action: {color}{signal}{reset}")
        print(f"Probability: {prob:.4f}")
        print(f"Confidence Level: {confidence}")

---

## `run_stack_inference`

Creates a single `CandlesProcessor` and defines a `batch_function` function that, for each micro-batch of the streaming DataFrame, collects rows and passes each candle as dict type to `processor.add_candle`. The streaming query writes to this batch function, uses the given checkpoint directory, and triggers every 10 seconds. Returns the started query so the caller can await termination.

In [None]:
def run_stack_inference(parsed_streaming_df):
    processor = CandlesProcessor()

    def batch_function(batch_df, batch_id):
        if batch_df.count() == 0:
            return

        records = batch_df.collect()
        for row in records:
            processor.add_candle(row.asDict())

    query = (parsed_streaming_df.writeStream
        .foreachBatch(batch_function)
        .option("checkpointLocation", "./binance_kline_chk")
        .trigger(processingTime="10 seconds")
        .start())

    logger.info("Stream started, waiting for candles")
    logger.info("Note: First prediction will have no backfilled prediction")
    logger.info("Subsequent predictions will show if previous prediction was correct")
    return query

---

## `if __name__ == "__main__"` — Standalone execution

When the file is run manualy directly, it builds a Spark session, defines the schema for the Kafka JSON payload, reads from the `binance_kline` topic latest offset, parses the value with `from_json`, and runs `run_stack_inference` on the parsed stream.

In [None]:
if __name__ == "__main__":
    from pyspark.sql import SparkSession
    from pyspark.sql.types import StructType, StructField, StringType, DoubleType, LongType, IntegerType

    spark = SparkSession.builder.appName("CryptoSignals").getOrCreate()

    schema = StructType([
        StructField("open_time", LongType(), True),
        StructField("open", DoubleType(), True),
        StructField("high", DoubleType(), True),
        StructField("low", DoubleType(), True),
        StructField("close", DoubleType(), True),
        StructField("volume", DoubleType(), True),
        StructField("quote_asset_volume", DoubleType(), True),
        StructField("number_of_trades", IntegerType(), True),
        StructField("taker_buy_base_asset_volume", DoubleType(), True),
        StructField("taker_buy_quote_asset_volume", DoubleType(), True),
        StructField("symbol", StringType(), True),
        StructField("interval", StringType(), True),
        StructField("ingested_at", StringType(), True)
    ])

    df_stream = (spark.readStream
        .format("kafka")
        .option("kafka.bootstrap.servers", "localhost:9092")
        .option("subscribe", "binance_kline")
        .option("startingOffsets", "earliest")
        .load())

    from pyspark.sql.functions import from_json

    df_parsed = df_stream.select(
        from_json(col("value").cast("string"), schema).alias("data")
    ).select("data.*")

    query = run_stack_inference(df_parsed)

    try:
        query.awaitTermination()
    except KeyboardInterrupt:
        query.stop()