In [1]:
# pyspark packages
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType, BooleanType

#other needed packages
import re
import os

# Set JAVA_HOME for PySpark
os.environ['JAVA_HOME'] = '/opt/homebrew/opt/openjdk@17'

spark = SparkSession.builder \
    .appName("stock market preds") \
    .config("spark.driver.host", "127.0.0.1") \
    .getOrCreate()
spark.sparkContext.setLogLevel("ERROR")

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
26/01/14 03:33:12 WARN Utils: Your hostname, Jeffreys-MacBook-Air.local, resolves to a loopback address: 127.0.0.1; using 10.0.0.17 instead (on interface en0)
26/01/14 03:33:12 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
26/01/14 03:33:12 WARN Utils: Your hostname, Jeffreys-MacBook-Air.local, resolves to a loopback address: 127.0.0.1; using 10.0.0.17 instead (on interface en0)
26/01/14 03:33:12 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust loggi

In [2]:
# functions to import data, run SQL from file and save back to file
# function to import and clean columns
def import_csv_to_table(table_name, file, format_cols):

    #read source files
    df = spark.read.csv(file, header=True, quote="\"",
                        escape="\"", multiLine=True, inferSchema=True)

    #clean column names
    if format_cols:
        cols_formatted = [re.sub(r"[^a-zA-Z0-9\s]", "", col_name).lower().replace(" ", "_") for col_name in df.columns]
        df = df.toDF(*cols_formatted)

    # create SQL view
    df.createOrReplaceTempView(f"{table_name}")
    return df

#run a SQL step
def sql_step(file):
    with open(file, 'r', encoding='utf-8') as file:
        sql_text = file.read()
    results = spark.sql(sql_text)
    return results

#run SQL and view output inline
def run_sql(file, rowstoshow, print_sql):
    with open(file, 'r', encoding='utf-8') as file:
        sql_text = file.read()
    results = spark.sql(sql_text)
    if print_sql == True: print(sql_text)
    results.show(rowstoshow, truncate=False)

# export data frame to csv
def export_csv(df, output_dir, final_file_name):
    df.coalesce(1).write.csv(output_dir, header=True, mode="overwrite")
    for file in os.listdir(output_dir):
        if file.startswith("part-") and file.endswith(".csv"):
            part_file_path = os.path.join(output_dir, file)
            break
    if part_file_path:
        os.rename(part_file_path, os.path.join(output_dir, final_file_name))
        print(f"CSV saved as: {final_file_name}")
    else:
        print("Error: Part file not found.")

In [3]:
news = import_csv_to_table("news", "raw_data/news_data.csv", False)
stocks = import_csv_to_table("stocks", "raw_data/stock_data.csv", False)

In [15]:
feature_set = sql_step("sql/sentiment_data_prep_v2.sql")
feature_set.show(300, truncate=False)

+------+----------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------------------+
|symbol|news_date |article_text                                                                                                                                                                                                                                                                                                                                                                                                                            |percent_daily_price_change|
+------+----------+-------------------------------------