In [0]:
%run "/Users/bartoszgardzinski1@gmail.com/DE_Databricks_1stpipe/utils/technical_indicators"

In [0]:
dbutils.widgets.text(name="env",defaultValue="",label="Enter the environment in lower case")
env = dbutils.widgets.get("env")

#hello

from pyspark.sql.functions import (
    lit, avg, col, row_number, when, stddev, lag, sum as spark_sum, max as spark_max
)
from pyspark.sql.window import Window
from datetime import datetime, timedelta
from functools import reduce
from pyspark.sql.functions import date_sub, current_date
import time
from delta.tables import DeltaTable


input_table = f"{env}_silver.daily_price_aggregates"
output_table = f"{env}_gold.daily_price_indicators"
output_columns = ['symbol', 'date', 'open', 'day_high', 'day_low', 'avg_price', 'volume', 'year', 'month', 'day', 'ma_12', 'ma_26', 'ma_60', 'ema_12', 'ema_26', 'ema_60', 'rsi_12', 'rsi_26', 'rsi_60', 'bb_upper_12', 'bb_lower_12', 'bb_width_12', 'bb_upper_26', 'bb_lower_26', 'bb_width_26', 'bb_upper_60', 'bb_lower_60', 'bb_width_60', 'macd', 'signal', 'macd_hist', 'obv']
windows = [60,26,12]

def readStream_prices():
    return spark.readStream.table(f"{env}_silver.daily_price_aggregates")

def read_prices(input_table, row_date):

    row_date = row_date.date()

    if granulity == "daily":
        datecol = "date"
        nb_days = 90
    elif granulity == "hourly":
        datecol = "timestamp"
        nb_days = 5

    start_date = (row_date - timedelta(days=nb_days)).date()
    end_date = row_date

    month_pairs = []
    current = start_date.replace(day=1)
    while current <= end_date:
        month_pairs.append((current.year, current.month))
        if current.month == 12:
            current = current.replace(year=current.year + 1, month=1)
        else:
            current = current.replace(month=current.month + 1)

    df_recent = spark.table(input_table) \
        .filter((col("year"), col("month")).isin(month_pairs)) \
        .filter((col("date") >= lit(start_date)) & (col("date") <= lit(end_date)))

    window_spec = Window.partitionBy("symbol").orderBy(col(datecol).desc())
    df_latest = df_filtered.withColumn("rn", row_number().over(window_spec))

    return df_latest


def take_latest_row(table_path, row_date, granulity):

    if granulity == "daily":
        datecol = "date"
    elif granulity == "hourly":
        datecol = "timestamp"

    row_date = row_date.date()

    five_days_ago = date_sub(row_date, 5)
    df_recent = (
        spark.read.table(table_path)
        .filter(col(datecol) >= five_days_ago & col(datecol) < row_date)  # Adjust to your date column
    )

    # Define window: partition by symbol, order by date descending
    window_spec = Window.partitionBy("symbol").orderBy(col(datecol).desc())

    # Assign row number to get the latest record per symbol
    df_latest = (
        df_recent
        .withColumn("rn", row_number().over(window_spec))
        .filter(col("rn") == 1)
        .drop("rn")
    )

    for c in df_latest.columns:
        if c != "symbol":
            df_latest = df_latest.withColumnRenamed(c, f"prev_{c}")

    return df_latest

def run_all_indicators(df_filtered, df_result, windows):

    df_result = moving_average(df_filtered, df_result, windows)

    df_result = exponential_moving_average(df_result, windows)

    df_result = calculate_rsi(df_filtered, df_result, windows)

    df_result = calculate_macd(df_result)

    df_result = calculate_bollinger_bands(df_filtered, df_result, windows)

    df_result = calculate_obv(df_filtered, df_result)

    df_result = df_result.select(*output_columns)




def write_to_delta(input_df, input_table, output_table):
    def upsert_to_delta(batch_df, batch_id):

        distinct_dates = [row['date'] for row in input_df.select('date').distinct().collect()]
        distinct_dates_count = batch_df.select("date").distinct().count()

        if distinct_dates_count < 2:
            df_filtered = read_prices(input_table,*distinct_dates)
            df_filtered.cache()

            latest_records = take_latest_row(output_table, df_streamed, "daily")
            df_result = df_result.join(latest_records, on="symbol", how="left")


        else:
            for single_date in distinct_dates:
                df_single_date = input_df.filter(col('date') == single_date)
                
                

        delta_table = DeltaTable.forName(spark, target_table)
        delta_table.alias("target").merge(
            batch_df.alias("source"),
            "target.symbol = source.symbol AND target.date = source.date"
        ).whenMatchedUpdateAll() \
        .whenNotMatchedInsertAll() \
        .execute()


    query = (df.writeStream
        .foreachBatch(upsert_to_delta)
        .option("checkpointLocation", checkpoint_path)
        .outputMode("update")
        .trigger(availableNow=True)
        .start())
    query.awaitTermination()





df_streamed = readStream_prices()
write_to_delta(df_streamed, input_table, output_table)

latest_records = take_latest_row(output_table, df_streamed, "daily")
df_result = df_result.join(latest_records, on="symbol", how="left")






In [0]:
df = spark.sql("SELECT * FROM hive_metastore.dev_gold.daily_price_indicators")
pandas_df = df.limit(100).toPandas()
display(pandas_df)