### Machine Learning Pipeline
The pipeline will involve the following models:
1. Decision Tree
2. Regression
3. Random Forest
4. Gradient-boosted trees
5. Linear Support Vector Machines

### How will this work?
Here's our recipe for success:
- 1st: load the data
- 2nd: split the data (80/20 approach)
- 3rd: get our feature columns and vectorize
- 4th: instantiate Models
- 5th: build and run the pipeline
- 6th: apply metrics (we use accuracy, precision, recall and f1-score)
- 7th: plot the confusion matrix

Documentation: 
- https://spark.apache.org/docs/latest/ml-pipeline.html; 
- https://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-classifier
- https://www.v7labs.com/blog/f1-score-guide
- https://developers.google.com/machine-learning/crash-course/classification/precision-and-recall
- https://datascience-enthusiast.com/Python/PySpark_ML_with_Text_part1.html
- https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.mllib.evaluation.MulticlassMetrics.html?highlight=confusion#pyspark.mllib.evaluation.MulticlassMetrics.confusionMatrix
- https://www.sparkitecture.io/machine-learning/model-evaluation
- https://towardsdatascience.com/understanding-confusion-matrix-a9ad42dcfd62

Note: we are using an already scaled and normalized dataset, without duplicates, which was processed in the EDA stage of the project.
- Dataset: dfscaled.csv

### Let's start with the imbalanced dataset


In [1]:
# imports and configure spark session

# Import libraries
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
import datetime
from scipy.stats import norm
from scipy import stats

# Import PySpark libraries
from pyspark.sql import Window
import pyspark.sql.types as t
import pyspark.sql.functions as f
from pyspark.sql import SparkSession
from pyspark.ml.feature import StandardScaler
from pyspark.ml import Pipeline, PipelineModel
from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier, RandomForestClassifier, GBTClassifier, LinearSVC, NaiveBayes
from pyspark.ml.feature import StringIndexer, VectorIndexer, IndexToString, VectorAssembler
from pyspark.ml.evaluation import MulticlassClassificationEvaluator, BinaryClassificationEvaluator
from pyspark.mllib.evaluation import MulticlassMetrics 
from pyspark.mllib.linalg import Matrix

import warnings
warnings.filterwarnings('ignore')

spark = SparkSession.builder \
    .appName("Read CSV Files - Sonae") \
    .master("local[*]") \
    .config("spark.driver.bindAddress", '127.0.0.1') \
    .getOrCreate()

sc = spark.sparkContext

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


23/05/10 21:38:51 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
23/05/10 21:38:53 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [2]:
# Load file
df = spark.read.csv('datasets/creditcard.csv', header=True, inferSchema=True, sep=",")
# Print Schema
df.printSchema()

[Stage 1:>                                                          (0 + 8) / 8]

root
 |-- Time: double (nullable = true)
 |-- V1: double (nullable = true)
 |-- V2: double (nullable = true)
 |-- V3: double (nullable = true)
 |-- V4: double (nullable = true)
 |-- V5: double (nullable = true)
 |-- V6: double (nullable = true)
 |-- V7: double (nullable = true)
 |-- V8: double (nullable = true)
 |-- V9: double (nullable = true)
 |-- V10: double (nullable = true)
 |-- V11: double (nullable = true)
 |-- V12: double (nullable = true)
 |-- V13: double (nullable = true)
 |-- V14: double (nullable = true)
 |-- V15: double (nullable = true)
 |-- V16: double (nullable = true)
 |-- V17: double (nullable = true)
 |-- V18: double (nullable = true)
 |-- V19: double (nullable = true)
 |-- V20: double (nullable = true)
 |-- V21: double (nullable = true)
 |-- V22: double (nullable = true)
 |-- V23: double (nullable = true)
 |-- V24: double (nullable = true)
 |-- V25: double (nullable = true)
 |-- V26: double (nullable = true)
 |-- V27: double (nullable = true)
 |-- V28: double (nulla

                                                                                

In [3]:
# Drop duplicates
df = df.dropDuplicates()
print("Distinct count: "+str(df.count()))

23/05/10 21:39:35 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.




Distinct count: 283726


                                                                                

## Normalization

In [4]:
# UDF for converting column type from vector to double type
unlist = f.udf(lambda x: round(float(list(x)[0]),3), t.DoubleType())

In [5]:
# Iterating over columns to be scaled
for i in ['Time', 'V1','V2','V3','V4','V5','V6','V7','V8','V9','V10','V11','V12','V13','V14','V15','V16','V17','V18','V19','V20','V21','V22','V23','V24','V25','V26','V27','V28','Amount']:
    # VectorAssembler Transformation - Converting column to vector type
    assembler = VectorAssembler(inputCols=[i],outputCol=i+"_Vect")

    # StandardScaler Transformation
    scaler = StandardScaler(inputCol=i+"_Vect", outputCol=i+"_Scaled", withStd=True, withMean=True)

    # Pipeline of VectorAssembler and MinMaxScaler
    pipeline = Pipeline(stages=[assembler, scaler])

    # Fitting pipeline on dataframe
    df = pipeline.fit(df).transform(df).withColumn(i+"_Scaled", unlist(i+"_Scaled")).drop(i+"_Vect")

                                                                                

In [6]:
dfscaled = df.drop('Time', 'V1','V2','V3','V4','V5','V6','V7','V8','V9','V10','V11','V12','V13','V14','V15','V16','V17','V18','V19','V20','V21','V22','V23','V24','V25','V26','V27','V28','Amount', 'time_udf')

In [7]:
dfscaled.printSchema()

root
 |-- Class: integer (nullable = true)
 |-- Time_Scaled: double (nullable = true)
 |-- V1_Scaled: double (nullable = true)
 |-- V2_Scaled: double (nullable = true)
 |-- V3_Scaled: double (nullable = true)
 |-- V4_Scaled: double (nullable = true)
 |-- V5_Scaled: double (nullable = true)
 |-- V6_Scaled: double (nullable = true)
 |-- V7_Scaled: double (nullable = true)
 |-- V8_Scaled: double (nullable = true)
 |-- V9_Scaled: double (nullable = true)
 |-- V10_Scaled: double (nullable = true)
 |-- V11_Scaled: double (nullable = true)
 |-- V12_Scaled: double (nullable = true)
 |-- V13_Scaled: double (nullable = true)
 |-- V14_Scaled: double (nullable = true)
 |-- V15_Scaled: double (nullable = true)
 |-- V16_Scaled: double (nullable = true)
 |-- V17_Scaled: double (nullable = true)
 |-- V18_Scaled: double (nullable = true)
 |-- V19_Scaled: double (nullable = true)
 |-- V20_Scaled: double (nullable = true)
 |-- V21_Scaled: double (nullable = true)
 |-- V22_Scaled: double (nullable = true)

In [8]:
dfscaled.show()

[Stage 190:>                                                        (0 + 1) / 1]

+-----+-----------+---------+---------+---------+---------+---------+---------+---------+---------+---------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+-------------+
|Class|Time_Scaled|V1_Scaled|V2_Scaled|V3_Scaled|V4_Scaled|V5_Scaled|V6_Scaled|V7_Scaled|V8_Scaled|V9_Scaled|V10_Scaled|V11_Scaled|V12_Scaled|V13_Scaled|V14_Scaled|V15_Scaled|V16_Scaled|V17_Scaled|V18_Scaled|V19_Scaled|V20_Scaled|V21_Scaled|V22_Scaled|V23_Scaled|V24_Scaled|V25_Scaled|V26_Scaled|V27_Scaled|V28_Scaled|Amount_Scaled|
+-----+-----------+---------+---------+---------+---------+---------+---------+---------+---------+---------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+-------------+
|

                                                                                

In [9]:
# check the spark shape
print(f"Number of columns: {len(dfscaled.columns)}")
print(f"Number of Records: {dfscaled.count()}")

Number of columns: 31




Number of Records: 283726


                                                                                

In [12]:
# build a data split: 80/20
train, test = dfscaled.randomSplit(weights=[0.8, 0.2], seed=42)


In [11]:
print('Train shape: ', (train.count(), len(train.columns)))
print('Test shape: ', (test.count(), len(test.columns)))

[Stage 199:>                                                        (0 + 8) / 9]

23/05/10 21:43:42 ERROR Executor: Exception in task 0.0 in stage 199.0 (TID 602)
org.apache.spark.memory.SparkOutOfMemoryError: Unable to acquire 260 bytes of memory, got 0
	at org.apache.spark.memory.MemoryConsumer.throwOom(MemoryConsumer.java:158)
	at org.apache.spark.memory.MemoryConsumer.allocatePage(MemoryConsumer.java:118)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.acquireNewPageIfNecessary(UnsafeExternalSorter.java:431)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.allocateMemoryForRecordIfNecessary(UnsafeExternalSorter.java:450)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.insertRecord(UnsafeExternalSorter.java:485)
	at org.apache.spark.sql.execution.UnsafeExternalRowSorter.insertRow(UnsafeExternalRowSorter.java:138)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage32.sort_addToSorter_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedCl

[Stage 199:>                                                        (0 + 8) / 9]

23/05/10 21:43:43 WARN TaskSetManager: Lost task 4.0 in stage 199.0 (TID 606) (rhaydrick.home executor driver): TaskKilled (Stage cancelled)
23/05/10 21:43:43 WARN TaskSetManager: Lost task 1.0 in stage 199.0 (TID 603) (rhaydrick.home executor driver): TaskKilled (Stage cancelled)


[Stage 199:>                                                        (0 + 6) / 9]

23/05/10 21:43:43 WARN TaskSetManager: Lost task 6.0 in stage 199.0 (TID 608) (rhaydrick.home executor driver): TaskKilled (Stage cancelled)
23/05/10 21:43:43 WARN TaskSetManager: Lost task 3.0 in stage 199.0 (TID 605) (rhaydrick.home executor driver): TaskKilled (Stage cancelled)
23/05/10 21:43:43 WARN TaskSetManager: Lost task 7.0 in stage 199.0 (TID 609) (rhaydrick.home executor driver): TaskKilled (Stage cancelled)


[Stage 199:>                                                        (0 + 3) / 9]

23/05/10 21:43:44 WARN TaskSetManager: Lost task 2.0 in stage 199.0 (TID 604) (rhaydrick.home executor driver): TaskKilled (Stage cancelled)
23/05/10 21:43:44 WARN TaskSetManager: Lost task 5.0 in stage 199.0 (TID 607) (rhaydrick.home executor driver): TaskKilled (Stage cancelled)


[Stage 199:>                                                        (0 + 1) / 9]

Py4JJavaError: An error occurred while calling o2693.count.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 199.0 failed 1 times, most recent failure: Lost task 0.0 in stage 199.0 (TID 602) (rhaydrick.home executor driver): org.apache.spark.memory.SparkOutOfMemoryError: Unable to acquire 260 bytes of memory, got 0
	at org.apache.spark.memory.MemoryConsumer.throwOom(MemoryConsumer.java:158)
	at org.apache.spark.memory.MemoryConsumer.allocatePage(MemoryConsumer.java:118)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.acquireNewPageIfNecessary(UnsafeExternalSorter.java:431)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.allocateMemoryForRecordIfNecessary(UnsafeExternalSorter.java:450)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.insertRecord(UnsafeExternalSorter.java:485)
	at org.apache.spark.sql.execution.UnsafeExternalRowSorter.insertRow(UnsafeExternalRowSorter.java:138)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage32.sort_addToSorter_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage32.hashAgg_doAggregateWithoutKey_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage32.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:760)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:140)
	at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52)
	at org.apache.spark.scheduler.Task.run(Task.scala:136)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	at java.base/java.lang.Thread.run(Thread.java:829)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2672)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2608)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2607)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2607)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1182)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1182)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1182)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2860)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2802)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2791)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
Caused by: org.apache.spark.memory.SparkOutOfMemoryError: Unable to acquire 260 bytes of memory, got 0
	at org.apache.spark.memory.MemoryConsumer.throwOom(MemoryConsumer.java:158)
	at org.apache.spark.memory.MemoryConsumer.allocatePage(MemoryConsumer.java:118)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.acquireNewPageIfNecessary(UnsafeExternalSorter.java:431)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.allocateMemoryForRecordIfNecessary(UnsafeExternalSorter.java:450)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.insertRecord(UnsafeExternalSorter.java:485)
	at org.apache.spark.sql.execution.UnsafeExternalRowSorter.insertRow(UnsafeExternalRowSorter.java:138)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage32.sort_addToSorter_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage32.hashAgg_doAggregateWithoutKey_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage32.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:760)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:140)
	at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52)
	at org.apache.spark.scheduler.Task.run(Task.scala:136)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	at java.base/java.lang.Thread.run(Thread.java:829)


In [13]:
# get feature columns names
feature_columns = [col for col in train.columns if col!= 'Class']
#print(feature_columns)
#print(len(feature_columns))

In [14]:
# vectorize
vectorizer = VectorAssembler(inputCols=feature_columns, outputCol="features")
train_vec = vectorizer.transform(train)
test_vec = vectorizer.transform(test)

#### Scores on the TRAINING AND TEST SET:

In [15]:
from pyspark.sql.types import StringType, BooleanType, IntegerType, FloatType, DateType, DoubleType
from sklearn.metrics import ConfusionMatrixDisplay
np.set_printoptions(suppress=True)

# instantiate Models

# regression
lr = LogisticRegression(
    featuresCol='features',
    labelCol='Class',
    predictionCol='Class_Prediction',
    maxIter=10,
    regParam=0.3,
    elasticNetParam=0.8
)

# decison tree
dt = DecisionTreeClassifier(featuresCol='features',
    labelCol='Class',
    predictionCol='Class_Prediction'
)

# random forest
rf = RandomForestClassifier(
    featuresCol='features',
    labelCol='Class',
    predictionCol='Class_Prediction'
)

# gradient - boosted tree
gbt = GBTClassifier(
    featuresCol='features',
    labelCol='Class',
    predictionCol='Class_Prediction'
)

# linear support vector machines
lsvc = LinearSVC(
    featuresCol='features',
    labelCol='Class',
    predictionCol='Class_Prediction'
)

# naive bayes
#nb = NaiveBayes(
    #featuresCol='features',
    #labelCol='Class',
    #predictionCol='Class_Prediction'
    #smoothing=1.0, 
    #modelType="multinomial"
#)




# create list of models
list_of_models = [lr, dt, rf, gbt, lsvc]
list_of_model_names = ['Logistic Regression', 'Decision Tree', 'Random Forest', 'Gradient-Boosted Tree', 'Linear Support Vector Machines']

# go through list
for model, model_name in zip(list_of_models, list_of_model_names):

    # print current model
    print('Current model: ', model_name)

    # create a pipeline object
    pipeline = Pipeline(stages=[model])

    # fit pipeline
    pipeline_model = pipeline.fit(train_vec)

    # get scores on the training set
    train_pred = pipeline_model.transform(train_vec)

    # get scores on the test set
    test_pred = pipeline_model.transform(test_vec)

    # get accuracy on train and test set
    accuracy_evaluator = MulticlassClassificationEvaluator(predictionCol='Class_Prediction', labelCol='Class', metricName='accuracy')
    accuracy_score_train = accuracy_evaluator.evaluate(train_pred)
    accuracy_score_test = accuracy_evaluator.evaluate(test_pred)
    print('Accuracy on Train: ', accuracy_score_train)
    print('Accuracy on Test: ', accuracy_score_test)

    # get precision on train and test set
    precision_evaluator = MulticlassClassificationEvaluator(predictionCol='Class_Prediction', labelCol='Class', metricName='precisionByLabel')
    precision_score_train = precision_evaluator.evaluate(train_pred)
    precision_score_test = precision_evaluator.evaluate(test_pred)
    print('Precision on Train: ', precision_score_train)
    print('Precision on Test: ', precision_score_test)

    # get recall on train and test set
    recall_evaluator = MulticlassClassificationEvaluator(predictionCol='Class_Prediction', labelCol='Class', metricName='recallByLabel')
    recall_score_train = recall_evaluator.evaluate(train_pred)
    recall_score_test = recall_evaluator.evaluate(test_pred)
    print('Recall on Train: ', recall_score_train)
    print('Recall on Test: ', recall_score_test)

    # get f1-score on train and test set
    f1_evaluator = MulticlassClassificationEvaluator(predictionCol='Class_Prediction', labelCol='Class', metricName='f1')
    f1_score_train = f1_evaluator.evaluate(train_pred)
    f1_score_test = f1_evaluator.evaluate(test_pred)
    print('F1-score on Train: ', f1_score_train)
    print('F1-score on Test: ', f1_score_test)

    # get confusion matrix on train set
    preds_and_labels_train = train_pred.withColumn("Class_Prediction", train_pred["Class_Prediction"].cast(DoubleType())).withColumn("Class", train_pred["Class"].cast(DoubleType()))
    preds_and_labels_train = preds_and_labels_train.select(['Class_Prediction', 'Class'])
    metrics_train = MulticlassMetrics(preds_and_labels_train.rdd)
    cm_arr_train = metrics_train.confusionMatrix().toArray().astype(float)
    cm_disp_train = ConfusionMatrixDisplay(confusion_matrix=cm_arr_train)
    print('Confusion Matrix on Train set:')
    cm_disp_train.plot()
    plt.show()

    # get confusion matrix on test set
    preds_and_labels_test = test_pred.withColumn("Class_Prediction", test_pred["Class_Prediction"].cast(DoubleType())).withColumn("Class", test_pred["Class"].cast(DoubleType()))
    preds_and_labels_test = preds_and_labels_test.select(['Class_Prediction', 'Class'])
    metrics_test = MulticlassMetrics(preds_and_labels_test.rdd)
    cm_arr_test = metrics_test.confusionMatrix().toArray().astype(float)
    cm_disp_test = ConfusionMatrixDisplay(confusion_matrix=cm_arr_test)
    print('Confusion Matrix on Test set:')
    cm_disp_test.plot()
    plt.show()


    

Current model:  Logistic Regression


[Stage 203:>                                                        (0 + 8) / 9]

23/05/10 21:47:00 ERROR Executor: Exception in task 4.0 in stage 203.0 (TID 631)
org.apache.spark.memory.SparkOutOfMemoryError: Unable to acquire 260 bytes of memory, got 0
	at org.apache.spark.memory.MemoryConsumer.throwOom(MemoryConsumer.java:158)
	at org.apache.spark.memory.MemoryConsumer.allocatePage(MemoryConsumer.java:118)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.acquireNewPageIfNecessary(UnsafeExternalSorter.java:431)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.allocateMemoryForRecordIfNecessary(UnsafeExternalSorter.java:450)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.insertRecord(UnsafeExternalSorter.java:485)
	at org.apache.spark.sql.execution.UnsafeExternalRowSorter.insertRow(UnsafeExternalRowSorter.java:138)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage32.sort_addToSorter_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedCl

[Stage 203:>                                                        (0 + 9) / 9]

23/05/10 21:47:00 ERROR TaskSetManager: Task 4 in stage 203.0 failed 1 times; aborting job


[Stage 203:>                                                        (0 + 8) / 9]

23/05/10 21:47:00 ERROR Instrumentation: org.apache.spark.SparkException: Job aborted due to stage failure: Task 4 in stage 203.0 failed 1 times, most recent failure: Lost task 4.0 in stage 203.0 (TID 631) (rhaydrick.home executor driver): org.apache.spark.memory.SparkOutOfMemoryError: Unable to acquire 260 bytes of memory, got 0
	at org.apache.spark.memory.MemoryConsumer.throwOom(MemoryConsumer.java:158)
	at org.apache.spark.memory.MemoryConsumer.allocatePage(MemoryConsumer.java:118)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.acquireNewPageIfNecessary(UnsafeExternalSorter.java:431)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.allocateMemoryForRecordIfNecessary(UnsafeExternalSorter.java:450)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.insertRecord(UnsafeExternalSorter.java:485)
	at org.apache.spark.sql.execution.UnsafeExternalRowSorter.insertRow(UnsafeExternalRowSorter.java:138)
	at org.apache.spark.sql.catal

[Stage 203:>                                                        (0 + 7) / 9]

23/05/10 21:47:01 WARN TaskSetManager: Lost task 3.0 in stage 203.0 (TID 630) (rhaydrick.home executor driver): TaskKilled (Stage cancelled)
23/05/10 21:47:01 WARN TaskSetManager: Lost task 7.0 in stage 203.0 (TID 634) (rhaydrick.home executor driver): TaskKilled (Stage cancelled)
23/05/10 21:47:01 WARN TaskSetManager: Lost task 6.0 in stage 203.0 (TID 633) (rhaydrick.home executor driver): TaskKilled (Stage cancelled)
23/05/10 21:47:01 WARN TaskSetManager: Lost task 1.0 in stage 203.0 (TID 628) (rhaydrick.home executor driver): TaskKilled (Stage cancelled)
23/05/10 21:47:01 WARN TaskSetManager: Lost task 5.0 in stage 203.0 (TID 632) (rhaydrick.home executor driver): TaskKilled (Stage cancelled)


[Stage 203:>                                                        (0 + 1) / 9]

23/05/10 21:47:01 WARN TaskSetManager: Lost task 2.0 in stage 203.0 (TID 629) (rhaydrick.home executor driver): TaskKilled (Stage cancelled)


Py4JJavaError: An error occurred while calling o2783.fit.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 4 in stage 203.0 failed 1 times, most recent failure: Lost task 4.0 in stage 203.0 (TID 631) (rhaydrick.home executor driver): org.apache.spark.memory.SparkOutOfMemoryError: Unable to acquire 260 bytes of memory, got 0
	at org.apache.spark.memory.MemoryConsumer.throwOom(MemoryConsumer.java:158)
	at org.apache.spark.memory.MemoryConsumer.allocatePage(MemoryConsumer.java:118)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.acquireNewPageIfNecessary(UnsafeExternalSorter.java:431)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.allocateMemoryForRecordIfNecessary(UnsafeExternalSorter.java:450)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.insertRecord(UnsafeExternalSorter.java:485)
	at org.apache.spark.sql.execution.UnsafeExternalRowSorter.insertRow(UnsafeExternalRowSorter.java:138)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage32.sort_addToSorter_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage32.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:760)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
	at scala.collection.TraversableOnce.foldLeft(TraversableOnce.scala:199)
	at scala.collection.TraversableOnce.foldLeft$(TraversableOnce.scala:192)
	at scala.collection.AbstractIterator.foldLeft(Iterator.scala:1431)
	at scala.collection.TraversableOnce.aggregate(TraversableOnce.scala:260)
	at scala.collection.TraversableOnce.aggregate$(TraversableOnce.scala:260)
	at scala.collection.AbstractIterator.aggregate(Iterator.scala:1431)
	at org.apache.spark.rdd.RDD.$anonfun$treeAggregate$4(RDD.scala:1236)
	at org.apache.spark.rdd.RDD.$anonfun$treeAggregate$6(RDD.scala:1237)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:855)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2$adapted(RDD.scala:855)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
	at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52)
	at org.apache.spark.scheduler.Task.run(Task.scala:136)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	at java.base/java.lang.Thread.run(Thread.java:829)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2672)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2608)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2607)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2607)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1182)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1182)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1182)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2860)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2802)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2791)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:952)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2238)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2333)
	at org.apache.spark.rdd.RDD.$anonfun$fold$1(RDD.scala:1174)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:406)
	at org.apache.spark.rdd.RDD.fold(RDD.scala:1168)
	at org.apache.spark.rdd.RDD.$anonfun$treeAggregate$2(RDD.scala:1267)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:406)
	at org.apache.spark.rdd.RDD.treeAggregate(RDD.scala:1228)
	at org.apache.spark.rdd.RDD.$anonfun$treeAggregate$1(RDD.scala:1214)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:406)
	at org.apache.spark.rdd.RDD.treeAggregate(RDD.scala:1214)
	at org.apache.spark.ml.stat.Summarizer$.getClassificationSummarizers(Summarizer.scala:233)
	at org.apache.spark.ml.classification.LogisticRegression.$anonfun$train$1(LogisticRegression.scala:512)
	at org.apache.spark.ml.util.Instrumentation$.$anonfun$instrumented$1(Instrumentation.scala:191)
	at scala.util.Try$.apply(Try.scala:213)
	at org.apache.spark.ml.util.Instrumentation$.instrumented(Instrumentation.scala:191)
	at org.apache.spark.ml.classification.LogisticRegression.train(LogisticRegression.scala:496)
	at org.apache.spark.ml.classification.LogisticRegression.train(LogisticRegression.scala:286)
	at org.apache.spark.ml.Predictor.fit(Predictor.scala:151)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:566)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:829)
Caused by: org.apache.spark.memory.SparkOutOfMemoryError: Unable to acquire 260 bytes of memory, got 0
	at org.apache.spark.memory.MemoryConsumer.throwOom(MemoryConsumer.java:158)
	at org.apache.spark.memory.MemoryConsumer.allocatePage(MemoryConsumer.java:118)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.acquireNewPageIfNecessary(UnsafeExternalSorter.java:431)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.allocateMemoryForRecordIfNecessary(UnsafeExternalSorter.java:450)
	at org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter.insertRecord(UnsafeExternalSorter.java:485)
	at org.apache.spark.sql.execution.UnsafeExternalRowSorter.insertRow(UnsafeExternalRowSorter.java:138)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage32.sort_addToSorter_0$(Unknown Source)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage32.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:760)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at scala.collection.AbstractIterator.foreach(Iterator.scala:1431)
	at scala.collection.TraversableOnce.foldLeft(TraversableOnce.scala:199)
	at scala.collection.TraversableOnce.foldLeft$(TraversableOnce.scala:192)
	at scala.collection.AbstractIterator.foldLeft(Iterator.scala:1431)
	at scala.collection.TraversableOnce.aggregate(TraversableOnce.scala:260)
	at scala.collection.TraversableOnce.aggregate$(TraversableOnce.scala:260)
	at scala.collection.AbstractIterator.aggregate(Iterator.scala:1431)
	at org.apache.spark.rdd.RDD.$anonfun$treeAggregate$4(RDD.scala:1236)
	at org.apache.spark.rdd.RDD.$anonfun$treeAggregate$6(RDD.scala:1237)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:855)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2$adapted(RDD.scala:855)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
	at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52)
	at org.apache.spark.scheduler.Task.run(Task.scala:136)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	... 1 more


#### Conclusions: Imbalanced dataset
As expected when working with an imbalanced, binary-targeted dataset, the predictions have excellent values for every metric. This happens because, in this particular case, 99% of records have a value of '0 - not fraud', which means that the model is basically training to predict 0s most of the time. Since the train and test scores have similar values, we can infere that overfitting is not occuring.

### Working with the balanced dataset

First we have to balance the data set. Here we use a random undersampling approach

In [None]:
# build a data split: 80/20
balanced_train, balanced_test = dfscaled.randomSplit(weights=[0.8, 0.2], seed=42)
print('Train shape: ', (balanced_train.count(), len(balanced_train.columns)))
print('Test shape: ', (balanced_test.count(), len(balanced_test.columns)))

In [None]:
# balancing the dataset

# select fraud and non-fraud transactions and limit non-fraud transactions to the same number as fraud transactions
fraud_data = balanced_train.filter(f.col('Class') == 1)
non_fraud_data = balanced_train.filter(f.col('Class') == 0).limit(fraud_data.count())

# Combine fraud and non-fraud transactions and shuffle the data
balanced_data_train = fraud_data.union(non_fraud_data).orderBy(f.rand())

# Show 5 rows of the shuffled, balanced data
balanced_data_train.show(5)

In [None]:
# get feature columns names
balanced_feature_columns = [col for col in balanced_data_train.columns if col!= 'Class']
print(balanced_feature_columns)
print(len(balanced_feature_columns))

In [None]:
# vectorize
balanced_vectorizer = VectorAssembler(inputCols=balanced_feature_columns, outputCol="features")
balanced_train_vec = balanced_vectorizer.transform(balanced_data_train)
balanced_test_vec = balanced_vectorizer.transform(balanced_test)

In [None]:
# instantiate Models

# regression
balanced_lr = LogisticRegression(
    featuresCol='features',
    labelCol='Class',
    predictionCol='Class_Prediction',
    maxIter=10,
    regParam=0.3,
    elasticNetParam=0.8
)

# decison tree
balanced_dt = DecisionTreeClassifier(featuresCol='features',
    labelCol='Class',
    predictionCol='Class_Prediction'
)

# random forest
balanced_rf = RandomForestClassifier(
    featuresCol='features',
    labelCol='Class',
    predictionCol='Class_Prediction'
)

# gradient - boosted tree
balanced_gbt = GBTClassifier(
    featuresCol='features',
    labelCol='Class',
    predictionCol='Class_Prediction'
)

# linear support vector machines
balanced_lsvc = LinearSVC(
    featuresCol='features',
    labelCol='Class',
    predictionCol='Class_Prediction'
)

# naive bayes
#balanced_nb = NaiveBayes(
    #featuresCol='features',
    #labelCol='Class',
    #predictionCol='Class_Prediction'
    #smoothing=1.0, 
    #modelType="multinomial"
#)




# create list of models
balanced_list_of_models = [balanced_lr, balanced_dt, balanced_rf, balanced_gbt, balanced_lsvc]
balanced_list_of_model_names = ['Logistic Regression', 'Decision Tree', 'Random Forest', 'Gradient-Boosted Tree', 'Linear Support Vector Machines']

# go through list
for balanced_model, balanced_model_name in zip(balanced_list_of_models, balanced_list_of_model_names):

    # print current model
    print('Current model: ', balanced_model_name)

    # create a pipeline object
    balanced_pipeline = Pipeline(stages=[balanced_model])

    # fit pipeline
    balanced_pipeline_model = balanced_pipeline.fit(balanced_train_vec)

    # get scores on the training set
    balanced_train_pred = balanced_pipeline_model.transform(balanced_train_vec)

    # get scores on the test set
    balanced_test_pred = balanced_pipeline_model.transform(balanced_test_vec)

    # get accuracy on train and test set
    balanced_accuracy_evaluator = MulticlassClassificationEvaluator(predictionCol='Class_Prediction', labelCol='Class', metricName='accuracy')
    balanced_accuracy_score_train = balanced_accuracy_evaluator.evaluate(balanced_train_pred)
    balanced_accuracy_score_test = balanced_accuracy_evaluator.evaluate(balanced_test_pred)
    print('Accuracy on Train: ', balanced_accuracy_score_train)
    print('Accuracy on Test: ', balanced_accuracy_score_test)

    # get precision on train and test set
    balanced_precision_evaluator = MulticlassClassificationEvaluator(predictionCol='Class_Prediction', labelCol='Class', metricName='precisionByLabel')
    balanced_precision_score_train = balanced_precision_evaluator.evaluate(balanced_train_pred)
    balanced_precision_score_test = balanced_precision_evaluator.evaluate(balanced_test_pred)
    print('Precision on Train: ', balanced_precision_score_train)
    print('Precision on Test: ', balanced_precision_score_test)

    # get recall on train and test set
    balanced_recall_evaluator = MulticlassClassificationEvaluator(predictionCol='Class_Prediction', labelCol='Class', metricName='recallByLabel')
    balanced_recall_score_train = balanced_recall_evaluator.evaluate(balanced_train_pred)
    balanced_recall_score_test = balanced_recall_evaluator.evaluate(balanced_test_pred)
    print('Recall on Train: ', balanced_recall_score_train)
    print('Recall on Test: ', balanced_recall_score_test)

    # get f1-score on train and test set
    balanced_f1_evaluator = MulticlassClassificationEvaluator(predictionCol='Class_Prediction', labelCol='Class', metricName='f1')
    balanced_f1_score_train = balanced_f1_evaluator.evaluate(balanced_train_pred)
    balanced_f1_score_test = balanced_f1_evaluator.evaluate(balanced_test_pred)
    print('F1-score on Train: ', balanced_f1_score_train)
    print('F1-score on Test: ', balanced_f1_score_test)

    # get confusion matrix on train set
    balanced_preds_and_labels_train = balanced_train_pred.withColumn("Class_Prediction", balanced_train_pred["Class_Prediction"].cast(DoubleType())).withColumn("Class", balanced_train_pred["Class"].cast(DoubleType()))
    balanced_preds_and_labels_train = balanced_preds_and_labels_train.select(['Class_Prediction', 'Class'])
    balanced_metrics_train = MulticlassMetrics(balanced_preds_and_labels_train.rdd)
    balanced_cm_arr_train = balanced_metrics_train.confusionMatrix().toArray().astype(float)
    balanced_cm_disp_train = ConfusionMatrixDisplay(confusion_matrix=balanced_cm_arr_train)
    print('Confusion Matrix on Train set:')
    balanced_cm_disp_train.plot()
    plt.show()

    # get confusion matrix on test set
    balanced_preds_and_labels_test = balanced_test_pred.withColumn("Class_Prediction", balanced_test_pred["Class_Prediction"].cast(DoubleType())).withColumn("Class", balanced_test_pred["Class"].cast(DoubleType()))
    balanced_preds_and_labels_test = balanced_preds_and_labels_test.select(['Class_Prediction', 'Class'])
    balanced_metrics_test = MulticlassMetrics(balanced_preds_and_labels_test.rdd)
    balanced_cm_arr_test = balanced_metrics_test.confusionMatrix().toArray().astype(float)
    balanced_cm_disp_test = ConfusionMatrixDisplay(confusion_matrix=balanced_cm_arr_test)
    print('Confusion Matrix on Test set:')
    balanced_cm_disp_test.plot()
    plt.show()


#### Conclusions: Balanced dataset
We can see that, in terms of performance, Logistic Regression was the model that obtained the worst results, having 0.74 as the value for precision on the test set. The remaining models obtained better results, with values ranging from 0.88 to 1. Overall, the metrics' values were lower when applying the ML pipeline on the balanced dataset.