## import library

In [1]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
import pandas as pd

from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline 
from pyspark.ml.feature import VectorAssembler 
from pyspark.ml.feature import OneHotEncoder
from pyspark.sql.functions import rand
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [2]:
# Create a spark session (which will run spark jobs)
spark = (
    SparkSession.builder.appName("MAST30034 Project 2")
    .config("spark.sql.repl.eagerEval.enabled", True) 
    .config("spark.sql.parquet.cacheMetadata", "true")
    .config("spark.sql.session.timeZone", "Etc/UTC")
    .config("spark.driver.memory", "2g")
    .config("spark.executer.memory", "4g")
    .getOrCreate()
)

## preprocessing

### load data

In [3]:
full = spark.read.parquet("../../data/curated/full_data/")

In [4]:
probs_merchant = spark.read.option('header', True).csv('../../data/tables/merchant_fraud_probability.csv')
probs_consumer= spark.read.option('header', True).csv('../../data/tables/consumer_fraud_probability.csv')

### Merge tables

In [5]:
full.printSchema()

root
 |-- user_id: long (nullable = true)
 |-- SA2_code: integer (nullable = true)
 |-- postcode: integer (nullable = true)
 |-- consumer_id: integer (nullable = true)
 |-- state: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- mean_total_income: integer (nullable = true)
 |-- merchant_abn: long (nullable = true)
 |-- dollar_value: double (nullable = true)
 |-- order_id: string (nullable = true)
 |-- order_datetime: date (nullable = true)
 |-- name: string (nullable = true)
 |-- tags: string (nullable = true)
 |-- revenue_level: string (nullable = true)
 |-- take_rate: double (nullable = true)



In [6]:
probs_consumer.printSchema()

root
 |-- user_id: string (nullable = true)
 |-- order_datetime: string (nullable = true)
 |-- fraud_probability: string (nullable = true)



In [7]:
probs_consumer =  probs_consumer.withColumn('user_id', F.col('user_id').cast('long'))\
                                .withColumn('fraud_probability', F.col('fraud_probability').cast('float'))

In [8]:
probs_merchant.printSchema()

root
 |-- merchant_abn: string (nullable = true)
 |-- order_datetime: string (nullable = true)
 |-- fraud_probability: string (nullable = true)



In [9]:
probs_merchant =  probs_merchant.withColumn('merchant_abn', F.col('merchant_abn').cast('long'))\
                                .withColumn('fraud_probability', F.col('fraud_probability').cast('float'))

In [10]:
# merge transaction file with merchants'/consumers' fraud probability based on merchant abn or user id respectively by left join
full = full.join(probs_merchant, on = ['merchant_abn', 'order_datetime'], how = 'left').withColumnRenamed('fraud_probability', 'merchant_prob')
full = full.join(probs_consumer, on = ['user_id', 'order_datetime'], how = 'left').withColumnRenamed('fraud_probability', 'consumer_prob')

In [11]:
# replace all the missing value with 0.01 as default fraud prob
full = full.na.fill(value=0.01, subset=['merchant_prob', 'consumer_prob'])

In [12]:
# We set benchmark as 5%, because we focus on False Positive instead of False Negative
full = full.withColumn('is_fraud', F.when((F.col('merchant_prob') > 5) | (F.col('consumer_prob') > 5), 1).otherwise(0))
full

user_id,order_datetime,merchant_abn,SA2_code,postcode,consumer_id,state,gender,mean_total_income,dollar_value,order_id,name,tags,revenue_level,take_rate,merchant_prob,consumer_prob,is_fraud
44,2021-04-09,10648956813,401041015,5074,564558,SA,Undisclosed,53613,68.17405810943993,4d42fd2c-0823-4af...,Proin Nisl Institute,computer,a,6.66,0.01,0.01,0
44,2022-06-21,16492082804,401041015,5074,564558,SA,Undisclosed,53613,26.901038254990677,ed201616-286b-461...,Et Malesuada Inst...,shoe,b,3.6,0.01,0.01,0
44,2022-03-27,15115332331,401041015,5074,564558,SA,Undisclosed,53613,79.98167044943824,19ae7e83-e673-459...,In Consectetuer Ltd,florists,a,5.64,0.01,0.01,0
44,2021-11-11,17739089622,401041015,5074,564558,SA,Undisclosed,53613,12.556313709152784,77bba8e7-8124-43e...,Auctor Quis Corp.,cable,b,5.01,0.01,0.01,0
44,2022-05-23,21807339153,401041015,5074,564558,SA,Undisclosed,53613,31.20457577801381,10be6623-5f1b-456...,Praesent Eu LLP,digital goods,c,2.73,0.01,0.01,0
44,2022-06-11,27504885147,401041015,5074,564558,SA,Undisclosed,53613,5.247280898834883,ccea177d-4368-41e...,Enim Ltd,gift,b,4.81,0.01,0.01,0
44,2022-04-29,22961647681,401041015,5074,564558,SA,Undisclosed,53613,22.482302441089782,a262a794-8aab-463...,Vestibulum Ante C...,opticians,c,1.96,0.01,0.01,0
44,2021-03-26,30122382323,401041015,5074,564558,SA,Undisclosed,53613,57.97616021380351,894752b5-f32c-4b5...,Ipsum Company,watch,b,3.36,0.01,0.01,0
44,2022-08-28,35014882568,401041015,5074,564558,SA,Undisclosed,53613,43.67388081483236,f3126198-dd72-422...,Faucibus Corporation,digital goods,a,6.91,0.01,0.01,0
44,2021-07-06,34096466752,401041015,5074,564558,SA,Undisclosed,53613,127.91174243430336,b33fb8e8-faf4-424...,Nullam Enim Ltd,computer,b,3.22,0.01,0.01,0


In [13]:
print('In {} transactions, {} are detected as fraud'.format(full.count(), full.filter(F.col('is_fraud') == 1).count()))

In 13614854 transactions, 75455 are detected as fraud


In [14]:
full = full.filter(F.col('dollar_value') >= 1).na.drop(subset = 'name')

In [15]:
full.printSchema()

root
 |-- user_id: long (nullable = true)
 |-- order_datetime: date (nullable = true)
 |-- merchant_abn: long (nullable = true)
 |-- SA2_code: integer (nullable = true)
 |-- postcode: integer (nullable = true)
 |-- consumer_id: integer (nullable = true)
 |-- state: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- mean_total_income: integer (nullable = true)
 |-- dollar_value: double (nullable = true)
 |-- order_id: string (nullable = true)
 |-- name: string (nullable = true)
 |-- tags: string (nullable = true)
 |-- revenue_level: string (nullable = true)
 |-- take_rate: double (nullable = true)
 |-- merchant_prob: float (nullable = false)
 |-- consumer_prob: float (nullable = false)
 |-- is_fraud: integer (nullable = false)



In [16]:
full = full.withColumn('take_rate', F.col('take_rate').cast('float'))

## Feature Engineering

### bin numeric features

In [17]:
#value_max = result.select('dollar_value').orderBy(F.col('dollar_value'),  ascending= False).collect()[0][0]

In [18]:
#value_min = result.select('dollar_value').orderBy(F.col('dollar_value')).collect()[0][0]

In [19]:
#print('dollar_value range:[{}, {}]'.format(value_min, value_max))

In [20]:
# bins = []
# i  = 0
# while i <= 70000:
#   bins.append(i)
#   i += 200

In [21]:
#bins.append(float('Inf'))

In [22]:
#bucketizer = Bucketizer(splits=bins, inputCol="dollar_value", outputCol="dollar_value_buckets")
#result = bucketizer.setHandleInvalid("keep").transform(result)

In [23]:
#result = result.withColumn('take_rate', F.col('take_rate').cast('double'))

In [24]:
# take_rate_max = result.select('take_rate').orderBy(F.col('take_rate'),  ascending= False).collect()[0][0]
# take_rate_min = result.select('take_rate').orderBy(F.col('take_rate')).collect()[0][0]

In [25]:
# print('take_rate range:[{}, {}]'.format(take_rate_min, take_rate_max))

In [26]:
# bucketizer = Bucketizer(splits=[0,1,2,3,4,5,6,7,8], inputCol="take_rate", outputCol="take_rate_buckets")
# result = bucketizer.setHandleInvalid("keep").transform(result)

In [27]:
# result.printSchema()

### Index ordinal features

In [28]:
full =  full.withColumn('month', F.month('order_datetime'))

In [29]:
indexed_features = ['revenue_level', 'tags', 'gender']

In [30]:
# We give all values in non-numeric features an index in order to make it ordinal or one-hot encoded
indexers =[]
for col in indexed_features:
  indexers.append(StringIndexer(inputCol=col, outputCol = col+"_index"))

indexers

[StringIndexer_cbe3b1955ab1,
 StringIndexer_8131eea67ff7,
 StringIndexer_70674c697405]

In [31]:
pipeline = Pipeline(stages=indexers)

In [32]:
indexed_result = pipeline.fit(full).transform(full)

### One hot encoding

In [33]:
categorical_features =  ["tags_index", "gender_index","month"]

In [34]:
# one-hot-encoding the numeric indices
ohe = []
for f in categorical_features:
  ohe.append(OneHotEncoder(inputCol=f, outputCol=f+"OHE"))

ohe

[OneHotEncoder_e692a7cd0b3d,
 OneHotEncoder_92d11c2b5d99,
 OneHotEncoder_60deea47bbba]

In [35]:
pipeline = Pipeline(stages=ohe)

In [36]:
indexed_result = pipeline.fit(indexed_result).transform(indexed_result)

In [37]:
indexed_result.printSchema()

root
 |-- user_id: long (nullable = true)
 |-- order_datetime: date (nullable = true)
 |-- merchant_abn: long (nullable = true)
 |-- SA2_code: integer (nullable = true)
 |-- postcode: integer (nullable = true)
 |-- consumer_id: integer (nullable = true)
 |-- state: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- mean_total_income: integer (nullable = true)
 |-- dollar_value: double (nullable = true)
 |-- order_id: string (nullable = true)
 |-- name: string (nullable = true)
 |-- tags: string (nullable = true)
 |-- revenue_level: string (nullable = true)
 |-- take_rate: float (nullable = true)
 |-- merchant_prob: float (nullable = false)
 |-- consumer_prob: float (nullable = false)
 |-- is_fraud: integer (nullable = false)
 |-- month: integer (nullable = true)
 |-- revenue_level_index: double (nullable = false)
 |-- tags_index: double (nullable = false)
 |-- gender_index: double (nullable = false)
 |-- tags_indexOHE: vector (nullable = true)
 |-- gender_indexOHE: vec

### Feature Selection and Vectorization

In [38]:
feature_selected = ['dollar_value','take_rate','mean_total_income','monthOHE','revenue_level_index','tags_indexOHE','gender_indexOHE']

In [39]:
assembler = VectorAssembler(inputCols=feature_selected ,outputCol='features')

In [40]:
output = assembler.transform(indexed_result)

In [41]:
final_data = output.filter(F.col('order_datetime') < '2022-02-28').select('features','is_fraud')

In [42]:
(final_data.filter(F.col('is_fraud') == 1).count())/(final_data.filter(F.col('is_fraud') == 0).count())

0.009780748354606692

In [43]:
final_data.filter(F.col('is_fraud') == 1).count(), final_data.filter(F.col('is_fraud') == 0).count()

(75108, 7679167)

In [44]:
# imbalanced distribution of two classes. We decide to split the data according to their class and make the distribution balanced
fraud_data = final_data.filter(F.col('is_fraud') == 1)
normal_data = final_data.filter(F.col('is_fraud') == 0).randomSplit([0.01,0.99])[0]

### set random seed?

## Split data

In [45]:
train_fraud,test_fraud = fraud_data.randomSplit([0.7,0.3])
train_normal,test_normal = normal_data.randomSplit([0.7,0.3])

In [46]:
train_data = train_fraud.union(train_normal).orderBy(rand())
test_data = test_fraud.union(test_normal).orderBy(rand())

In [47]:
final_data.count()

7754275

In [48]:
train_data.count(),test_data.count()

(106503, 45581)

## Logistic Regression

In [49]:
lr = LogisticRegression(labelCol='is_fraud')

In [50]:
fitted_model = lr.fit(train_data)
fitted_model.setFeaturesCol("features")
fitted_model.setPredictionCol("prediction")

LogisticRegressionModel: uid=LogisticRegression_63864f74f7db, numClasses=2, numFeatures=41

### Evaluation

In [51]:
pred_and_labels = fitted_model.evaluate(test_data)

In [52]:
score_and_label = pred_and_labels.predictions.select('prediction', 'is_fraud').withColumnRenamed('is_fraud', 'label')

In [53]:
evaluator = MulticlassClassificationEvaluator()

In [54]:
evaluator.setPredictionCol("prediction")

MulticlassClassificationEvaluator_cd462c6eb4d0

In [55]:
print("Accuracy: " + str(evaluator.evaluate(score_and_label)))

Accuracy: 0.732965062529167


In [56]:
tp = score_and_label.filter((F.col('prediction') == 1) & (F.col('label') == 1)).count()

In [57]:
fn = score_and_label.filter((F.col('prediction') == 0) & (F.col('label') == 1)).count()

In [58]:
recall = tp/(tp+fn)
print('recall: ' + str(recall))

recall: 0.6806498868025037


## 

## Prediction

In [59]:
train_data = output.filter(F.col('order_datetime') < '2022-02-28')
predict_data = output.filter(F.col('order_datetime') >= '2022-02-28')

In [60]:
fraud_data = train_data.filter(F.col('is_fraud') == 1)
normal_data = train_data.filter(F.col('is_fraud') == 0).randomSplit([0.01,0.99])[0]

In [61]:
train_data = fraud_data.union(normal_data).orderBy(rand())

In [62]:
lr_final = LogisticRegression(labelCol='is_fraud')

In [64]:
fitted_model = lr_final.fit(train_data)

In [65]:
fitted_model.setFeaturesCol('features')
fitted_model.setPredictionCol('is_fraud')

LogisticRegressionModel: uid=LogisticRegression_2c1ff4df5e3c, numClasses=2, numFeatures=41

In [66]:
predict_data = predict_data.drop("is_fraud")
predicted = fitted_model.transform(predict_data)

In [67]:
cols = ['user_id','order_datetime','merchant_abn','SA2_code','postcode','consumer_id','state','gender','mean_total_income','dollar_value','order_id','name','tags','revenue_level','take_rate','is_fraud'	]
predicted =  predicted.select(cols)

In [68]:
full = full.drop('merchant_prob', 'consumer_prob', 'month')

In [69]:
predicted.count()

5749867

In [70]:
full = full.filter(F.col('order_datetime') < '2022-02-28')

In [71]:
full = full.union(predicted)

In [72]:
full.printSchema()

root
 |-- user_id: long (nullable = true)
 |-- order_datetime: date (nullable = true)
 |-- merchant_abn: long (nullable = true)
 |-- SA2_code: integer (nullable = true)
 |-- postcode: integer (nullable = true)
 |-- consumer_id: integer (nullable = true)
 |-- state: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- mean_total_income: integer (nullable = true)
 |-- dollar_value: double (nullable = true)
 |-- order_id: string (nullable = true)
 |-- name: string (nullable = true)
 |-- tags: string (nullable = true)
 |-- revenue_level: string (nullable = true)
 |-- take_rate: float (nullable = true)
 |-- is_fraud: double (nullable = false)



In [74]:
full = full.withColumn("is_fraud", F.col("is_fraud").cast("INT"))

In [75]:
full.cache()

user_id,order_datetime,merchant_abn,SA2_code,postcode,consumer_id,state,gender,mean_total_income,dollar_value,order_id,name,tags,revenue_level,take_rate,is_fraud
44,2021-04-09,10648956813,401041015,5074,564558,SA,Undisclosed,53613,68.17405810943993,4d42fd2c-0823-4af...,Proin Nisl Institute,computer,a,6.66,0
44,2021-11-11,17739089622,401041015,5074,564558,SA,Undisclosed,53613,12.556313709152784,77bba8e7-8124-43e...,Auctor Quis Corp.,cable,b,5.01,0
44,2021-03-26,30122382323,401041015,5074,564558,SA,Undisclosed,53613,57.97616021380351,894752b5-f32c-4b5...,Ipsum Company,watch,b,3.36,0
44,2021-07-06,34096466752,401041015,5074,564558,SA,Undisclosed,53613,127.91174243430336,b33fb8e8-faf4-424...,Nullam Enim Ltd,computer,b,3.22,0
44,2021-04-08,38435278995,401041015,5074,564558,SA,Undisclosed,53613,56.22196422567881,f10f5c9e-0844-4e1...,Sed Consequat Corp.,hobby,a,6.17,0
44,2021-10-20,41974958954,401041015,5074,564558,SA,Undisclosed,53613,60.02299315612368,cf13f2cc-5d81-413...,Sed Libero Proin ...,cable,a,5.51,0
44,2021-04-13,49891706470,401041015,5074,564558,SA,Undisclosed,53613,15.2725764827411,e2a272ce-bb6b-43f...,Non Vestibulum In...,tent,a,5.8,0
44,2021-12-01,49891706470,401041015,5074,564558,SA,Undisclosed,53613,6.1436338479183785,43a40900-af31-421...,Non Vestibulum In...,tent,a,5.8,0
44,2022-01-19,49891706470,401041015,5074,564558,SA,Undisclosed,53613,32.38049155067601,fe3929b0-d4c6-408...,Non Vestibulum In...,tent,a,5.8,0
44,2021-04-02,46804135891,401041015,5074,564558,SA,Undisclosed,53613,33.11945261885712,c0e6fffc-3f10-403...,Suspendisse Dui C...,opticians,c,2.93,0


In [76]:
## underlying files have been updated. explicitly invalidate the cache
full.createOrReplaceTempView("full_data_with_lable")

In [77]:
query = "REFRESH TABLE full_data_with_lable" 

In [78]:
spark.sql(query)

In [79]:
full.write.format('parquet').mode('overwrite').save("../../data/curated/full_data_with_fraud")

In [80]:
full.unpersist()

user_id,order_datetime,merchant_abn,SA2_code,postcode,consumer_id,state,gender,mean_total_income,dollar_value,order_id,name,tags,revenue_level,take_rate,is_fraud
44,2021-04-09,10648956813,401041015,5074,564558,SA,Undisclosed,53613,68.17405810943993,4d42fd2c-0823-4af...,Proin Nisl Institute,computer,a,6.66,0
44,2021-11-11,17739089622,401041015,5074,564558,SA,Undisclosed,53613,12.556313709152784,77bba8e7-8124-43e...,Auctor Quis Corp.,cable,b,5.01,0
44,2021-03-26,30122382323,401041015,5074,564558,SA,Undisclosed,53613,57.97616021380351,894752b5-f32c-4b5...,Ipsum Company,watch,b,3.36,0
44,2021-07-06,34096466752,401041015,5074,564558,SA,Undisclosed,53613,127.91174243430336,b33fb8e8-faf4-424...,Nullam Enim Ltd,computer,b,3.22,0
44,2021-04-08,38435278995,401041015,5074,564558,SA,Undisclosed,53613,56.22196422567881,f10f5c9e-0844-4e1...,Sed Consequat Corp.,hobby,a,6.17,0
44,2021-10-20,41974958954,401041015,5074,564558,SA,Undisclosed,53613,60.02299315612368,cf13f2cc-5d81-413...,Sed Libero Proin ...,cable,a,5.51,0
44,2021-04-13,49891706470,401041015,5074,564558,SA,Undisclosed,53613,15.2725764827411,e2a272ce-bb6b-43f...,Non Vestibulum In...,tent,a,5.8,0
44,2021-12-01,49891706470,401041015,5074,564558,SA,Undisclosed,53613,6.1436338479183785,43a40900-af31-421...,Non Vestibulum In...,tent,a,5.8,0
44,2022-01-19,49891706470,401041015,5074,564558,SA,Undisclosed,53613,32.38049155067601,fe3929b0-d4c6-408...,Non Vestibulum In...,tent,a,5.8,0
44,2021-04-02,46804135891,401041015,5074,564558,SA,Undisclosed,53613,33.11945261885712,c0e6fffc-3f10-403...,Suspendisse Dui C...,opticians,c,2.93,0
