Supervised Learning with pre-labeled datasets

In [1]:
pip install pyspark nltk findspark



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.2[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49m/opt/homebrew/Cellar/jupyterlab/4.2.5_1/libexec/bin/python -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [2]:
import findspark
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, trim
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pyspark.ml.feature import Tokenizer, StopWordsRemover, HashingTF, IDF
from pyspark.ml import Pipeline

findspark.init()
import os
os.environ['JAVA_OPTS'] = '-Djava.security.manager=allow'
# Step 1: Initialize SparkSession
spark = SparkSession.builder.config("spark.driver.host", "localhost").appName("YTSentAnal2").getOrCreate()

# Step 2: Define Dataset Paths
file_paths = {
    "LoganPaul": "LoganPaul.csv",
    "OKGO": "OKGO.csv",
    "RoyalWedding": "RoyalWedding.csv",
    "TaylorSwift": "TaylorSwift.csv",
    "Trump": "trump.csv",
}

24/11/29 19:03:46 WARN Utils: Your hostname, Kofis-MacBook-Air-2.local resolves to a loopback address: 127.0.0.1; using 192.168.1.65 instead (on interface en0)
24/11/29 19:03:46 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/11/29 19:03:47 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
----------------------------------------
Exception occurred during processing of request from ('127.0.0.1', 49661)
Traceback (most recent call last):
  File "/opt/homebrew/Cellar/python@3.12/3.12.7_1/Frameworks/Python.framework/Versions/3.12/lib/python3.12/socketserver.py", line 318, in _handle_request_noblock
    self.process_request(request, client_address)
  File "/opt/homebrew/Cellar/python@3.12/3.12.7_1/Frameworks/Python.framework/Versions/3.12/lib/python3.12/socketserv

In [3]:
# Load and basic cleaning of data
def load_and_clean(file_path, delimiter=","):
    schema = StructType([
        StructField("label", IntegerType(), True),
        StructField("text", StringType(), True)
    ])
    df = spark.read.option("header", "false").option("sep", delimiter).schema(schema).csv(file_path)
    return df.filter((col("text").isNotNull()) & (col("label").isNotNull()))

datasets = [load_and_clean(file_path, delimiter=";" if name == "OKGO" else ",") for name, file_path in file_paths.items()]

# Combine all datasets into one DataFrame
combined_df = spark.createDataFrame([], schema=datasets[0].schema)
for df in datasets:
    combined_df = combined_df.union(df)

# Verify label column distribution
combined_df.groupBy("label").count().show()

# Check schema to confirm proper column names
combined_df.printSchema()

                                                                                

+-----+-----+
|label|count|
+-----+-----+
|   -1|  780|
|    1|  818|
|    0| 1238|
+-----+-----+

root
 |-- label: integer (nullable = true)
 |-- text: string (nullable = true)



### Text Normalization with NLTK

This section tackles more preprocessing of text data, transforming raw text into meaningful tokens. Leveraging NLTK, we remove noise such as special characters, digits, and common stopwords. Additionally, we apply lemmatization to reduce words to their base form, ensuring consistency in representation. The use of PySpark's UDF (User-Defined Functions) enables seamless integration of Python-based transformations into the distributed processing pipeline.

In [4]:
import re
from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, StringType
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
import nltk

# Download NLTK data
nltk.download("wordnet")
nltk.download("omw-1.4")
nltk.download("stopwords")

# Initialize tools
lemmatizer = WordNetLemmatizer()
stop_words = set(stopwords.words("english"))

# Define preprocessing UDF
def preprocess_text(tokens):
    processed_tokens = []
    for word in tokens:
        word = word.lower()
        word = re.sub(r'[^\w\s]', '', word)  # Remove special characters
        word = re.sub(r'\d+', '', word)  # Remove digits
        
        if word not in stop_words:  # Remove stopwords
            lemmatized_word = lemmatizer.lemmatize(word)  # Lemmatize word
            processed_tokens.append(lemmatized_word)
    return processed_tokens

preprocess_udf = udf(preprocess_text, ArrayType(StringType()))

# Tokenization
tokenizer = Tokenizer(inputCol="text", outputCol="tokens")
tokenized_df = tokenizer.transform(combined_df)

# Apply UDF for normalization and lemmatization
normalized_df = tokenized_df.withColumn(
    "filtered_tokens", preprocess_udf(col("tokens"))
)

# Filter out empty tokenized rows
normalized_df = normalized_df.filter(col("filtered_tokens").isNotNull())

# Select only required columns
refined_preprocessed_df = normalized_df.select("filtered_tokens", "label")

# Verify the updated normalization
refined_preprocessed_df.printSchema()
refined_preprocessed_df.show(5, truncate=False)

[nltk_data] Downloading package wordnet to
[nltk_data]     /Users/kofifrempong/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     /Users/kofifrempong/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/kofifrempong/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


root
 |-- filtered_tokens: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- label: integer (nullable = true)





+--------------------------------------------------------------------------------------------------------------------------------------+-----+
|filtered_tokens                                                                                                                       |label|
+--------------------------------------------------------------------------------------------------------------------------------------+-----+
|[wow, heard, guy, easily, insecure, douche, ever, seen, youtube, clearly, mental, issue, need, evaluated, give, guy, help, need, asap]|1    |
|[japanese, trying, respectful, lo, gan, logan, care, wtf]                                                                             |-1   |
|[prick]                                                                                                                               |-1   |
|[think, weed, cry]                                                                                                                    |-1   |

                                                                                

### Balancing the Dataset with Class Weights

Since we know these datasets exhibit imbalanced class distributions (-1:780, 0:818, 1: 1238), which can bias the model. We can address this by calculating class weights proportional to their frequencies. These weights are added to the dataset, ensuring fairer representation of each sentiment class during model training. 


In [5]:
# Class Weighting
from pyspark.sql.functions import when

# Add weights to the dataset
class_weights = {
    -1: 1.0 / 780,  
    0: 1.0 / 818,
    1: 1.0 / 1238
}

# Add weights to the dataset
balanced_training_data = refined_preprocessed_df.withColumn(
    "weight",
    when(col("label") == -1, class_weights[-1])
    .when(col("label") == 0, class_weights[0])
    .when(col("label") == 1, class_weights[1])
)

# Verify the added weights
balanced_training_data.groupBy("label").agg({"weight": "avg"}).show()



+-----+--------------------+
|label|         avg(weight)|
+-----+--------------------+
|   -1|0.001282051282051...|
|    1|8.077544426494289E-4|
|    0|0.001222493887530...|
+-----+--------------------+



                                                                                

### Feature Extraction with CountVectorizer and IDF
Next, we transition from raw tokens to numerical feature vectors. Using `CountVectorizer`, we create a vocabulary-based representation of the text, capturing the frequency of terms. The `IDF` (Inverse Document Frequency) scaling step enhances the representation by down-weighting common terms, emphasizing words more indicative of the sentiment. These stages are encapsulated in a pipeline for streamlined preprocessing.

In [6]:
from pyspark.ml.feature import CountVectorizer, IDF
from pyspark.ml import Pipeline

count_vectorizer = CountVectorizer(inputCol="filtered_tokens", outputCol="raw_features", vocabSize=100000)

# Apply IDF for scaling feature vectors
idf = IDF(inputCol="raw_features", outputCol="features")

# Create a feature extraction pipeline
feature_pipeline = Pipeline(stages=[count_vectorizer, idf])

# Fit and transform the pipeline on the weighted dataset
feature_model = feature_pipeline.fit(balanced_training_data)
featured_df = feature_model.transform(balanced_training_data)

# Select necessary columns for training
final_training_data = featured_df.select("features", "label", "weight")

# Verify the resulting dataset
final_training_data.printSchema()
final_training_data.show(5, truncate=False)

                                                                                

root
 |-- features: vector (nullable = true)
 |-- label: integer (nullable = true)
 |-- weight: double (nullable = true)





+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+-----+---------------------+
|features                                                                                                                                                                                                                                                                                                                                                                                        |label|weight               |
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------

                                                                                

### Label Remapping

Since PySpark’s MLlib LogisticRegression does not support negative label values and since with is a multiclass classification, we need to remap the labels (-1, 0, 1) to non-negative integers (0, 1, 2). This transformation ensures compatibility with Spark's MLlib while preserving the underlying semantics of the labels.

In [7]:
# Label Remapping and Data Splitting

from pyspark.sql.functions import when

# Remap labels: -1 -> 0, 0 -> 1, 1 -> 2 for ML since -1 cannot b used 4 training
final_training_data = final_training_data.withColumn(
    "label",
    when(col("label") == -1, 0)
    .when(col("label") == 0, 1)
    .when(col("label") == 1, 2)
)


### Training the Logistic Regression Model
The logistic regression model, optimized using k-fold cross-validation. A parameter grid is defined to tune regularization strength (`RegParam`) and the ElasticNet mixing parameter. The cross-validation ensures robust model evaluation across multiple splits, providing the best hyperparameters for the task. Finally, the best model is saved for reuse, with its key parameters displayed.

In [22]:

from pyspark.ml.classification import LogisticRegression, RandomForestClassifier
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.sql.functions import monotonically_increasing_id


# Logistic Regression Model
lr = LogisticRegression(featuresCol="features", labelCol="label", weightCol="weight", maxIter=20)

# Random Forest Model
rf = RandomForestClassifier(featuresCol="features", labelCol="label", weightCol="weight", numTrees=50)

# Logistic Regression Parameter Grid
lr_param_grid = ParamGridBuilder() \
    .addGrid(lr.regParam, [0.01, 0.1, 1.0]) \
    .addGrid(lr.elasticNetParam, [0.0, 0.5, 1.0]) \
    .build()

# Random Forest Parameter Grid
rf_param_grid = ParamGridBuilder() \
    .addGrid(rf.numTrees, [10, 50, 100]) \
    .addGrid(rf.maxDepth, [5, 10]) \
    .build()

# Cross-validation for Logistic Regression
lr_cv = CrossValidator(
    estimator=lr,
    estimatorParamMaps=lr_param_grid,
    evaluator=MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy"),
    numFolds=5
)

# Cross-validation for Random Forest
rf_cv = CrossValidator(
    estimator=rf,
    estimatorParamMaps=rf_param_grid,
    evaluator=MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy"),
    numFolds=5
)

from pyspark.sql.functions import col




# Train individual models
lr_model = lr_cv.fit(final_training_data)
rf_model = rf_cv.fit(final_training_data)

lr_model = lr_model.bestModel
rf_model = rf_model.bestModel

lr_predictions = lr_model.transform(final_training_data)
rf_predictions = rf_model.transform(final_training_data)

# Extract predictions from both models
# Rename columns in lr_predictions and rf_predictions to avoid ambiguity
lr_predictions = lr_predictions.withColumn("id", monotonically_increasing_id()) \
    .withColumnRenamed("label", "lr_label") \
    .withColumnRenamed("prediction", "lr_prediction")

rf_predictions = rf_predictions.withColumn("id", monotonically_increasing_id()) \
    .withColumnRenamed("label", "rf_label") \
    .withColumnRenamed("prediction", "rf_prediction")

lr_predictions = lr_predictions.withColumn("lr_prediction", col("lr_prediction").cast("double"))
rf_predictions = rf_predictions.withColumn("rf_prediction", col("rf_prediction").cast("double"))

# Join predictions from both models
combined_predictions = lr_predictions.join(rf_predictions, on="id").select(
    "id", "lr_label", "lr_prediction", "rf_prediction"
)

# Verify which label to use and rename it for clarity
combined_predictions = combined_predictions.withColumnRenamed("lr_label", "label")

# Combine predictions into a single feature vector
assembler = VectorAssembler(
    inputCols=["lr_prediction", "rf_prediction"],
    outputCol="features"
)

ensemble_data = assembler.transform(combined_predictions)

from pyspark.ml.linalg import DenseVector, VectorUDT
from pyspark.sql.functions import udf

# Define the UDF with VectorUDT as the return type
to_dense_udf = udf(lambda v: DenseVector(v.toArray()), VectorUDT())

# Apply the UDF to the column
ensemble_data = ensemble_data.withColumn("features", to_dense_udf(col("features")))





# Meta-classifier (Final Logistic Regression on combined predictions)
ensemble_lr = LogisticRegression(featuresCol="features", labelCol="label", maxIter=20)

# Train meta-classifier
best_model = ensemble_lr.fit(ensemble_data)

# Save the ensemble model
best_model.write().overwrite().save("best_model")

24/11/29 21:57:31 WARN DAGScheduler: Broadcasting large task binary with size 1033.0 KiB
24/11/29 21:58:21 WARN DAGScheduler: Broadcasting large task binary with size 1041.1 KiB
24/11/29 21:59:32 WARN DAGScheduler: Broadcasting large task binary with size 1031.9 KiB
24/11/29 22:00:21 WARN DAGScheduler: Broadcasting large task binary with size 1048.6 KiB
24/11/29 22:01:19 WARN DAGScheduler: Broadcasting large task binary with size 1079.9 KiB
                                                                                

In [23]:
# Evaluate the best model on the entire dataset
test_predictions = best_model.transform(final_training_data)

# Evaluate test accuracy
evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")
test_accuracy = evaluator.evaluate(test_predictions)
print(f"Cross-Validated Test Accuracy: {test_accuracy}")

# Evaluate precision, recall, and F1-score for each class
labels = [0, 1, 2]  
for label in labels:
    precision = evaluator.evaluate(test_predictions, {evaluator.metricName: "precisionByLabel", evaluator.metricLabel: label})
    recall = evaluator.evaluate(test_predictions, {evaluator.metricName: "recallByLabel", evaluator.metricLabel: label})
    f1 = evaluator.evaluate(test_predictions, {evaluator.metricName: "fMeasureByLabel", evaluator.metricLabel: label})
    print(f"Class {label}: Precision = {precision}, Recall = {recall}, F1-Score = {f1}")

24/11/29 22:03:50 ERROR Executor: Exception in task 8.0 in stage 38682.0 (TID 338330)
org.apache.spark.SparkException: [FAILED_EXECUTE_UDF] Failed to execute user defined function (`ProbabilisticClassificationModel$$Lambda$5170/0x0000000801b29840`: (struct<type:tinyint,size:int,indices:array<int>,values:array<double>>) => struct<type:tinyint,size:int,indices:array<int>,values:array<double>>).
	at org.apache.spark.sql.errors.QueryExecutionErrors$.failedExecuteUserDefinedFunctionError(QueryExecutionErrors.scala:198)
	at org.apache.spark.sql.errors.QueryExecutionErrors.failedExecuteUserDefinedFunctionError(QueryExecutionErrors.scala)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage6.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluato

Py4JJavaError: An error occurred while calling o228900.evaluate.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 10 in stage 38682.0 failed 1 times, most recent failure: Lost task 10.0 in stage 38682.0 (TID 338332) (192.168.1.65 executor driver): org.apache.spark.SparkException: [FAILED_EXECUTE_UDF] Failed to execute user defined function (`ProbabilisticClassificationModel$$Lambda$5170/0x0000000801b29840`: (struct<type:tinyint,size:int,indices:array<int>,values:array<double>>) => struct<type:tinyint,size:int,indices:array<int>,values:array<double>>).
	at org.apache.spark.sql.errors.QueryExecutionErrors$.failedExecuteUserDefinedFunctionError(QueryExecutionErrors.scala:198)
	at org.apache.spark.sql.errors.QueryExecutionErrors.failedExecuteUserDefinedFunctionError(QueryExecutionErrors.scala)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage12.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:43)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:197)
	at org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)
	at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:104)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:54)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	at java.base/java.lang.Thread.run(Thread.java:829)
Caused by: java.lang.IllegalArgumentException: requirement failed: The columns of A don't match the number of elements of x. A: 2, x: 4538
	at scala.Predef$.require(Predef.scala:281)
	at org.apache.spark.ml.linalg.BLAS$.gemv(BLAS.scala:643)
	at org.apache.spark.ml.classification.LogisticRegressionModel.$anonfun$margins$1(LogisticRegression.scala:1157)
	at org.apache.spark.ml.classification.LogisticRegressionModel.predictRaw(LogisticRegression.scala:1239)
	at org.apache.spark.ml.classification.LogisticRegressionModel.predictRaw(LogisticRegression.scala:1060)
	at org.apache.spark.ml.classification.ProbabilisticClassificationModel.$anonfun$transform$2(ProbabilisticClassifier.scala:121)
	... 22 more

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2856)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2792)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2791)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2791)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1247)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3060)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2994)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2983)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:989)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2393)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2414)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2433)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2458)
	at org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1049)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:410)
	at org.apache.spark.rdd.RDD.collect(RDD.scala:1048)
	at org.apache.spark.rdd.PairRDDFunctions.$anonfun$collectAsMap$1(PairRDDFunctions.scala:738)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:410)
	at org.apache.spark.rdd.PairRDDFunctions.collectAsMap(PairRDDFunctions.scala:737)
	at org.apache.spark.mllib.evaluation.MulticlassMetrics.confusions$lzycompute(MulticlassMetrics.scala:61)
	at org.apache.spark.mllib.evaluation.MulticlassMetrics.confusions(MulticlassMetrics.scala:52)
	at org.apache.spark.mllib.evaluation.MulticlassMetrics.tpByClass$lzycompute(MulticlassMetrics.scala:78)
	at org.apache.spark.mllib.evaluation.MulticlassMetrics.tpByClass(MulticlassMetrics.scala:76)
	at org.apache.spark.mllib.evaluation.MulticlassMetrics.accuracy$lzycompute(MulticlassMetrics.scala:188)
	at org.apache.spark.mllib.evaluation.MulticlassMetrics.accuracy(MulticlassMetrics.scala:188)
	at org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator.evaluate(MulticlassClassificationEvaluator.scala:153)
	at jdk.internal.reflect.GeneratedMethodAccessor414.invoke(Unknown Source)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:566)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:829)
Caused by: org.apache.spark.SparkException: [FAILED_EXECUTE_UDF] Failed to execute user defined function (`ProbabilisticClassificationModel$$Lambda$5170/0x0000000801b29840`: (struct<type:tinyint,size:int,indices:array<int>,values:array<double>>) => struct<type:tinyint,size:int,indices:array<int>,values:array<double>>).
	at org.apache.spark.sql.errors.QueryExecutionErrors$.failedExecuteUserDefinedFunctionError(QueryExecutionErrors.scala:198)
	at org.apache.spark.sql.errors.QueryExecutionErrors.failedExecuteUserDefinedFunctionError(QueryExecutionErrors.scala)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage12.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:43)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:197)
	at org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)
	at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:104)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:54)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	... 1 more
Caused by: java.lang.IllegalArgumentException: requirement failed: The columns of A don't match the number of elements of x. A: 2, x: 4538
	at scala.Predef$.require(Predef.scala:281)
	at org.apache.spark.ml.linalg.BLAS$.gemv(BLAS.scala:643)
	at org.apache.spark.ml.classification.LogisticRegressionModel.$anonfun$margins$1(LogisticRegression.scala:1157)
	at org.apache.spark.ml.classification.LogisticRegressionModel.predictRaw(LogisticRegression.scala:1239)
	at org.apache.spark.ml.classification.LogisticRegressionModel.predictRaw(LogisticRegression.scala:1060)
	at org.apache.spark.ml.classification.ProbabilisticClassificationModel.$anonfun$transform$2(ProbabilisticClassifier.scala:121)
	... 22 more


In [None]:
from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType

# Create a confusion matrix DataFrame
confusion_matrix_df = (
    test_predictions
    .groupBy("label", "prediction")
    .agg(F.count("*").alias("count"))
    .orderBy("label", "prediction")
)

# Collect the confusion matrix for display
confusion_matrix = confusion_matrix_df.collect()

# Print confusion matrix
print("Confusion Matrix:")
for row in confusion_matrix:
    print(f"Label {int(row['label'])} Predicted as {int(row['prediction'])}: {row['count']}")