## import library

In [1]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
import pandas as pd

from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline 
from pyspark.ml.feature import VectorAssembler 
from pyspark.ml.feature import OneHotEncoder
from pyspark.sql.functions import rand
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [2]:
# Create a spark session (which will run spark jobs)
spark = (
    SparkSession.builder.appName("MAST30034 Project 2")
    .config("spark.sql.repl.eagerEval.enabled", True) 
    .config("spark.sql.parquet.cacheMetadata", "true")
    .config("spark.sql.session.timeZone", "Etc/UTC")
    .config("spark.driver.memory", "2g")
    .config("spark.executer.memory", "4g")
    .getOrCreate()
)

22/09/22 21:43:48 WARN Utils: Your hostname, DESKTOP-80AOBLL resolves to a loopback address: 127.0.1.1; using 172.18.176.175 instead (on interface eth0)
22/09/22 21:43:48 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


22/09/22 21:43:50 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
22/09/22 21:43:51 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


## preprocessing

### load data

In [3]:
full = spark.read.parquet("../../data/curated/full_data/")

                                                                                

In [4]:
probs_merchant = spark.read.option('header', True).csv('../../data/tables/merchant_fraud_probability.csv')
probs_consumer= spark.read.option('header', True).csv('../../data/tables/consumer_fraud_probability.csv')

### Merge tables

In [5]:
full.printSchema()

root
 |-- user_id: long (nullable = true)
 |-- SA2_code: integer (nullable = true)
 |-- postcode: integer (nullable = true)
 |-- consumer_id: integer (nullable = true)
 |-- state: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- mean_total_income: integer (nullable = true)
 |-- merchant_abn: long (nullable = true)
 |-- dollar_value: double (nullable = true)
 |-- order_id: string (nullable = true)
 |-- order_datetime: date (nullable = true)
 |-- name: string (nullable = true)
 |-- tags: string (nullable = true)
 |-- revenue_level: string (nullable = true)
 |-- take_rate: double (nullable = true)



In [6]:
probs_consumer.printSchema()

root
 |-- user_id: string (nullable = true)
 |-- order_datetime: string (nullable = true)
 |-- fraud_probability: string (nullable = true)



In [7]:
probs_consumer =  probs_consumer.withColumn('user_id', F.col('user_id').cast('long'))\
                                .withColumn('fraud_probability', F.col('fraud_probability').cast('float'))

In [8]:
probs_merchant.printSchema()

root
 |-- merchant_abn: string (nullable = true)
 |-- order_datetime: string (nullable = true)
 |-- fraud_probability: string (nullable = true)



In [9]:
probs_merchant =  probs_merchant.withColumn('merchant_abn', F.col('merchant_abn').cast('long'))\
                                .withColumn('fraud_probability', F.col('fraud_probability').cast('float'))

In [10]:
# merge transaction file with merchants'/consumers' fraud probability based on merchant abn or user id respectively by left join
full = full.join(probs_merchant, on = ['merchant_abn', 'order_datetime'], how = 'left').withColumnRenamed('fraud_probability', 'merchant_prob')
full = full.join(probs_consumer, on = ['user_id', 'order_datetime'], how = 'left').withColumnRenamed('fraud_probability', 'consumer_prob')

In [11]:
# replace all the missing value with 0.01 as default fraud prob
full = full.na.fill(value=0.01, subset=['merchant_prob', 'consumer_prob'])

In [12]:
# We set benchmark as 5%, because we focus on False Positive instead of False Negative
full = full.withColumn('is_fraud', F.when((F.col('merchant_prob') > 5) | (F.col('consumer_prob') > 5), 1).otherwise(0))
full

                                                                                

user_id,order_datetime,merchant_abn,SA2_code,postcode,consumer_id,state,gender,mean_total_income,dollar_value,order_id,name,tags,revenue_level,take_rate,merchant_prob,consumer_prob,is_fraud
44,2021-04-09,10648956813,401041015,5074,564558,SA,Undisclosed,53613,68.17405810943993,4d42fd2c-0823-4af...,Proin Nisl Institute,computer,a,6.66,0.01,0.01,0
44,2021-03-26,16256895427,401041015,5074,564558,SA,Undisclosed,53613,109.29115227377834,20849f6c-9286-4e5...,Tempus Non Founda...,garden supply,a,6.6,0.01,0.01,0
44,2022-03-27,15115332331,401041015,5074,564558,SA,Undisclosed,53613,79.98167044943824,19ae7e83-e673-459...,In Consectetuer Ltd,florists,a,5.64,0.01,0.01,0
44,2022-05-06,20885454195,401041015,5074,564558,SA,Undisclosed,53613,69.20949836618651,a11e7fb4-42d7-4f9...,Pharetra Ut Indus...,cable,b,4.94,0.01,0.01,0
44,2022-05-23,21807339153,401041015,5074,564558,SA,Undisclosed,53613,31.20457577801381,10be6623-5f1b-456...,Praesent Eu LLP,digital goods,c,2.73,0.01,0.01,0
44,2021-09-03,21702179125,401041015,5074,564558,SA,Undisclosed,53613,128.94839096472307,e2eb1a75-836d-43b...,At Auctor Ullamco...,gift,c,2.65,0.01,0.01,0
44,2022-04-29,22961647681,401041015,5074,564558,SA,Undisclosed,53613,22.482302441089782,a262a794-8aab-463...,Vestibulum Ante C...,opticians,c,1.96,0.01,0.01,0
44,2022-06-11,25235376304,401041015,5074,564558,SA,Undisclosed,53613,35.67939574082981,568de521-5f63-445...,Pellentesque Tell...,garden supply,c,2.86,0.01,0.01,0
44,2022-08-28,35014882568,401041015,5074,564558,SA,Undisclosed,53613,43.67388081483236,f3126198-dd72-422...,Faucibus Corporation,digital goods,a,6.91,0.01,0.01,0
44,2022-08-25,36125151647,401041015,5074,564558,SA,Undisclosed,53613,20.96605822155854,ef56a7e3-81c7-4a8...,Sed Nec Corp.,hobby,c,1.83,0.01,0.01,0


In [13]:
print('In {} transactions, {} are detected as fraud'.format(full.count(), full.filter(F.col('is_fraud') == 1).count()))



In 13614854 transactions, 75455 are detected as fraud


                                                                                

In [14]:
full = full.filter(F.col('dollar_value') >= 1).na.drop(subset = 'name')

In [15]:
full.printSchema()

root
 |-- user_id: long (nullable = true)
 |-- order_datetime: date (nullable = true)
 |-- merchant_abn: long (nullable = true)
 |-- SA2_code: integer (nullable = true)
 |-- postcode: integer (nullable = true)
 |-- consumer_id: integer (nullable = true)
 |-- state: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- mean_total_income: integer (nullable = true)
 |-- dollar_value: double (nullable = true)
 |-- order_id: string (nullable = true)
 |-- name: string (nullable = true)
 |-- tags: string (nullable = true)
 |-- revenue_level: string (nullable = true)
 |-- take_rate: double (nullable = true)
 |-- merchant_prob: float (nullable = false)
 |-- consumer_prob: float (nullable = false)
 |-- is_fraud: integer (nullable = false)



In [16]:
full = full.withColumn('take_rate', F.col('take_rate').cast('float'))

## Feature Engineering

### bin numeric features

In [17]:
#value_max = result.select('dollar_value').orderBy(F.col('dollar_value'),  ascending= False).collect()[0][0]

In [18]:
#value_min = result.select('dollar_value').orderBy(F.col('dollar_value')).collect()[0][0]

In [19]:
#print('dollar_value range:[{}, {}]'.format(value_min, value_max))

In [20]:
# bins = []
# i  = 0
# while i <= 70000:
#   bins.append(i)
#   i += 200

In [21]:
#bins.append(float('Inf'))

In [22]:
#bucketizer = Bucketizer(splits=bins, inputCol="dollar_value", outputCol="dollar_value_buckets")
#result = bucketizer.setHandleInvalid("keep").transform(result)

In [23]:
#result = result.withColumn('take_rate', F.col('take_rate').cast('double'))

In [24]:
# take_rate_max = result.select('take_rate').orderBy(F.col('take_rate'),  ascending= False).collect()[0][0]
# take_rate_min = result.select('take_rate').orderBy(F.col('take_rate')).collect()[0][0]

In [25]:
# print('take_rate range:[{}, {}]'.format(take_rate_min, take_rate_max))

In [26]:
# bucketizer = Bucketizer(splits=[0,1,2,3,4,5,6,7,8], inputCol="take_rate", outputCol="take_rate_buckets")
# result = bucketizer.setHandleInvalid("keep").transform(result)

In [27]:
# result.printSchema()

### Index ordinal features

In [17]:
full =  full.withColumn('month', F.month('order_datetime'))

In [18]:
indexed_features = ['revenue_level', 'tags', 'gender']

In [19]:
# We give all values in non-numeric features an index in order to make it ordinal or one-hot encoded
indexers =[]
for col in indexed_features:
  indexers.append(StringIndexer(inputCol=col, outputCol = col+"_index"))

indexers

[StringIndexer_7755e8d09070,
 StringIndexer_1b4cea862f0f,
 StringIndexer_0bd27832b147]

In [20]:
pipeline = Pipeline(stages=indexers)

In [21]:
indexed_result = pipeline.fit(full).transform(full)

                                                                                

### One hot encoding

In [22]:
categorical_features =  ["tags_index", "gender_index","month"]

In [23]:
# one-hot-encoding the numeric indices
ohe = []
for f in categorical_features:
  ohe.append(OneHotEncoder(inputCol=f, outputCol=f+"OHE"))

ohe

[OneHotEncoder_ff2b4c5bb4fb,
 OneHotEncoder_6ad7d6da7862,
 OneHotEncoder_e0e85936a6fc]

In [24]:
pipeline = Pipeline(stages=ohe)

In [25]:
indexed_result = pipeline.fit(indexed_result).transform(indexed_result)

                                                                                

In [26]:
indexed_result.printSchema()

root
 |-- user_id: long (nullable = true)
 |-- order_datetime: date (nullable = true)
 |-- merchant_abn: long (nullable = true)
 |-- SA2_code: integer (nullable = true)
 |-- postcode: integer (nullable = true)
 |-- consumer_id: integer (nullable = true)
 |-- state: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- mean_total_income: integer (nullable = true)
 |-- dollar_value: double (nullable = true)
 |-- order_id: string (nullable = true)
 |-- name: string (nullable = true)
 |-- tags: string (nullable = true)
 |-- revenue_level: string (nullable = true)
 |-- take_rate: float (nullable = true)
 |-- merchant_prob: float (nullable = false)
 |-- consumer_prob: float (nullable = false)
 |-- is_fraud: integer (nullable = false)
 |-- month: integer (nullable = true)
 |-- revenue_level_index: double (nullable = false)
 |-- tags_index: double (nullable = false)
 |-- gender_index: double (nullable = false)
 |-- tags_indexOHE: vector (nullable = true)
 |-- gender_indexOHE: vec

### Feature Selection and Vectorization

In [27]:
feature_selected = ['dollar_value','take_rate','mean_total_income','monthOHE','revenue_level_index','tags_indexOHE','gender_indexOHE']

In [28]:
assembler = VectorAssembler(inputCols=feature_selected ,outputCol='features')

In [29]:
output = assembler.transform(indexed_result)

In [30]:
final_data = output.filter(F.col('order_datetime') < '2022-02-28').select('features','is_fraud')

In [31]:
(final_data.filter(F.col('is_fraud') == 1).count())/(final_data.filter(F.col('is_fraud') == 0).count())

                                                                                

0.009780748354606692

In [32]:
final_data.filter(F.col('is_fraud') == 1).count(), final_data.filter(F.col('is_fraud') == 0).count()

                                                                                

(75108, 7679167)

In [33]:
# imbalanced distribution of two classes. We decide to split the data according to their class and make the distribution balanced
fraud_data = final_data.filter(F.col('is_fraud') == 1)
normal_data = final_data.filter(F.col('is_fraud') == 0).randomSplit([0.01,0.99])[0]

### set random seed?

## Split data

In [34]:
train_fraud,test_fraud = fraud_data.randomSplit([0.7,0.3])
train_normal,test_normal = normal_data.randomSplit([0.7,0.3])

In [35]:
train_data = train_fraud.union(train_normal).orderBy(rand())
test_data = test_fraud.union(test_normal).orderBy(rand())

In [36]:
final_data.count()

                                                                                

7754275

In [37]:
train_data.count(),test_data.count()

                                                                                

(106409, 45855)

## Logistic Regression

In [38]:
lr = LogisticRegression(labelCol='is_fraud')

In [39]:
fitted_model = lr.fit(train_data)
fitted_model.setFeaturesCol("features")
fitted_model.setPredictionCol("prediction")



LogisticRegressionModel: uid=LogisticRegression_b17aca1cc915, numClasses=2, numFeatures=41

### Evaluation

In [40]:
pred_and_labels = fitted_model.evaluate(test_data)



In [41]:
score_and_label = pred_and_labels.predictions.select('prediction', 'is_fraud').withColumnRenamed('is_fraud', 'label')

In [42]:
evaluator = MulticlassClassificationEvaluator()

In [43]:
evaluator.setPredictionCol("prediction")

MulticlassClassificationEvaluator_89dedcff7e02

In [44]:
print("Accuracy: " + str(evaluator.evaluate(score_and_label)))



Accuracy: 0.73644834318918


                                                                                

In [45]:
tp = score_and_label.filter((F.col('prediction') == 1) & (F.col('label') == 1)).count()

                                                                                

In [46]:
fn = score_and_label.filter((F.col('prediction') == 0) & (F.col('label') == 1)).count()

                                                                                

In [47]:
recall = tp/(tp+fn)
print('recall: ' + str(recall))

recall: 0.6848548905623817


## 

## Prediction

In [48]:
train_data = output.filter(F.col('order_datetime') < '2022-02-28')
predict_data = output.filter(F.col('order_datetime') >= '2022-02-28')

In [49]:
fraud_data = train_data.filter(F.col('is_fraud') == 1)
normal_data = train_data.filter(F.col('is_fraud') == 0).randomSplit([0.01,0.99])[0]

In [50]:
train_data = fraud_data.union(normal_data).orderBy(rand())

In [51]:
lr_final = LogisticRegression(labelCol='is_fraud')

In [52]:
fitted_model = lr_final.fit(train_data)



In [53]:
fitted_model.setFeaturesCol('features')
fitted_model.setPredictionCol('is_fraud')

LogisticRegressionModel: uid=LogisticRegression_b6ebe913dd8a, numClasses=2, numFeatures=41

In [54]:
predict_data = predict_data.drop("is_fraud")
predicted = fitted_model.transform(predict_data)

In [55]:
cols = ['user_id','order_datetime','merchant_abn','SA2_code','postcode','consumer_id','state','gender','mean_total_income','dollar_value','order_id','name','tags','revenue_level','take_rate','is_fraud'	]
predicted =  predicted.select(cols)

In [56]:
full = full.drop('merchant_prob', 'consumer_prob', 'month')

In [57]:
predicted.count()

                                                                                

5749867

In [58]:
full = full.filter(F.col('order_datetime') < '2022-02-28')

In [59]:
full = full.union(predicted)

In [60]:
full.printSchema()

root
 |-- user_id: long (nullable = true)
 |-- order_datetime: date (nullable = true)
 |-- merchant_abn: long (nullable = true)
 |-- SA2_code: integer (nullable = true)
 |-- postcode: integer (nullable = true)
 |-- consumer_id: integer (nullable = true)
 |-- state: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- mean_total_income: integer (nullable = true)
 |-- dollar_value: double (nullable = true)
 |-- order_id: string (nullable = true)
 |-- name: string (nullable = true)
 |-- tags: string (nullable = true)
 |-- revenue_level: string (nullable = true)
 |-- take_rate: float (nullable = true)
 |-- is_fraud: double (nullable = false)



In [61]:
full = full.withColumn("is_fraud", F.col("is_fraud").cast("INT"))

In [62]:
full.cache()

                                                                                

user_id,order_datetime,merchant_abn,SA2_code,postcode,consumer_id,state,gender,mean_total_income,dollar_value,order_id,name,tags,revenue_level,take_rate,is_fraud
44,2021-04-09,10648956813,401041015,5074,564558,SA,Undisclosed,53613,68.17405810943993,4d42fd2c-0823-4af...,Proin Nisl Institute,computer,a,6.66,0
44,2021-03-26,16256895427,401041015,5074,564558,SA,Undisclosed,53613,109.29115227377834,20849f6c-9286-4e5...,Tempus Non Founda...,garden supply,a,6.6,0
44,2021-09-03,21702179125,401041015,5074,564558,SA,Undisclosed,53613,128.94839096472307,e2eb1a75-836d-43b...,At Auctor Ullamco...,gift,c,2.65,0
44,2021-09-28,36125151647,401041015,5074,564558,SA,Undisclosed,53613,50.97888666647738,be7da9a9-b31a-4d8...,Sed Nec Corp.,hobby,c,1.83,0
44,2021-10-20,41974958954,401041015,5074,564558,SA,Undisclosed,53613,60.02299315612368,cf13f2cc-5d81-413...,Sed Libero Proin ...,cable,a,5.51,0
44,2021-09-22,41315101616,401041015,5074,564558,SA,Undisclosed,53613,22.01099830027075,784eb5b8-7032-42d...,Malesuada Fames F...,antique,b,5.05,0
44,2021-12-01,49891706470,401041015,5074,564558,SA,Undisclosed,53613,6.1436338479183785,43a40900-af31-421...,Non Vestibulum In...,tent,a,5.8,0
44,2021-11-24,47459279421,401041015,5074,564558,SA,Undisclosed,53613,37.40432919826938,fcba3312-3c72-4e5...,Aliquam Gravida PC,shoe,a,6.86,0
44,2021-07-20,49891706470,401041015,5074,564558,SA,Undisclosed,53613,25.611929292222403,9d3f75e2-f164-42c...,Non Vestibulum In...,tent,a,5.8,0
44,2021-06-29,49891706470,401041015,5074,564558,SA,Undisclosed,53613,9.534770249908876,59ec83dc-bc56-4f0...,Non Vestibulum In...,tent,a,5.8,0


In [63]:
## underlying files have been updated. explicitly invalidate the cache
full.createOrReplaceTempView("full_data_with_lable")

22/09/22 22:12:51 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


In [64]:
query = "REFRESH TABLE full_data_with_lable" 

In [65]:
spark.sql(query)

In [66]:
full.write.format('parquet').mode('overwrite').save("../../data/curated/full_data_with_fraud")

                                                                                

In [67]:
full.unpersist()

                                                                                

user_id,order_datetime,merchant_abn,SA2_code,postcode,consumer_id,state,gender,mean_total_income,dollar_value,order_id,name,tags,revenue_level,take_rate,is_fraud
44,2021-04-09,10648956813,401041015,5074,564558,SA,Undisclosed,53613,68.17405810943993,4d42fd2c-0823-4af...,Proin Nisl Institute,computer,a,6.66,0
44,2021-03-26,16256895427,401041015,5074,564558,SA,Undisclosed,53613,109.29115227377834,20849f6c-9286-4e5...,Tempus Non Founda...,garden supply,a,6.6,0
44,2021-09-03,21702179125,401041015,5074,564558,SA,Undisclosed,53613,128.94839096472307,e2eb1a75-836d-43b...,At Auctor Ullamco...,gift,c,2.65,0
44,2021-09-28,36125151647,401041015,5074,564558,SA,Undisclosed,53613,50.97888666647738,be7da9a9-b31a-4d8...,Sed Nec Corp.,hobby,c,1.83,0
44,2021-10-20,41974958954,401041015,5074,564558,SA,Undisclosed,53613,60.02299315612368,cf13f2cc-5d81-413...,Sed Libero Proin ...,cable,a,5.51,0
44,2021-09-22,41315101616,401041015,5074,564558,SA,Undisclosed,53613,22.01099830027075,784eb5b8-7032-42d...,Malesuada Fames F...,antique,b,5.05,0
44,2021-12-01,49891706470,401041015,5074,564558,SA,Undisclosed,53613,6.1436338479183785,43a40900-af31-421...,Non Vestibulum In...,tent,a,5.8,0
44,2021-11-24,47459279421,401041015,5074,564558,SA,Undisclosed,53613,37.40432919826938,fcba3312-3c72-4e5...,Aliquam Gravida PC,shoe,a,6.86,0
44,2021-07-20,49891706470,401041015,5074,564558,SA,Undisclosed,53613,25.611929292222403,9d3f75e2-f164-42c...,Non Vestibulum In...,tent,a,5.8,0
44,2021-06-29,49891706470,401041015,5074,564558,SA,Undisclosed,53613,9.534770249908876,59ec83dc-bc56-4f0...,Non Vestibulum In...,tent,a,5.8,0
