In [1]:
from pyspark.ml.feature import StringIndexer, StandardScaler, VectorAssembler
from pyspark.ml import Pipeline
from pyspark.ml.regression import RandomForestRegressor
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.sql import SparkSession
from pyspark.sql.functions import lead
from pyspark.sql.window import Window

In [2]:
# Initialize Spark Session
spark = SparkSession.builder.appName("StockDataModel").getOrCreate()

# Define the path to the cleaned data in GCS
cleaned_data_path = "gs://my-big-data-as/cleaned/*.parquet"

# Load the cleaned data into a Spark DataFrame (using a sample for testing)
df = spark.read.parquet(cleaned_data_path)

# For testing, sample 2.5% of the data
df = df.sample(False, 0.025, seed=42)

# Show the first few rows to understand the structure
df.show(5)

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/12/01 01:41:59 INFO SparkEnv: Registering MapOutputTracker
24/12/01 01:41:59 INFO SparkEnv: Registering BlockManagerMaster
24/12/01 01:41:59 INFO SparkEnv: Registering BlockManagerMasterHeartbeat
24/12/01 01:42:00 INFO SparkEnv: Registering OutputCommitCoordinator
                                                                                

+-------------------+-------------+-----+-----+-----+------+---------+-----+-----+------------------+-------------------+------------------+
|           datetime|ticker_symbol| high|close| open|volume|    obv_0|mom_3|ema_3|bbands_3_upperband|bbands_3_middleband|bbands_3_lowerband|
+-------------------+-------------+-----+-----+-----+------+---------+-----+-----+------------------+-------------------+------------------+
|1998-01-30 00:00:00|       NVR.US|25.75|25.25|25.25| 21800| 145500.0| 1.69|24.76|             26.24|              24.52|              22.8|
|1998-06-09 00:00:00|       NVR.US|32.75|32.56|32.75| 37000|1210101.0| 0.43| 32.4|             32.86|              32.35|             31.85|
|1998-06-15 00:00:00|       NVR.US|33.69|33.25|33.25| 14400|1240801.0| 0.25|33.15|              33.4|              33.17|             32.93|
|1998-10-01 00:00:00|       NVR.US|32.88| 32.5|32.75|  7400|1775902.0|-3.25|33.29|             34.96|              33.31|             31.67|
|1998-10-23 0

In [3]:
# Create a Window specification to calculate the next day's closing price
windowSpec = Window.partitionBy("ticker_symbol").orderBy("datetime")

# Use the LEAD function to look ahead one day and get the closing price
df = df.withColumn("next_day_close", lead("close", 1).over(windowSpec))

# Drop the last row for each group (since it will have NULL for the next day's close)
df = df.dropna(subset=["next_day_close"])

# Ensure columns are in double type for scaling
df = df.withColumn("volume", df.volume.cast('double'))

In [4]:
# List of columns to scale
columns_to_scale = ["high", "open", "volume", "obv_0", "mom_3", "ema_3", 
                    "bbands_3_upperband", "bbands_3_middleband", "bbands_3_lowerband"]

# Assemble columns to scale into one vector in the pipeline (remove earlier manual transformation)
assembler_ = VectorAssembler(inputCols=columns_to_scale, outputCol="columns_to_scale_vector")

# Scale the vector using StandardScaler
scaler = StandardScaler(inputCol="columns_to_scale_vector", outputCol="scaled_vector", withStd=True, withMean=False)

# StringIndexer for ticker_symbol (categorical feature)
indexer = StringIndexer(inputCol="ticker_symbol", outputCol="ticker_index")

# Assemble the final features into a feature vector
final_assembler = VectorAssembler(
    inputCols=["scaled_vector", "ticker_index"],  # Includes scaled features and ticker_index
    outputCol="features"
)

In [5]:
# Create the pipeline with all the stages
pipeline = Pipeline(stages=[assembler_, scaler, indexer, final_assembler])

# Transform dataframe based on pipeline
df_transformed = pipeline.fit(df).transform(df)

# Save transformed feature vectors to trusted folder before model
df_transformed.write.parquet("gs://my-big-data-as/trusted/transformed_feature_vectors")

                                                                                

In [7]:
# Define the Random Forest model
rf = RandomForestRegressor(featuresCol='features', labelCol='next_day_close', maxBins=2048)

# Split the data into training and testing sets 
train_data, test_data = df_transformed.randomSplit([0.8, 0.2], seed=49)

# Cache training data to speed up cross-validation
train_data.cache()

# Set up cross-validation with hyperparameter tuning
paramGrid = ParamGridBuilder() \
    .addGrid(rf.numTrees, [10, 20]) \
    .addGrid(rf.maxDepth, [5, 10]) \
    .build()

# Regression evaluator for RMSE
rmse_evaluator = RegressionEvaluator(labelCol="next_day_close", predictionCol="prediction", metricName="rmse")

# Set up the CrossValidator with RandomForest model, parameter grid, and evaluator
cv = CrossValidator(estimator=rf,
                    estimatorParamMaps=paramGrid,
                    evaluator=rmse_evaluator,  # Evaluator for RMSE
                    numFolds=2)

# Train the model using cross-validation
cvModel = cv.fit(train_data)

# Make predictions on the test set
predictions = cvModel.transform(test_data)

24/12/01 02:36:38 WARN CacheManager: Asked to cache already cached data.
24/12/01 02:45:11 WARN DAGScheduler: Broadcasting large task binary with size 1906.6 KiB
24/12/01 02:50:35 WARN DAGScheduler: Broadcasting large task binary with size 1299.9 KiB
24/12/01 02:54:17 WARN DAGScheduler: Broadcasting large task binary with size 2.3 MiB
24/12/01 02:57:23 WARN DAGScheduler: Broadcasting large task binary with size 1786.4 KiB
24/12/01 03:03:51 WARN DAGScheduler: Broadcasting large task binary with size 1923.3 KiB
24/12/01 03:12:49 WARN DAGScheduler: Broadcasting large task binary with size 2.2 MiB
24/12/01 03:22:13 WARN DAGScheduler: Broadcasting large task binary with size 2.2 MiB
24/12/01 03:25:18 WARN DAGScheduler: Broadcasting large task binary with size 1336.3 KiB
24/12/01 03:31:42 WARN DAGScheduler: Broadcasting large task binary with size 3.0 MiB
24/12/01 03:33:30 WARN DAGScheduler: Broadcasting large task binary with size 1669.0 KiB
24/12/01 03:41:22 WARN DAGScheduler: Broadcasting

24/12/01 04:46:57 WARN DAGScheduler: Broadcasting large task binary with size 2.0 MiB
24/12/01 04:56:14 WARN DAGScheduler: Broadcasting large task binary with size 2.4 MiB
24/12/01 05:01:56 WARN DAGScheduler: Broadcasting large task binary with size 1449.6 KiB
24/12/01 05:05:34 WARN DAGScheduler: Broadcasting large task binary with size 4.2 MiB
24/12/01 05:11:13 WARN DAGScheduler: Broadcasting large task binary with size 1669.6 KiB
24/12/01 05:15:14 WARN DAGScheduler: Broadcasting large task binary with size 3.5 MiB
24/12/01 05:20:28 WARN DAGScheduler: Broadcasting large task binary with size 1661.6 KiB
24/12/01 05:23:47 WARN DAGScheduler: Broadcasting large task binary with size 4.0 MiB
24/12/01 05:29:43 WARN DAGScheduler: Broadcasting large task binary with size 1857.8 KiB
24/12/01 05:32:46 WARN DAGScheduler: Broadcasting large task binary with size 1695.3 KiB
24/12/01 05:38:49 WARN DAGScheduler: Broadcasting large task binary with size 1554.6 KiB
24/12/01 05:41:53 WARN DAGScheduler:

In [8]:
# Evaluate the model using RMSE, MAE, and R2
rmse = rmse_evaluator.evaluate(predictions)
print(f"Root Mean Squared Error (RMSE) = {rmse}")

mae_evaluator = RegressionEvaluator(labelCol="next_day_close", predictionCol="prediction", metricName="mae")
mae = mae_evaluator.evaluate(predictions)
print(f"Mean Absolute Error (MAE) = {mae}")

r2_evaluator = RegressionEvaluator(labelCol="next_day_close", predictionCol="prediction", metricName="r2")
r2 = r2_evaluator.evaluate(predictions)
print(f"R-Squared (R2) = {r2}")

# Get feature importance from the best model
rf_model = cvModel.bestModel  # Best model after cross-validation

# Extract feature importances
feature_importances = rf_model.featureImportances

# Print the feature importances
print("Feature Importances: ")
for feature, importance in zip(df_transformed.columns, feature_importances):
    print(f"{feature}: {importance}")

                                                                                

Root Mean Squared Error (RMSE) = 21.115899648183415


                                                                                

Mean Absolute Error (MAE) = 8.1807257765551


                                                                                

R-Squared (R2) = 0.9835247834224002
Feature Importances: 
datetime: 0.3520112845975745
ticker_symbol: 0.06829989467833736
high: 0.0016100864946969075
close: 0.006082477935932407
open: 0.0008128456852295622
volume: 0.16883962746515507
obv_0: 0.18445926616383362
mom_3: 0.1318217733920169
ema_3: 0.08457997615903919
bbands_3_upperband: 0.0014827674281844242


24/12/01 06:09:12 WARN DAGScheduler: Broadcasting large task binary with size 2.4 MiB

In [10]:
# Save the trained model to a location (optional)
rf_model.save("gs://my-big-data-as/models/stock_model")

# Optionally, save the test predictions to GCS
predictions.select("ticker_symbol", "datetime", "next_day_close", "prediction").write.parquet("gs://my-big-data-as/models/test_predictions.parquet")

# Stop the Spark session
spark.stop()

24/12/01 06:17:30 ERROR Instrumentation: org.apache.spark.SparkException: Job 207 cancelled because SparkContext was shut down
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$cleanUpAfterSchedulerStop$1(DAGScheduler.scala:1253)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$cleanUpAfterSchedulerStop$1$adapted(DAGScheduler.scala:1251)
	at scala.collection.mutable.HashSet.foreach(HashSet.scala:79)
	at org.apache.spark.scheduler.DAGScheduler.cleanUpAfterSchedulerStop(DAGScheduler.scala:1251)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onStop(DAGScheduler.scala:3087)
	at org.apache.spark.util.EventLoop.stop(EventLoop.scala:84)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$stop$3(DAGScheduler.scala:2973)
	at org.apache.spark.util.Utils$.tryLogNonFatalError(Utils.scala:1377)
	at org.apache.spark.scheduler.DAGScheduler.stop(DAGScheduler.scala:2973)
	at org.apache.spark.SparkContext.$anonfun$stop$12(SparkContext.scala:2322)
	at org.apache.spark.util.Utils$.try

This stopped SparkContext was created at:

org.apache.spark.api.java.JavaSparkContext.<init>(JavaSparkContext.scala:58)
java.base/jdk.internal.reflect.NativeConstructorAccessorImpl.newInstance0(Native Method)
java.base/jdk.internal.reflect.NativeConstructorAccessorImpl.newInstance(NativeConstructorAccessorImpl.java:62)
java.base/jdk.internal.reflect.DelegatingConstructorAccessorImpl.newInstance(DelegatingConstructorAccessorImpl.java:45)
java.base/java.lang.reflect.Constructor.newInstance(Constructor.java:490)
py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:247)
py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
py4j.Gateway.invoke(Gateway.java:238)
py4j.commands.ConstructorCommand.invokeConstructor(ConstructorCommand.java:80)
py4j.commands.ConstructorCommand.execute(ConstructorCommand.java:69)
py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
py4j.ClientServerConnection.run(ClientServerConnection.java:106)
java.base/java.lang.Thread.ru

24/12/01 06:17:31 ERROR Instrumentation: java.lang.IllegalStateException: Cannot call methods on a stopped SparkContext.
This stopped SparkContext was created at:

org.apache.spark.api.java.JavaSparkContext.<init>(JavaSparkContext.scala:58)
java.base/jdk.internal.reflect.NativeConstructorAccessorImpl.newInstance0(Native Method)
java.base/jdk.internal.reflect.NativeConstructorAccessorImpl.newInstance(NativeConstructorAccessorImpl.java:62)
java.base/jdk.internal.reflect.DelegatingConstructorAccessorImpl.newInstance(DelegatingConstructorAccessorImpl.java:45)
java.base/java.lang.reflect.Constructor.newInstance(Constructor.java:490)
py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:247)
py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
py4j.Gateway.invoke(Gateway.java:238)
py4j.commands.ConstructorCommand.invokeConstructor(ConstructorCommand.java:80)
py4j.commands.ConstructorCommand.execute(ConstructorCommand.java:69)
py4j.ClientServerConnection.waitForCommands(Client