In [1]:
# pyspark packages
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, BooleanType

#other needed packages
import re
import os

# Set JAVA_HOME for PySpark
os.environ['JAVA_HOME'] = '/opt/homebrew/opt/openjdk@17'

spark = SparkSession.builder \
    .appName("stock market preds") \
    .config("spark.driver.host", "127.0.0.1") \
    .getOrCreate()
spark.sparkContext.setLogLevel("ERROR")

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
26/01/12 19:56:10 WARN Utils: Your hostname, Jeffreys-MacBook-Air.local, resolves to a loopback address: 127.0.0.1; using 10.0.0.17 instead (on interface en0)
26/01/12 19:56:10 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
26/01/12 19:56:10 WARN Utils: Your hostname, Jeffreys-MacBook-Air.local, resolves to a loopback address: 127.0.0.1; using 10.0.0.17 instead (on interface en0)
26/01/12 19:56:10 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust loggi

In [2]:
# functions to import data, run SQL from file and save back to file
# function to import and clean columns
def import_csv_to_table(table_name, file, format_cols):

    #read source files
    df = spark.read.csv(file, header=True, quote="\"",
                        escape="\"", multiLine=True, inferSchema=True)

    #clean column names
    if format_cols:
        cols_formatted = [re.sub(r"[^a-zA-Z0-9\s]", "", col_name).lower().replace(" ", "_") for col_name in df.columns]
        df = df.toDF(*cols_formatted)

    # create SQL view
    df.createOrReplaceTempView(f"{table_name}")
    return df

#run a SQL step
def sql_step(file):
    with open(file, 'r', encoding='utf-8') as file:
        sql_text = file.read()
    results = spark.sql(sql_text)
    return results

#run SQL and view output inline
def run_sql(file, rowstoshow, print_sql):
    with open(file, 'r', encoding='utf-8') as file:
        sql_text = file.read()
    results = spark.sql(sql_text)
    if print_sql == True: print(sql_text)
    results.show(rowstoshow, truncate=False)

# export data frame to csv
def export_csv(df, output_dir, final_file_name):
    df.coalesce(1).write.csv(output_dir, header=True, mode="overwrite")
    for file in os.listdir(output_dir):
        if file.startswith("part-") and file.endswith(".csv"):
            part_file_path = os.path.join(output_dir, file)
            break
    if part_file_path:
        os.rename(part_file_path, os.path.join(output_dir, final_file_name))
        print(f"CSV saved as: {final_file_name}")
    else:
        print("Error: Part file not found.")

In [3]:
news = import_csv_to_table("news", "raw_data/news_data.csv", False)
stocks = import_csv_to_table("stocks", "raw_data/stock_data.csv", False)

In [None]:
news.show()

+--------+--------------------+--------------------+-----------------+-------------------+-------------------+--------------------+--------------------+--------+
|      id|            headline|             summary|           author|         created_at|         updated_at|                 url|             symbols|  source|
+--------+--------------------+--------------------+-----------------+-------------------+-------------------+--------------------+--------------------+--------+
|49701666|Evercore ISI Grou...|                NULL|Benzinga Newsdesk|2026-01-05 10:50:55|2026-01-05 10:50:56|https://www.benzi...|                   A|benzinga|
|49391324|Barclays Upgrades...|                NULL|Benzinga Newsdesk|2025-12-15 06:49:28|2025-12-15 06:49:29|https://www.benzi...|                   A|benzinga|
|49342760|What's Driving th...|                    |Benzinga Insights|2025-12-11 11:00:38|2025-12-11 11:00:39|https://www.benzi...|                   A|benzinga|
|49276887|Goldman Sachs Ini.

In [6]:
stocks.show()

+------+-------------------+------+-------+--------+------+---------+-----------+----------+-----+
|symbol|          timestamp|  open|   high|     low| close|   volume|trade_count|      vwap|index|
+------+-------------------+------+-------+--------+------+---------+-----------+----------+-----+
|     A|2025-01-07 22:00:00|137.68| 137.68|  135.63| 137.0|1684573.0|    19948.0|137.068421| NULL|
|     A|2025-01-09 22:00:00|134.75| 140.14| 134.709|137.47|1369875.0|    25383.0|137.592663| NULL|
|     A|2025-01-12 22:00:00|137.22| 142.82|   137.0|141.95|1561959.0|    28739.0|141.776934| NULL|
|     A|2025-01-13 22:00:00| 142.0| 145.38|  140.15|143.43|2445434.0|    36636.0|143.373405| NULL|
|     A|2025-01-14 22:00:00|144.14|  146.5|  138.68|142.23|2328643.0|    35076.0|142.841548| NULL|
|     A|2025-01-15 22:00:00|142.78| 145.11|  140.43|144.72|1661474.0|    25916.0|143.989713| NULL|
|     A|2025-01-16 22:00:00|145.88| 148.46| 145.195|147.36|3210310.0|    45636.0|147.281027| NULL|
|     A|20

In [9]:
feature_set = sql_step("sql/sentiment_data_prep.sql")
feature_set.show()

+------+----------+--------------------+--------------------------+
|symbol| news_date|     daily_news_text|percent_daily_price_change|
+------+----------+--------------------+--------------------------+
|  PRGO|2025-01-06|                    |      0.004008016032064185|
|   TNL|2025-01-06|                    |      0.005322294500295745|
|  RPRX|2025-01-12|News Article Numb...|      0.004347826086956608|
|  LFUS|2025-01-13|                    |      -0.01619148936170...|
|   NSA|2025-01-13|                    |      -0.02918918918918...|
|   USD|2025-01-14|News Article Numb...|      -0.05604253484696085|
|   RYN|2025-01-15|                    |       0.01221374045801528|
|   WFC|2025-01-15|News Article Numb...|       0.01447749407738871|
|  BOKF|2025-01-21|News Article Numb...|      -0.01828510562755...|
|    CF|2025-01-21|News Article Numb...|      0.014919011082693856|
|  STWD|2025-01-21|                    |      -0.00104931794333...|
|   EXP|2025-01-22|News Article Numb...|      7.