In [1]:
# pyspark packages
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, BooleanType

#other needed packages
import re
import os

# Set JAVA_HOME for PySpark
os.environ['JAVA_HOME'] = '/opt/homebrew/opt/openjdk@17'

# Stop any existing Spark session
try:
    SparkSession.builder.getOrCreate().stop()
except:
    pass

# Create new Spark session with more memory
spark = SparkSession.builder \
    .appName("stock market preds") \
    .config("spark.driver.host", "127.0.0.1") \
    .config("spark.driver.memory", "8g") \
    .config("spark.sql.shuffle.partitions", "200") \
    .master("local[*]") \
    .getOrCreate()
spark.sparkContext.setLogLevel("ERROR")

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
26/01/19 15:41:09 WARN Utils: Your hostname, Jeffreys-MacBook-Air.local, resolves to a loopback address: 127.0.0.1; using 10.0.0.17 instead (on interface en0)
26/01/19 15:41:09 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
26/01/19 15:41:09 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [2]:
# functions to import data, run SQL from file and save back to file
# function to import and clean columns
def import_csv_to_table(table_name, file, format_cols):

    #read source files
    df = spark.read.csv(file, header=True, quote="\"",
                        escape="\"", multiLine=True, inferSchema=True)

    #clean column names
    if format_cols:
        cols_formatted = [re.sub(r"[^a-zA-Z0-9\s]", "", col_name).lower().replace(" ", "_") for col_name in df.columns]
        df = df.toDF(*cols_formatted)

    # create SQL view
    df.createOrReplaceTempView(f"{table_name}")
    return df

#run a SQL step
def sql_step(file):
    with open(file, 'r', encoding='utf-8') as file:
        sql_text = file.read()
    results = spark.sql(sql_text)
    return results

#run SQL and view output inline
def run_sql(file, rowstoshow, print_sql):
    with open(file, 'r', encoding='utf-8') as file:
        sql_text = file.read()
    results = spark.sql(sql_text)
    if print_sql == True: print(sql_text)
    results.show(rowstoshow, truncate=False)

# export data frame to csv
def export_csv(df, output_dir, final_file_name):
    df.coalesce(1).write.csv(output_dir, header=True, mode="overwrite")
    for file in os.listdir(output_dir):
        if file.startswith("part-") and file.endswith(".csv"):
            part_file_path = os.path.join(output_dir, file)
            break
    if part_file_path:
        os.rename(part_file_path, os.path.join(output_dir, final_file_name))
        print(f"CSV saved as: {final_file_name}")
    else:
        print("Error: Part file not found.")

In [3]:
# Use absolute paths to avoid OneDrive timeout issues with Spark
import os
base_path = os.getcwd()

# Import raw_data tables
raw_data_path = os.path.join(base_path, "raw_data")
alfred_economic_data = import_csv_to_table("alfred_economic_data", os.path.join(raw_data_path, "alfred_economic_data.csv"), False)
balance_sheet = import_csv_to_table("balance_sheet", os.path.join(raw_data_path, "balance_sheet.csv"), False)
cashflow = import_csv_to_table("cashflow", os.path.join(raw_data_path, "cashflow.csv"), False)
earnings_dates = import_csv_to_table("earnings_dates", os.path.join(raw_data_path, "earnings_dates.csv"), False)
income_statement = import_csv_to_table("income_statement", os.path.join(raw_data_path, "income_statement.csv"), False)
news_data = import_csv_to_table("news_data", os.path.join(raw_data_path, "news_data.csv"), False)
stock_data = import_csv_to_table("stock_data", os.path.join(raw_data_path, "stock_data.csv"), False)
tickers = import_csv_to_table("tickers", os.path.join(raw_data_path, "tickers.csv"), False)

# Import processed_data tables
processed_data_path = os.path.join(base_path, "processed_data")
finbert_news_classifications = import_csv_to_table("finbert_news_classifications", os.path.join(processed_data_path, "finbert_news_classifications.csv"), False)
fingpt_sentiment_checkpoint = import_csv_to_table("fingpt_sentiment_checkpoint", os.path.join(processed_data_path, "fingpt_sentiment_checkpoint.csv"), False)
lstm_predictions = import_csv_to_table("lstm_predictions", os.path.join(processed_data_path, "lstm_predictions.csv"), False)

print("All tables imported and SQL views created successfully!")

All tables imported and SQL views created successfully!


In [None]:
lstm_predictions.show()

In [None]:
spark.sql("""
    SELECT split, 
           MIN(target_date) as first_target,
           MAX(target_date) as last_target,
           MIN(prediction_date) as first_pred,
           MAX(prediction_date) as last_pred
    FROM lstm_predictions
    GROUP BY split
    ORDER BY split
""").show()

In [4]:
run_sql("sql/xgboost_feature_set.sql", 30, False)

[Stage 78:>                                                       (0 + 10) / 11]

+------+---------------+--------------------------+-------+--------+--------+--------+---------+----------+--------------+----------------------+---------------+---------------+---------------+----------------+---------------+-------------------+-------------+---------+-------+--------------+------------+------------------+--------------------+-------------------+----------------+---------------------------+------------------------+-----------------+-------------------+-------------------+---------------+--------------------+----------------------+-------------+-----------+-----------------------+-----------------+----------------------+------------------+-----------------------+----------+------------+----------------------+-------------------+--------------------+------------------+------------------+----------------------+----------------------+----------------------+-----------------+----------------------+----------------------------+---------------------------+-------------------

                                                                                

In [None]:
# Export the XGBoost feature set to CSV
xgboost_features = sql_step("sql/xgboost_feature_set.sql")
export_csv(xgboost_features, processed_data_path, "xgboost_feature_set/xgboost_feature_set.csv")

# Stop Spark and clear memory
spark.stop()
del spark, xgboost_features
del stock_data, income_statement, balance_sheet, cashflow, earnings_dates
del alfred_economic_data, news_data, tickers
del fingpt_sentiment_checkpoint, finbert_news_classifications, lstm_predictions

import gc
gc.collect()
print("Spark stopped and memory cleared. Ready for XGBoost model development.")

                                                                                

CSV saved as: xgboost_feature_set.csv
Spark stopped and memory cleared. Ready for XGBoost model development.
