In [0]:
%run "/Users/bartoszgardzinski1@gmail.com/DE_Databricks_1stpipe/utils/technical_indicators"

In [0]:
dbutils.widgets.text(name="env",defaultValue="",label="Enter the environment in lower case")
env = dbutils.widgets.get("env")

from pyspark.sql.functions import (
    lit, avg, col, row_number, when, stddev, lag, sum as spark_sum, max as spark_max
)
from pyspark.sql.window import Window
from datetime import datetime, timedelta
from functools import reduce
from pyspark.sql.functions import date_sub, current_date
import time
from delta.tables import DeltaTable


input_table = f"{env}_silver.daily_price_aggregates"
output_table = f"{env}_gold.daily_price_indicators"
output_columns = ['symbol', 'date', 'open', 'day_high', 'day_low', 'avg_price', 'volume', 'year', 'month', 'day', 'ma_12', 'ma_26', 'ma_60', 'ema_12', 'ema_26', 'ema_60', 'rsi_12', 'rsi_26', 'rsi_60', 'bb_upper_12', 'bb_lower_12', 'bb_width_12', 'bb_upper_26', 'bb_lower_26', 'bb_width_26', 'bb_upper_60', 'bb_lower_60', 'bb_width_60', 'macd', 'signal', 'macd_hist', 'obv']
windows = [60,26,12]

def read_prices(input_table):

    # Calculate start date 90 days ago (Python side)
    start_date = (datetime.today() - timedelta(days=90)).date()

    # Extract year and month from start_date
    start_year = start_date.year
    start_month = start_date.month

    # Also get current year and month
    current_year = datetime.today().year
    current_month = datetime.today().month

    df_recent = spark.table(input_table) \
        .filter(
            ( (col("year") == start_year) & (col("month") >= start_month) ) |
            ( (col("year") > start_year) & (col("year") < current_year) ) |
            ( (col("year") == current_year) & (col("month") <= current_month) )
        )

    df_filtered = df_recent.filter(col("date") >= lit(start_date.strftime("%Y-%m-%d")))
    window_spec = Window.partitionBy("symbol").orderBy(col("date").desc())
    df_latest = df_filtered.withColumn("rn", row_number().over(window_spec))

    return df_latest


def take_latest_row(table_path, granulity):

    if granulity == "daily":
        datecol = "date"
    elif granulity == "hourly":
        datecol = "timestamp"

    five_days_ago = date_sub(current_date(), 5)
    df_recent = (
        spark.read.table(table_path)
        .filter(col(datecol) >= five_days_ago)  # Adjust to your date column
    )

    # Define window: partition by symbol, order by date descending
    window_spec = Window.partitionBy("symbol").orderBy(col(datecol).desc())

    # Assign row number to get the latest record per symbol
    df_latest = (
        df_recent
        .withColumn("rn", row_number().over(window_spec))
        .filter(col("rn") == 1)
        .drop("rn")
    )

    for c in df_latest.columns:
        if c != "symbol":
            df_latest = df_latest.withColumnRenamed(c, f"prev_{c}")

    return df_latest

def write_indicators(df_result, path):
    target_table = DeltaTable.forName(spark, output_table)

    target_table.alias("target").merge(
    source=df_result.alias("source"),
    condition="target.symbol = source.symbol AND target.date = source.date"
    ).whenMatchedUpdateAll() \
    .whenNotMatchedInsertAll() \
    .execute()



df_filtered = read_prices(input_table)
df_filtered.cache()

df_result = df_filtered.filter(col("rn") == 1).drop("rn")
latest_records = take_latest_row(output_table, "daily")
df_result = df_result.join(latest_records, on="symbol", how="left")

start = time.time()
df_result = moving_average(df_filtered, df_result, windows)
df_result.count()
print(f"moving_average took {time.time() - start:.3f} seconds")

start = time.time()
df_result = exponential_moving_average(df_result, windows)
df_result.count()
print(f"exponential_moving_average took {time.time() - start:.3f} seconds")

start = time.time()
df_result = calculate_rsi(df_filtered, df_result, windows)
df_result.count()
print(f"calculate_rsi took {time.time() - start:.3f} seconds")

start = time.time()
df_result = calculate_macd(df_result)
df_result.count()
print(f"calculate_macd took {time.time() - start:.3f} seconds")

start = time.time()
df_result = calculate_bollinger_bands(df_filtered, df_result, windows)
df_result.count()
print(f"calculate_bollinger_bands took {time.time() - start:.3f} seconds")

start = time.time()
df_result = calculate_obv(df_filtered, df_result)
df_result.count()
print(f"calculate_obv took {time.time() - start:.3f} seconds")


df_result = df_result.select(*output_columns)

write_indicators(df_result, output_table)




In [0]:
df = spark.sql("SELECT * FROM hive_metastore.dev_gold.daily_price_indicators")
pandas_df = df.limit(100).toPandas()
display(pandas_df)