In [3]:
# Import necessary libraries
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pyspark.ml.feature import VectorAssembler as va
from pyspark.ml.regression import LinearRegression as lr
from pyspark.ml.evaluation import RegressionEvaluator as re

# Create a SparkSession
spark_session_1 = SparkSession.builder.appName("TemperaturePrediction").getOrCreate()

# Define the schema for the dataset
schema = StructType(
    [
        StructField("station_id", StringType(), True),
        StructField("date", IntegerType(), True),
        StructField("observation_type", StringType(), True),
        StructField("value", IntegerType(), True),
        StructField("extra_info", StringType(), True),
    ]
)

# Load the dataset
data = spark_session_1.read.csv("./1800.csv", schema=schema)

# Filter the dataset to keep only the TMAX col
tmax_data = data.filter(data.observation_type == "TMAX")

# Prepare features
vector_assembler = va(inputCols=["date", "value"], outputCol="features")
data_with_features = vector_assembler.transform(tmax_data)

# Split the data into training and test sets
train_data, test_data = data_with_features.randomSplit([0.7, 0.3], seed=123)

# Instantiate the linear regression instance
lr = lr(featuresCol="features", labelCol="value")

# Train the model
lr_model = lr.fit(train_data)

# Make predictions on the test data
predictions = lr_model.transform(test_data)

# Evaluate the model
evaluator = re(labelCol="value", predictionCol="prediction", metricName="rmse")
rmse = evaluator.evaluate(predictions)

print("Root Mean Squared Error (RMSE):", rmse)

# Stop the SparkSession
spark_session_1.stop()

24/02/23 14:34:13 WARN Instrumentation: [b70c9fb0] regParam is zero, which might cause numerical instability and overfitting.


Root Mean Squared Error (RMSE): 2.855992731933797e-14
