In [1]:
import numpy as np
import pandas as pd
import seaborn as sns

import re
import datetime

from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import *
from pyspark.sql.types import DateType

In [71]:
sp = (
    SparkSession.builder.appName("Model")
    .config("spark.sql.session.timeZone", "+11")
    .config("spark.driver.memory", "6g")
    .config("spark.executor.memory", "8g")
    .getOrCreate()
)

ConnectionRefusedError: [Errno 111] Connection refused

In [3]:
transactions = sp.read.option("inferSchema", True).parquet("../data/processed/transactions")
merchants = sp.read.option("inferSchema", True).parquet("../data/processed/merchants")
customers = sp.read.option("inferSchema", True).parquet("../data/processed/customers")

                                                                                

In [4]:
transactions.show(1)
merchants.show(1)
customers.show(1)

+--------+-------+------------+------------+--------------+-----------+-----------------+-------+----------+-----+---------+
|order_id|user_id|merchant_abn|dollar_value|order_datetime|Natural_var|Potential_Outlier|holiday|dayofmonth|month|dayofweek|
+--------+-------+------------+------------+--------------+-----------+-----------------+-------+----------+-----+---------+
|       3|      3| 60956456424|      136.68|    2021-08-20|          0|                0|      0|        20|    8|        6|
+--------+-------+------------+------------+--------------+-----------+-----------------+-------+----------+-----+---------+
only showing top 1 row

+------------+-----------------+--------------------+---------------+---------------+----------------+-----------------+
|merchant_abn|             name|                tags|avg_monthly_inc|monthly_entropy|postcode_entropy|          revenue|
+------------+-----------------+--------------------+---------------+---------------+----------------+-------

In [5]:
final = transactions.join(merchants, on="merchant_abn").join(customers, on="user_id")
final.show(2)

+-------+------------+--------+------------+--------------+-----------+-----------------+-------+----------+-----+---------+--------------------+--------------------+---------------+---------------+----------------+------------------+-----+--------+------+--------------------------------------------------+------------------------------+-----------------------------+-------------------------------+---------------------+-----------------------+----------------------+------------------------+--------------+----------------+---------------+----------------------------+---------------------------+------------------------+-----------------------+-------------------------------------+---------------------------+-----------------------------+----------------------------+---------------------------------------+-----------------------------+-------------------------------+------------------------------+-----------------------+-------------+---------------+--------------+-------------------------

In [6]:
final.count()

                                                                                

13614648

### Dropping Columns

In [7]:
final.printSchema()

root
 |-- user_id: long (nullable = true)
 |-- merchant_abn: long (nullable = true)
 |-- order_id: long (nullable = true)
 |-- dollar_value: float (nullable = true)
 |-- order_datetime: date (nullable = true)
 |-- Natural_var: integer (nullable = true)
 |-- Potential_Outlier: integer (nullable = true)
 |-- holiday: long (nullable = true)
 |-- dayofmonth: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- dayofweek: integer (nullable = true)
 |-- name: string (nullable = true)
 |-- tags: string (nullable = true)
 |-- avg_monthly_inc: float (nullable = true)
 |-- monthly_entropy: float (nullable = true)
 |-- postcode_entropy: float (nullable = true)
 |-- revenue: double (nullable = true)
 |-- state: string (nullable = true)
 |-- postcode: long (nullable = true)
 |-- gender: string (nullable = true)
 |-- Number of individuals lodging an income tax return: long (nullable = true)
 |-- Average taxable income or loss: long (nullable = true)
 |-- Median taxable income or loss

In [13]:
final = final.drop("user_id", "merchant_abn", "order_id", "order_datetime", "name")
final.printSchema()

root
 |-- dollar_value: float (nullable = true)
 |-- Natural_var: integer (nullable = true)
 |-- Potential_Outlier: integer (nullable = true)
 |-- holiday: long (nullable = true)
 |-- dayofmonth: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- dayofweek: integer (nullable = true)
 |-- tags: string (nullable = true)
 |-- avg_monthly_inc: float (nullable = true)
 |-- monthly_entropy: float (nullable = true)
 |-- postcode_entropy: float (nullable = true)
 |-- revenue: double (nullable = true)
 |-- state: string (nullable = true)
 |-- postcode: long (nullable = true)
 |-- gender: string (nullable = true)
 |-- Number of individuals lodging an income tax return: long (nullable = true)
 |-- Average taxable income or loss: long (nullable = true)
 |-- Median taxable income or loss: long (nullable = true)
 |-- Proportion with salary or wages: long (nullable = true)
 |-- Count salary or wages: long (nullable = true)
 |-- Average salary or wages: long (nullable = true)
 |-- Me

### PROCESSING CUSTOMER FRAUD DATA

In [8]:
c_fraud = sp.read.option("inferSchema", True).parquet("../data/curated/customer_fraud")
c_fraud = c_fraud.withColumn("order_datetime", col("order_datetime").cast(DateType()))
c_fraud.show(2)

+-------+--------------+-----------------+
|user_id|order_datetime|fraud_probability|
+-------+--------------+-----------------+
|   6228|    2021-12-19|         97.62981|
|  21419|    2021-12-10|         99.24738|
+-------+--------------+-----------------+
only showing top 2 rows



In [9]:
transactions.show(2)

+--------+-------+------------+------------+--------------+-----------+-----------------+-------+----------+-----+---------+
|order_id|user_id|merchant_abn|dollar_value|order_datetime|Natural_var|Potential_Outlier|holiday|dayofmonth|month|dayofweek|
+--------+-------+------------+------------+--------------+-----------+-----------------+-------+----------+-----+---------+
|       3|      3| 60956456424|      136.68|    2021-08-20|          0|                0|      0|        20|    8|        6|
|       8|  18482| 70501974849|       68.75|    2021-08-20|          0|                0|      0|        20|    8|        6|
+--------+-------+------------+------------+--------------+-----------+-----------------+-------+----------+-----+---------+
only showing top 2 rows



In [10]:
c_fraud_full = transactions.join(c_fraud, on=["user_id", "order_datetime"])
c_fraud_full.show(2)

+-------+--------------+--------+------------+------------+-----------+-----------------+-------+----------+-----+---------+-----------------+
|user_id|order_datetime|order_id|merchant_abn|dollar_value|Natural_var|Potential_Outlier|holiday|dayofmonth|month|dayofweek|fraud_probability|
+-------+--------------+--------+------------+------------+-----------+-----------------+-------+----------+-----+---------+-----------------+
|    448|    2021-08-20|    1005| 94380689142|     6263.03|          0|                0|      0|        20|    8|        6|        14.681704|
|   3116|    2021-08-20|    6989| 22248828825|     3958.86|          0|                0|      0|        20|    8|        6|         8.809071|
+-------+--------------+--------+------------+------------+-----------+-----------------+-------+----------+-----+---------+-----------------+
only showing top 2 rows



In [11]:
c_fraud_full.count()

                                                                                

80560

In [12]:
X = c_fraud_full.join(merchants, on="merchant_abn").join(customers, on="user_id")
X.show(1)

+-------+------------+--------------+--------+------------+-----------+-----------------+-------+----------+-----+---------+-----------------+-----------+--------------------+---------------+---------------+----------------+----------------+-----+--------+------+--------------------------------------------------+------------------------------+-----------------------------+-------------------------------+---------------------+-----------------------+----------------------+------------------------+--------------+----------------+---------------+----------------------------+---------------------------+------------------------+-----------------------+-------------------------------------+---------------------------+-----------------------------+----------------------------+---------------------------------------+-----------------------------+-------------------------------+------------------------------+-----------------------+-------------+---------------+--------------+------------------

In [13]:
X = X.drop("user_id", "merchant_abn", "order_datetime", "order_id", "name")
X.printSchema()

root
 |-- dollar_value: float (nullable = true)
 |-- Natural_var: integer (nullable = true)
 |-- Potential_Outlier: integer (nullable = true)
 |-- holiday: long (nullable = true)
 |-- dayofmonth: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- dayofweek: integer (nullable = true)
 |-- fraud_probability: float (nullable = true)
 |-- tags: string (nullable = true)
 |-- avg_monthly_inc: float (nullable = true)
 |-- monthly_entropy: float (nullable = true)
 |-- postcode_entropy: float (nullable = true)
 |-- revenue: double (nullable = true)
 |-- state: string (nullable = true)
 |-- postcode: long (nullable = true)
 |-- gender: string (nullable = true)
 |-- Number of individuals lodging an income tax return: long (nullable = true)
 |-- Average taxable income or loss: long (nullable = true)
 |-- Median taxable income or loss: long (nullable = true)
 |-- Proportion with salary or wages: long (nullable = true)
 |-- Count salary or wages: long (nullable = true)
 |-- Average

Categorical
- holiday (done)
- dayofmonth ?
- dayofweek
- month (done)
- tags
- state
- gender
- postcode

In [14]:
from pyspark.ml.feature import StringIndexer, OneHotEncoder

from pyspark.ml import Pipeline
from pyspark.sql.types import FloatType
from pyspark.mllib.evaluation import MulticlassMetrics
from pyspark.ml.regression import GeneralizedLinearRegression, GBTRegressor
from pyspark.ml.classification import MultilayerPerceptronClassifier
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler, StandardScaler
from pyspark.ml.evaluation import MulticlassClassificationEvaluator, RegressionEvaluator

def category_processing(data: DataFrame, outcome: str):
    categories = [
        "dayofmonth",
        "dayofweek",
        "month",
        "tags",
        "state",
        "gender",
        "postcode"
    ]

    # Pipeline
    indexers = [StringIndexer(inputCol=c, outputCol=c+"_index") for c in categories]
    encoders = [OneHotEncoder(inputCol=c+"_index", outputCol=c+"_encoded") for c in categories]
    transformed = Pipeline(stages=indexers + encoders).fit(data).transform(data)

    for c in categories:
        transformed = transformed.drop(c).drop(c+"_index")
    return transformed

In [16]:
category_processed = category_processing(X, "outcome")
category_processed.show(1)

                                                                                

+------------+-----------+-----------------+-------+-----------------+---------------+---------------+----------------+----------------+--------------------------------------------------+------------------------------+-----------------------------+-------------------------------+---------------------+-----------------------+----------------------+------------------------+--------------+----------------+---------------+----------------------------+---------------------------+------------------------+-----------------------+-------------------------------------+---------------------------+-----------------------------+----------------------------+---------------------------------------+-----------------------------+-------------------------------+------------------------------+-----------------------+-------------+---------------+--------------+----------------------------------+------------------------------------+-----------------------------------+------------------+-----------------+

## TRAIN TEST SPLIT

In [17]:
y = X.select("fraud_probability")
y.count()

                                                                                

71813

In [19]:
y.describe().show()



+-------+------------------+
|summary| fraud_probability|
+-------+------------------+
|  count|             71813|
|   mean|14.717076242080452|
| stddev| 9.404555316490011|
|    min|          8.287144|
|    max|          97.62981|
+-------+------------------+



                                                                                

In [45]:
from pyspark.ml.feature import Bucketizer

buckets = Bucketizer(splits=[0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100], inputCol="fraud_probability", outputCol="fraud_buckets")
X_bucks = buckets.transform(category_processed)

X_bucks.show(1)

+------------+-----------+-----------------+-------+-----------------+---------------+---------------+----------------+----------------+--------------------------------------------------+------------------------------+-----------------------------+-------------------------------+---------------------+-----------------------+----------------------+------------------------+--------------+----------------+---------------+----------------------------+---------------------------+------------------------+-----------------------+-------------------------------------+---------------------------+-----------------------------+----------------------------+---------------------------------------+-----------------------------+-------------------------------+------------------------------+-----------------------+-------------+---------------+--------------+----------------------------------+------------------------------------+-----------------------------------+------------------+-----------------+

In [46]:
X_bucks.groupBy("fraud_buckets").count().orderBy("fraud_buckets").show()

[Stage 350:>                                                        (0 + 8) / 9]

+-------------+-----+
|fraud_buckets|count|
+-------------+-----+
|          0.0|22923|
|          1.0|38611|
|          2.0| 6113|
|          3.0| 1934|
|          4.0|  990|
|          5.0|  576|
|          6.0|  360|
|          7.0|  193|
|          8.0|  102|
|          9.0|   11|
+-------------+-----+



                                                                                

In [47]:
from functools import reduce

X_bucks_9 = X_bucks.filter(X_bucks["fraud_buckets"] == 9.0)
X_bucks_8 = X_bucks.filter(X_bucks["fraud_buckets"] == 8.0)
X_bucks_7 = X_bucks.filter(X_bucks["fraud_buckets"] == 7.0)
X_bucks_6 = X_bucks.filter(X_bucks["fraud_buckets"] == 6.0)
X_bucks_5 = X_bucks.filter(X_bucks["fraud_buckets"] == 5.0)
X_bucks_4 = X_bucks.filter(X_bucks["fraud_buckets"] == 4.0)
X_bucks_3 = X_bucks.filter(X_bucks["fraud_buckets"] == 3.0)

X_bucks_9_over = X_bucks_9.sample(withReplacement=True, fraction=350.0, seed=69)
X_bucks_8_over = X_bucks_8.sample(withReplacement=True, fraction=35.0, seed=69)
X_bucks_7_over = X_bucks_7.sample(withReplacement=True, fraction=20.0, seed=69)
X_bucks_6_over = X_bucks_6.sample(withReplacement=True, fraction=15.0, seed=69)
X_bucks_5_over = X_bucks_5.sample(withReplacement=True, fraction=7.0, seed=69)
X_bucks_4_over = X_bucks_4.sample(withReplacement=True, fraction=4.0, seed=69)
X_bucks_3_over = X_bucks_3.sample(withReplacement=True, fraction=2.0, seed=69)

X_adjusted = reduce(DataFrame.unionAll, [X_bucks_9_over, X_bucks_8_over, X_bucks_7_over, X_bucks_6_over, X_bucks_5_over, X_bucks_4_over, X_bucks_3_over])
X_adjusted = reduce(DataFrame.unionAll, [X_adjusted, X_bucks.filter(X_bucks["fraud_buckets"] == 2.0), X_bucks.filter(X_bucks["fraud_buckets"] == 1.0), X_bucks.filter(X_bucks["fraud_buckets"] == 0.0)])
X_adjusted.count()

                                                                                

96272

In [48]:
X_adjusted.groupBy("fraud_buckets").count().orderBy("fraud_buckets").show()



+-------------+-----+
|fraud_buckets|count|
+-------------+-----+
|          0.0|22923|
|          1.0|38611|
|          2.0| 6113|
|          3.0| 3910|
|          4.0| 3931|
|          5.0| 4025|
|          6.0| 5450|
|          7.0| 3876|
|          8.0| 3583|
|          9.0| 3850|
+-------------+-----+



                                                                                

In [49]:
train, val, test = X_adjusted.randomSplit([0.7, 0.2, 0.1], seed=69)

print(train.count())
print(val.count())
test.count()

22/10/06 00:23:31 WARN DAGScheduler: Broadcasting large task binary with size 2027.9 KiB


                                                                                

67492
22/10/06 00:24:31 WARN DAGScheduler: Broadcasting large task binary with size 2028.0 KiB


                                                                                

19067
22/10/06 00:25:16 WARN DAGScheduler: Broadcasting large task binary with size 2028.0 KiB


                                                                                

9713

In [52]:
print(test.columns)

['dollar_value', 'Natural_var', 'Potential_Outlier', 'holiday', 'fraud_probability', 'avg_monthly_inc', 'monthly_entropy', 'postcode_entropy', 'revenue', 'Number of individuals lodging an income tax return', 'Average taxable income or loss', 'Median taxable income or loss', 'Proportion with salary or wages', 'Count salary or wages', 'Average salary or wages', 'Median salary or wages', 'Proportion with net rent', 'Count net rent', 'Average net rent', 'Median net rent', 'Average total income or loss', 'Median total income or loss', 'Average total deductions', 'Median total deductions', 'Proportion with total business income', 'Count total business income', 'Average total business income', 'Median total business income', 'Proportion with total business expenses', 'Count total business expenses', 'Average total business expenses', 'Median total business expenses', 'Proportion with net tax', 'Count net tax', 'Average net tax', 'Median net tax', 'Count super total accounts balance', 'Average

In [53]:
def process_numerical(data: DataFrame):
    """
    Function to scale and process numerical columns
    """
    # Scaler
    columns = ['dollar_value', 'avg_monthly_inc',
    'monthly_entropy', 'postcode_entropy', 'revenue', 'Number of individuals lodging an income tax return', 
    'Average taxable income or loss', 'Median taxable income or loss', 'Proportion with salary or wages', 'Count salary or wages', 
    'Average salary or wages', 'Median salary or wages', 'Proportion with net rent', 'Count net rent', 'Average net rent', 
    'Median net rent', 'Average total income or loss', 'Median total income or loss', 'Average total deductions', 
    'Median total deductions', 'Proportion with total business income', 'Count total business income', 
    'Average total business income', 'Median total business income', 'Proportion with total business expenses', 
    'Count total business expenses', 'Average total business expenses', 'Median total business expenses', 
    'Proportion with net tax', 'Count net tax', 'Average net tax', 'Median net tax', 'Count super total accounts balance', 
    'Average super total accounts balance', 'Median super total accounts balance']

    va = VectorAssembler(inputCols=columns, outputCol="to_scale")
    sc = StandardScaler(inputCol="to_scale", outputCol="scaled")

    va_data = va.transform(data)
    data = sc.fit(va_data).transform(va_data)
    
    # Drop other columns
    for c in columns:
        data = data.drop(c)
    return data.drop("to_scale")

In [54]:
train_processed = process_numerical(train)
val_processed = process_numerical(val)
test_processed = process_numerical(test)

train_processed.show(1)
val_processed.show(1)
test_processed.show(1)

22/10/06 00:41:35 WARN DAGScheduler: Broadcasting large task binary with size 2.0 MiB




22/10/06 00:41:48 WARN DAGScheduler: Broadcasting large task binary with size 1473.7 KiB


                                                                                

22/10/06 00:41:53 WARN DAGScheduler: Broadcasting large task binary with size 2.0 MiB




22/10/06 00:42:02 WARN DAGScheduler: Broadcasting large task binary with size 1473.7 KiB


                                                                                

22/10/06 00:42:06 WARN DAGScheduler: Broadcasting large task binary with size 2.0 MiB




22/10/06 00:42:15 WARN DAGScheduler: Broadcasting large task binary with size 1473.7 KiB


                                                                                

22/10/06 00:42:19 WARN DAGScheduler: Broadcasting large task binary with size 2.1 MiB
+-----------+-----------------+-------+-----------------+------------------+-----------------+--------------+------------------+-------------+--------------+-------------------+-------------+--------------------+
|Natural_var|Potential_Outlier|holiday|fraud_probability|dayofmonth_encoded|dayofweek_encoded| month_encoded|      tags_encoded|state_encoded|gender_encoded|   postcode_encoded|fraud_buckets|              scaled|
+-----------+-----------------+-------+-----------------+------------------+-----------------+--------------+------------------+-------------+--------------+-------------------+-------------+--------------------+
|          0|                0|      0|         92.99139|   (30,[29],[1.0])|        (6,[],[])|(11,[2],[1.0])|(3207,[183],[1.0])|(7,[1],[1.0])| (2,[0],[1.0])|(3160,[1772],[1.0])|          9.0|[0.00119578017432...|
+-----------+-----------------+-------+-----------------+-----

In [57]:
def vectorize(data: DataFrame, outcome: str):
    """
    Function to vectorize all the processed data
    """
    data = data.withColumnRenamed(outcome, "label")
    return VectorAssembler(
        inputCols= [c for c in data.drop("label").columns],
        outputCol="features"
    ).transform(data)

In [58]:
train_vector = vectorize(train_processed, "fraud_buckets")
val_vector = vectorize(val_processed, "fraud_buckets")
test_vector = vectorize(test_processed, "fraud_buckets")

train_vector.head(1)
val_vector.head(1)
test_vector.head(1)

22/10/06 00:56:04 WARN DAGScheduler: Broadcasting large task binary with size 2.7 MiB
22/10/06 00:56:06 WARN DAGScheduler: Broadcasting large task binary with size 2.7 MiB
22/10/06 00:56:09 WARN DAGScheduler: Broadcasting large task binary with size 2.7 MiB


[Row(Natural_var=0, Potential_Outlier=0, holiday=0, fraud_probability=92.99138641357422, dayofmonth_encoded=SparseVector(30, {29: 1.0}), dayofweek_encoded=SparseVector(6, {}), month_encoded=SparseVector(11, {2: 1.0}), tags_encoded=SparseVector(3207, {183: 1.0}), state_encoded=SparseVector(7, {1: 1.0}), gender_encoded=SparseVector(2, {0: 1.0}), postcode_encoded=SparseVector(3160, {1772: 1.0}), label=9.0, scaled=DenseVector([0.0012, -0.3014, 9.7089, 5.4351, 0.2911, 0.0152, 3.2695, 5.1927, 0.0, 0.0161, 4.9966, 5.6526, 40.2451, 0.0126, -1.8313, -3.0797, 3.1927, 5.5668, 1.1192, 2.7861, 44.084, 0.0097, 2.472, 1.2469, 37.2616, 0.0088, 2.4669, 0.9598, 0.0, 0.0144, 1.8605, 2.9789, 0.0199, 0.8657, 0.5944]), features=SparseVector(6462, {3: 92.9914, 33: 1.0, 42: 1.0, 234: 1.0, 3259: 1.0, 3265: 1.0, 5039: 1.0, 6427: 0.0012, 6428: -0.3014, 6429: 9.7089, 6430: 5.4351, 6431: 0.2911, 6432: 0.0152, 6433: 3.2695, 6434: 5.1927, 6436: 0.0161, 6437: 4.9966, 6438: 5.6526, 6439: 40.2451, 6440: 0.0126, 6441: -

### MODEL

In [59]:
inputCount = 6462                # Seen from sparse vector column
layers = [512, 256, 64, 10]
model = MultilayerPerceptronClassifier(
    labelCol='label',
    featuresCol='features',
    maxIter=50,
    layers=layers,
    blockSize=128,
    seed=669)

In [None]:
model_fit = model.fit(train_vector.select("features", "label"))

In [61]:
train_output = model_fit.transform(train_vector)
val_output = model_fit.transform(val_vector)
test_output = model_fit.transform(test_vector)

In [69]:
train_vector.write.parquet("../models/train_vector", mode="overwrite")
val_vector.write.parquet("../models/val_vector", mode="overwrite")
test_vector.write.parquet("../models/test_vector", mode="overwrite")

ConnectionRefusedError: [Errno 111] Connection refused

In [64]:
model.save("../models/classifier")

                                                                                

In [63]:
metrics = ['weightedPrecision', 'weightedRecall', 'accuracy']
for metric in metrics:
    evaluator = MulticlassClassificationEvaluator(metricName=metric)
    print('Train ' + metric + ' = ' + str(evaluator.evaluate(
        val_output.select("prediction", "label"))))

                                                                                

22/10/06 01:11:58 WARN DAGScheduler: Broadcasting large task binary with size 3.8 MiB


[Stage 575:>                                                       (0 + 8) / 90]

22/10/06 01:11:59 ERROR Executor: Exception in task 7.0 in stage 575.0 (TID 8500)
org.apache.spark.SparkException: Failed to execute user defined function (ProbabilisticClassificationModel$$Lambda$5189/0x0000000801fdc458: (struct<type:tinyint,size:int,indices:array<int>,values:array<double>>) => struct<type:tinyint,size:int,indices:array<int>,values:array<double>>)
	at org.apache.spark.sql.errors.QueryExecutionErrors$.failedExecuteUserDefinedFunctionError(QueryExecutionErrors.scala:177)
	at org.apache.spark.sql.errors.QueryExecutionErrors.failedExecuteUserDefinedFunctionError(QueryExecutionErrors.scala)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage41.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:760)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)


Py4JJavaError: An error occurred while calling o2754.evaluate.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 7 in stage 575.0 failed 1 times, most recent failure: Lost task 7.0 in stage 575.0 (TID 8500) (172.18.71.108 executor driver): org.apache.spark.SparkException: Failed to execute user defined function (ProbabilisticClassificationModel$$Lambda$5189/0x0000000801fdc458: (struct<type:tinyint,size:int,indices:array<int>,values:array<double>>) => struct<type:tinyint,size:int,indices:array<int>,values:array<double>>)
	at org.apache.spark.sql.errors.QueryExecutionErrors$.failedExecuteUserDefinedFunctionError(QueryExecutionErrors.scala:177)
	at org.apache.spark.sql.errors.QueryExecutionErrors.failedExecuteUserDefinedFunctionError(QueryExecutionErrors.scala)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage41.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:760)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:197)
	at org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)
	at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52)
	at org.apache.spark.scheduler.Task.run(Task.scala:136)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
	at java.base/java.lang.Thread.run(Thread.java:833)
Caused by: java.lang.IllegalArgumentException: requirement failed: A & B Dimension mismatch!
	at scala.Predef$.require(Predef.scala:281)
	at org.apache.spark.ml.ann.BreezeUtil$.dgemm(BreezeUtil.scala:42)
	at org.apache.spark.ml.ann.AffineLayerModel.eval(Layer.scala:164)
	at org.apache.spark.ml.ann.FeedForwardModel.forward(Layer.scala:508)
	at org.apache.spark.ml.ann.FeedForwardModel.predictRaw(Layer.scala:561)
	at org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel.predictRaw(MultilayerPerceptronClassifier.scala:332)
	at org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel.predictRaw(MultilayerPerceptronClassifier.scala:274)
	at org.apache.spark.ml.classification.ProbabilisticClassificationModel.$anonfun$transform$2(ProbabilisticClassifier.scala:121)
	... 19 more

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2672)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2608)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2607)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2607)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1182)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1182)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1182)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2860)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2802)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2791)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:952)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2228)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2249)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2268)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2293)
	at org.apache.spark.rdd.RDD.$anonfun$collect$1(RDD.scala:1021)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:406)
	at org.apache.spark.rdd.RDD.collect(RDD.scala:1020)
	at org.apache.spark.rdd.PairRDDFunctions.$anonfun$collectAsMap$1(PairRDDFunctions.scala:738)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:406)
	at org.apache.spark.rdd.PairRDDFunctions.collectAsMap(PairRDDFunctions.scala:737)
	at org.apache.spark.mllib.evaluation.MulticlassMetrics.confusions$lzycompute(MulticlassMetrics.scala:61)
	at org.apache.spark.mllib.evaluation.MulticlassMetrics.confusions(MulticlassMetrics.scala:52)
	at org.apache.spark.mllib.evaluation.MulticlassMetrics.labelCountByClass$lzycompute(MulticlassMetrics.scala:66)
	at org.apache.spark.mllib.evaluation.MulticlassMetrics.labelCountByClass(MulticlassMetrics.scala:64)
	at org.apache.spark.mllib.evaluation.MulticlassMetrics.weightedPrecision$lzycompute(MulticlassMetrics.scala:218)
	at org.apache.spark.mllib.evaluation.MulticlassMetrics.weightedPrecision(MulticlassMetrics.scala:218)
	at org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator.evaluate(MulticlassClassificationEvaluator.scala:155)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:77)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:568)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:833)
Caused by: org.apache.spark.SparkException: Failed to execute user defined function (ProbabilisticClassificationModel$$Lambda$5189/0x0000000801fdc458: (struct<type:tinyint,size:int,indices:array<int>,values:array<double>>) => struct<type:tinyint,size:int,indices:array<int>,values:array<double>>)
	at org.apache.spark.sql.errors.QueryExecutionErrors$.failedExecuteUserDefinedFunctionError(QueryExecutionErrors.scala:177)
	at org.apache.spark.sql.errors.QueryExecutionErrors.failedExecuteUserDefinedFunctionError(QueryExecutionErrors.scala)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage41.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:760)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.util.collection.ExternalSorter.insertAll(ExternalSorter.scala:197)
	at org.apache.spark.shuffle.sort.SortShuffleWriter.write(SortShuffleWriter.scala:63)
	at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52)
	at org.apache.spark.scheduler.Task.run(Task.scala:136)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635)
	... 1 more
Caused by: java.lang.IllegalArgumentException: requirement failed: A & B Dimension mismatch!
	at scala.Predef$.require(Predef.scala:281)
	at org.apache.spark.ml.ann.BreezeUtil$.dgemm(BreezeUtil.scala:42)
	at org.apache.spark.ml.ann.AffineLayerModel.eval(Layer.scala:164)
	at org.apache.spark.ml.ann.FeedForwardModel.forward(Layer.scala:508)
	at org.apache.spark.ml.ann.FeedForwardModel.predictRaw(Layer.scala:561)
	at org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel.predictRaw(MultilayerPerceptronClassifier.scala:332)
	at org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel.predictRaw(MultilayerPerceptronClassifier.scala:274)
	at org.apache.spark.ml.classification.ProbabilisticClassificationModel.$anonfun$transform$2(ProbabilisticClassifier.scala:121)
	... 19 more
