## import library

In [1]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
import pandas as pd

from pyspark.ml.feature import StringIndexer
from pyspark.ml import Pipeline 
from pyspark.ml.feature import VectorAssembler 
from pyspark.ml.feature import OneHotEncoder
from pyspark.sql.functions import rand
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [120]:
# Create a spark session (which will run spark jobs)
spark = (
    SparkSession.builder.appName("MAST30034 Project 2")
    .config("spark.sql.repl.eagerEval.enabled", True) 
    .config("spark.sql.parquet.cacheMetadata", "true")
    .config("spark.sql.session.timeZone", "Etc/UTC")
    .config("spark.driver.memory", "2g")
    .config("spark.executer.memory", "4g")
    .getOrCreate()
)

## preprocessing

### load data

In [141]:
full = spark.read.parquet("../../data/curated/full_data/")

                                                                                

In [4]:
probs_merchant = spark.read.option('header', True).csv('../../data/tables/merchant_fraud_probability.csv')
probs_consumer= spark.read.option('header', True).csv('../../data/tables/consumer_fraud_probability.csv')

### Merge tables

In [5]:
full.printSchema()

root
 |-- user_id: long (nullable = true)
 |-- SA2_code: integer (nullable = true)
 |-- postcode: integer (nullable = true)
 |-- consumer_id: integer (nullable = true)
 |-- state: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- mean_total_income: integer (nullable = true)
 |-- merchant_abn: long (nullable = true)
 |-- dollar_value: double (nullable = true)
 |-- order_id: string (nullable = true)
 |-- order_datetime: date (nullable = true)
 |-- name: string (nullable = true)
 |-- tags: string (nullable = true)
 |-- revenue_level: string (nullable = true)
 |-- take_rate: double (nullable = true)



In [6]:
probs_consumer.printSchema()

root
 |-- user_id: string (nullable = true)
 |-- order_datetime: string (nullable = true)
 |-- fraud_probability: string (nullable = true)



In [7]:
probs_consumer =  probs_consumer.withColumn('user_id', F.col('user_id').cast('long'))\
                                .withColumn('fraud_probability', F.col('fraud_probability').cast('float'))

In [8]:
probs_merchant.printSchema()

root
 |-- merchant_abn: string (nullable = true)
 |-- order_datetime: string (nullable = true)
 |-- fraud_probability: string (nullable = true)



In [9]:
probs_merchant =  probs_merchant.withColumn('merchant_abn', F.col('merchant_abn').cast('long'))\
                                .withColumn('fraud_probability', F.col('fraud_probability').cast('float'))

In [142]:
# merge transaction file with merchants'/consumers' fraud probability based on merchant abn or user id respectively by left join
full = full.join(probs_merchant, on = ['merchant_abn', 'order_datetime'], how = 'left').withColumnRenamed('fraud_probability', 'merchant_prob')
full = full.join(probs_consumer, on = ['user_id', 'order_datetime'], how = 'left').withColumnRenamed('fraud_probability', 'consumer_prob')

In [143]:
# replace all the missing value with 0.01 as default fraud prob
full = full.na.fill(value=0.01, subset=['merchant_prob', 'consumer_prob'])

In [144]:
# We set benchmark as 5%, because we focus on False Positive instead of False Negative
full = full.withColumn('is_fraud', F.when((F.col('merchant_prob') > 5) | (F.col('consumer_prob') > 5), 1).otherwise(0))
full

user_id,order_datetime,merchant_abn,SA2_code,postcode,consumer_id,state,gender,mean_total_income,dollar_value,order_id,name,tags,revenue_level,take_rate,merchant_prob,consumer_prob,is_fraud
183,2022-01-31,38700038932,210051245,3043,133032,VIC,Female,58060,1115.1879869740476,d4432c96-0627-403...,Etiam Bibendum In...,tent,a,6.31,0.01,0.01,0
183,2022-08-30,29216160692,210051245,3043,133032,VIC,Female,58060,41.89601231453352,9971cdea-3d42-4e2...,Class Aptent LLC,garden supply,a,5.6,0.01,0.01,0
183,2022-09-15,10648956813,210051245,3043,133032,VIC,Female,58060,89.7561614941009,a6e6a33b-a7bc-402...,Proin Nisl Institute,computer,a,6.66,0.01,0.01,0
183,2022-01-29,68874243493,210051245,3043,133032,VIC,Female,58060,2456.739311532021,71d23bb1-bcb3-4ee...,Nullam Suscipit C...,telecom,c,2.52,0.01,9.307967,1
183,2021-07-08,90568944804,210051245,3043,133032,VIC,Female,58060,370.1733463859578,ba106433-7c52-49c...,Diam Eu Dolor LLC,tent,b,4.1,0.01,0.01,0
183,2021-03-10,90568944804,210051245,3043,133032,VIC,Female,58060,1374.0095428112525,de4e4ff5-f906-401...,Diam Eu Dolor LLC,tent,b,4.1,0.01,0.01,0
183,2022-07-18,62191208634,210051245,3043,133032,VIC,Female,58060,51.46301423810628,6d863060-aea1-45b...,Cursus Non Egesta...,furniture,c,2.17,0.01,0.01,0
183,2022-06-13,62191208634,210051245,3043,133032,VIC,Female,58060,24.48723535765797,42422047-e97c-401...,Cursus Non Egesta...,furniture,c,2.17,0.01,0.01,0
183,2021-09-27,15253672771,210051245,3043,133032,VIC,Female,58060,49.42210604823397,13fb71a4-a8cf-4a7...,Nascetur Ridiculu...,gift,b,3.26,0.01,0.01,0
183,2021-03-13,43321695302,210051245,3043,133032,VIC,Female,58060,25.146636745770262,bcf0651b-aeb1-43c...,Ante Limited,health,b,4.51,0.01,0.01,0


In [13]:
print('In {} transactions, {} are detected as fraud'.format(full.count(), full.filter(F.col('is_fraud') == 1).count()))

[Stage 15:>                                                       (0 + 12) / 12]

In 13614854 transactions, 75455 are detected as fraud


                                                                                

In [145]:
full = full.filter(F.col('dollar_value') >= 1).na.drop(subset = 'name')

In [15]:
full.printSchema()

root
 |-- user_id: long (nullable = true)
 |-- order_datetime: date (nullable = true)
 |-- merchant_abn: long (nullable = true)
 |-- SA2_code: integer (nullable = true)
 |-- postcode: integer (nullable = true)
 |-- consumer_id: integer (nullable = true)
 |-- state: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- mean_total_income: integer (nullable = true)
 |-- dollar_value: double (nullable = true)
 |-- order_id: string (nullable = true)
 |-- name: string (nullable = true)
 |-- tags: string (nullable = true)
 |-- revenue_level: string (nullable = true)
 |-- take_rate: double (nullable = true)
 |-- merchant_prob: float (nullable = false)
 |-- consumer_prob: float (nullable = false)
 |-- is_fraud: integer (nullable = false)



In [146]:
full = full.withColumn('take_rate', F.col('take_rate').cast('float'))

## Feature Engineering

### bin numeric features

In [17]:
#value_max = result.select('dollar_value').orderBy(F.col('dollar_value'),  ascending= False).collect()[0][0]

In [18]:
#value_min = result.select('dollar_value').orderBy(F.col('dollar_value')).collect()[0][0]

In [19]:
#print('dollar_value range:[{}, {}]'.format(value_min, value_max))

In [20]:
# bins = []
# i  = 0
# while i <= 70000:
#   bins.append(i)
#   i += 200

In [21]:
#bins.append(float('Inf'))

In [22]:
#bucketizer = Bucketizer(splits=bins, inputCol="dollar_value", outputCol="dollar_value_buckets")
#result = bucketizer.setHandleInvalid("keep").transform(result)

In [23]:
#result = result.withColumn('take_rate', F.col('take_rate').cast('double'))

In [24]:
# take_rate_max = result.select('take_rate').orderBy(F.col('take_rate'),  ascending= False).collect()[0][0]
# take_rate_min = result.select('take_rate').orderBy(F.col('take_rate')).collect()[0][0]

In [25]:
# print('take_rate range:[{}, {}]'.format(take_rate_min, take_rate_max))

In [26]:
# bucketizer = Bucketizer(splits=[0,1,2,3,4,5,6,7,8], inputCol="take_rate", outputCol="take_rate_buckets")
# result = bucketizer.setHandleInvalid("keep").transform(result)

In [27]:
# result.printSchema()

### Index ordinal features

In [147]:
full =  full.withColumn('month', F.month('order_datetime'))

In [29]:
indexed_features = ['revenue_level', 'tags', 'gender']

In [30]:
# We give all values in non-numeric features an index in order to make it ordinal or one-hot encoded
indexers =[]
for col in indexed_features:
  indexers.append(StringIndexer(inputCol=col, outputCol = col+"_index"))

indexers

[StringIndexer_cc3e79d29b85,
 StringIndexer_3d5eeabc05ec,
 StringIndexer_f8ab18b2597e]

In [31]:
pipeline = Pipeline(stages=indexers)

In [32]:
indexed_result = pipeline.fit(full).transform(full)

                                                                                

### One hot encoding

In [33]:
categorical_features =  ["tags_index", "gender_index","month"]

In [34]:
# one-hot-encoding the numeric indices
ohe = []
for f in categorical_features:
  ohe.append(OneHotEncoder(inputCol=f, outputCol=f+"OHE"))

ohe

[OneHotEncoder_8614ad60253a,
 OneHotEncoder_ff7769d88295,
 OneHotEncoder_05ca5783abb7]

In [35]:
pipeline = Pipeline(stages=ohe)

In [36]:
indexed_result = pipeline.fit(indexed_result).transform(indexed_result)

                                                                                

In [37]:
indexed_result.printSchema()

root
 |-- user_id: long (nullable = true)
 |-- order_datetime: date (nullable = true)
 |-- merchant_abn: long (nullable = true)
 |-- SA2_code: integer (nullable = true)
 |-- postcode: integer (nullable = true)
 |-- consumer_id: integer (nullable = true)
 |-- state: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- mean_total_income: integer (nullable = true)
 |-- dollar_value: double (nullable = true)
 |-- order_id: string (nullable = true)
 |-- name: string (nullable = true)
 |-- tags: string (nullable = true)
 |-- revenue_level: string (nullable = true)
 |-- take_rate: float (nullable = true)
 |-- merchant_prob: float (nullable = false)
 |-- consumer_prob: float (nullable = false)
 |-- is_fraud: integer (nullable = false)
 |-- month: integer (nullable = true)
 |-- revenue_level_index: double (nullable = false)
 |-- tags_index: double (nullable = false)
 |-- gender_index: double (nullable = false)
 |-- tags_indexOHE: vector (nullable = true)
 |-- gender_indexOHE: vec

### Feature Selection and Vectorization

In [38]:
feature_selected = ['dollar_value','take_rate','mean_total_income','monthOHE','revenue_level_index','tags_indexOHE','gender_indexOHE']

In [39]:
assembler = VectorAssembler(inputCols=feature_selected ,outputCol='features')

In [40]:
output = assembler.transform(indexed_result)

In [41]:
final_data = output.filter(F.col('order_datetime') < '2022-02-28').select('features','is_fraud')

In [42]:
(final_data.filter(F.col('is_fraud') == 1).count())/(final_data.filter(F.col('is_fraud') == 0).count())

                                                                                

0.009780748354606692

In [43]:
final_data.filter(F.col('is_fraud') == 1).count(), final_data.filter(F.col('is_fraud') == 0).count()

                                                                                

(75108, 7679167)

In [44]:
# imbalanced distribution of two classes. We decide to split the data according to their class and make the distribution balanced
fraud_data = final_data.filter(F.col('is_fraud') == 1)
normal_data = final_data.filter(F.col('is_fraud') == 0).randomSplit([0.01,0.99])[0]

## Split data

In [45]:
train_fraud,test_fraud = fraud_data.randomSplit([0.7,0.3])
train_normal,test_normal = normal_data.randomSplit([0.7,0.3])

In [46]:
train_data = train_fraud.union(train_normal).orderBy(rand())
test_data = test_fraud.union(test_normal).orderBy(rand())

In [47]:
final_data.count()

                                                                                

7754275

In [48]:
train_data.count(),test_data.count()

                                                                                

(106183, 45245)

## Logistic Regression

In [49]:
lr = LogisticRegression(labelCol='is_fraud')

In [50]:
fitted_model = lr.fit(train_data)
fitted_model.setFeaturesCol("features")
fitted_model.setPredictionCol("prediction")

22/09/20 17:46:20 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS
22/09/20 17:46:20 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeRefBLAS
                                                                                

LogisticRegressionModel: uid=LogisticRegression_1e777df2ea40, numClasses=2, numFeatures=41

### Evaluation

In [51]:
pred_and_labels = fitted_model.evaluate(test_data)

                                                                                

In [52]:
score_and_label = pred_and_labels.predictions.select('prediction', 'is_fraud').withColumnRenamed('is_fraud', 'label')

In [53]:
evaluator = MulticlassClassificationEvaluator()

In [54]:
evaluator.setPredictionCol("prediction")

MulticlassClassificationEvaluator_d4811fb883e0

In [55]:
print("Accuracy: " + str(evaluator.evaluate(score_and_label)))



Accuracy: 0.737546934063668


                                                                                

In [56]:
tp = score_and_label.filter((F.col('prediction') == 1) & (F.col('label') == 1)).count()

                                                                                

In [57]:
fn = score_and_label.filter((F.col('prediction') == 0) & (F.col('label') == 1)).count()

                                                                                

In [58]:
recall = tp/(tp+fn)
print('recall: ' + str(recall))

recall: 0.6872505493519889


## 

## Prediction

In [59]:
train_data = output.filter(F.col('order_datetime') < '2022-02-28')
predict_data = output.filter(F.col('order_datetime') >= '2022-02-28')

In [60]:
fraud_data = train_data.filter(F.col('is_fraud') == 1)
normal_data = train_data.filter(F.col('is_fraud') == 0).randomSplit([0.01,0.99])[0]

In [61]:
train_data = fraud_data.union(normal_data).orderBy(rand())

In [62]:
lr_final = LogisticRegression(labelCol='is_fraud')

In [64]:
fitted_model = lr_final.fit(train_data)

                                                                                

In [65]:
fitted_model.setFeaturesCol('features')
fitted_model.setPredictionCol('is_fraud')

LogisticRegressionModel: uid=LogisticRegression_6c00fda8f4a3, numClasses=2, numFeatures=41

In [66]:
predict_data = predict_data.drop("is_fraud")
predicted = fitted_model.transform(predict_data)

In [67]:
cols = ['user_id','order_datetime','merchant_abn','SA2_code','postcode','consumer_id','state','gender','mean_total_income','dollar_value','order_id','name','tags','revenue_level','take_rate','is_fraud'	]
predicted =  predicted.select(cols)

In [148]:
full = full.drop('merchant_prob', 'consumer_prob', 'month')

In [69]:
predicted.count()

5749867

In [149]:
full = full.filter(F.col('order_datetime') < '2022-02-28')

In [150]:
full = full.union(predicted)

In [151]:
full.printSchema()

root
 |-- user_id: long (nullable = true)
 |-- order_datetime: date (nullable = true)
 |-- merchant_abn: long (nullable = true)
 |-- SA2_code: integer (nullable = true)
 |-- postcode: integer (nullable = true)
 |-- consumer_id: integer (nullable = true)
 |-- state: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- mean_total_income: integer (nullable = true)
 |-- dollar_value: double (nullable = true)
 |-- order_id: string (nullable = true)
 |-- name: string (nullable = true)
 |-- tags: string (nullable = true)
 |-- revenue_level: string (nullable = true)
 |-- take_rate: float (nullable = true)
 |-- is_fraud: double (nullable = false)



In [152]:
full.cache()

                                                                                

user_id,order_datetime,merchant_abn,SA2_code,postcode,consumer_id,state,gender,mean_total_income,dollar_value,order_id,name,tags,revenue_level,take_rate,is_fraud
183,2022-01-31,38700038932,210051245,3043,133032,VIC,Female,58060,1115.1879869740476,d4432c96-0627-403...,Etiam Bibendum In...,tent,a,6.31,0.0
183,2022-01-29,68874243493,210051245,3043,133032,VIC,Female,58060,2456.739311532021,71d23bb1-bcb3-4ee...,Nullam Suscipit C...,telecom,c,2.52,1.0
183,2021-07-08,90568944804,210051245,3043,133032,VIC,Female,58060,370.1733463859578,ba106433-7c52-49c...,Diam Eu Dolor LLC,tent,b,4.1,0.0
183,2021-03-10,90568944804,210051245,3043,133032,VIC,Female,58060,1374.0095428112525,de4e4ff5-f906-401...,Diam Eu Dolor LLC,tent,b,4.1,0.0
183,2021-09-27,15253672771,210051245,3043,133032,VIC,Female,58060,49.42210604823397,13fb71a4-a8cf-4a7...,Nascetur Ridiculu...,gift,b,3.26,0.0
183,2021-03-13,43321695302,210051245,3043,133032,VIC,Female,58060,25.146636745770262,bcf0651b-aeb1-43c...,Ante Limited,health,b,4.51,0.0
183,2021-06-15,21532935983,210051245,3043,133032,VIC,Female,58060,22.840601049120835,4224f811-80d2-43e...,Eleifend Nec Inco...,cable,a,5.58,0.0
183,2021-12-29,70172340121,210051245,3043,133032,VIC,Female,58060,89.97820881579668,38f2deb0-6e4c-489...,Justo Eu Incorpor...,gift,b,3.29,0.0
183,2021-11-12,49891706470,210051245,3043,133032,VIC,Female,58060,46.8170732849845,d0de8b61-e555-441...,Non Vestibulum In...,tent,a,5.8,0.0
183,2021-10-01,49891706470,210051245,3043,133032,VIC,Female,58060,50.57450884244343,49f6dc67-1c48-47e...,Non Vestibulum In...,tent,a,5.8,0.0


In [153]:
full.show()

+-------+--------------+------------+---------+--------+-----------+-----+------+-----------------+------------------+--------------------+--------------------+-------------+-------------+---------+--------+
|user_id|order_datetime|merchant_abn| SA2_code|postcode|consumer_id|state|gender|mean_total_income|      dollar_value|            order_id|                name|         tags|revenue_level|take_rate|is_fraud|
+-------+--------------+------------+---------+--------+-----------+-----+------+-----------------+------------------+--------------------+--------------------+-------------+-------------+---------+--------+
|    183|    2022-01-31| 38700038932|210051245|    3043|     133032|  VIC|Female|            58060|1115.1879869740476|d4432c96-0627-403...|Etiam Bibendum In...|         tent|            a|     6.31|     0.0|
|    183|    2022-01-29| 68874243493|210051245|    3043|     133032|  VIC|Female|            58060|2456.7393115320206|71d23bb1-bcb3-4ee...|Nullam Suscipit C...|      te

In [155]:
## underlying files have been updated. explicitly invalidate the cache
full.createOrReplaceTempView("full_data_with_lable")

22/09/20 19:24:44 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


In [156]:
query = "REFRESH TABLE full_data_with_lable" 

In [157]:
spark.sql(query)

In [159]:
full.write.format('parquet').mode('overwrite').save("../../data/curated/full_data")

22/09/20 19:39:11 WARN MemoryStore: Not enough space to cache rdd_2157_20 in memory! (computed 12.3 MiB so far)
22/09/20 19:39:11 WARN MemoryStore: Not enough space to cache rdd_2157_21 in memory! (computed 12.3 MiB so far)
                                                                                

In [164]:
full.unpersist()

False