In [1]:
%load_ext nb_black
import findspark

findspark.init()
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
import pyspark.sql.types as T
import pandas as pd

spark = SparkSession.builder.getOrCreate()

<IPython.core.display.Javascript object>

In [2]:
spark

<IPython.core.display.Javascript object>

In [3]:
df = spark.read.csv("synth_composite.csv", inferSchema=True, header=True)

<IPython.core.display.Javascript object>

In [4]:
df.printSchema()

root
 |-- step: integer (nullable = true)
 |-- type: string (nullable = true)
 |-- amount: double (nullable = true)
 |-- nameOrig: string (nullable = true)
 |-- oldbalanceOrg: double (nullable = true)
 |-- newbalanceOrig: double (nullable = true)
 |-- nameDest: string (nullable = true)
 |-- oldbalanceDest: double (nullable = true)
 |-- newbalanceDest: double (nullable = true)
 |-- isFraud: integer (nullable = true)
 |-- isFlaggedFraud: integer (nullable = true)



<IPython.core.display.Javascript object>

In [5]:
df.show(3)

+----+--------+------------------+-------------+------------------+------------------+-------------+------------------+------------------+-------+--------------+
|step|    type|            amount|     nameOrig|     oldbalanceOrg|    newbalanceOrig|     nameDest|    oldbalanceDest|    newbalanceDest|isFraud|isFlaggedFraud|
+----+--------+------------------+-------------+------------------+------------------+-------------+------------------+------------------+-------+--------------+
| 211|CASH_OUT|184185.75527735116|160_2696646_0|203108.47605069657|18922.720773345412| 160_901564_0|1019993.7073250777| 1204127.136276725|      0|             0|
| 394|CASH_OUT|237093.19600862195|160_5585569_0| 382707.4642382287|145614.26822960674|160_1586945_0| 587089.2765989383| 824320.2512419948|      0|             0|
| 306| CASH_IN|163600.51879411662|160_4221372_0|  3093.43116922955|               0.0|160_1701821_0|  653453.795387125|490121.90304387547|      0|             0|
+----+--------+-------------

<IPython.core.display.Javascript object>

In [6]:
df = df.select("type", "amount", "oldbalanceOrg", "newbalanceOrig", "isFraud")

<IPython.core.display.Javascript object>

In [7]:
df.show(3)

+--------+------------------+------------------+------------------+-------+
|    type|            amount|     oldbalanceOrg|    newbalanceOrig|isFraud|
+--------+------------------+------------------+------------------+-------+
|CASH_OUT|184185.75527735116|203108.47605069657|18922.720773345412|      0|
|CASH_OUT|237093.19600862195| 382707.4642382287|145614.26822960674|      0|
| CASH_IN|163600.51879411662|  3093.43116922955|               0.0|      0|
+--------+------------------+------------------+------------------+-------+
only showing top 3 rows



<IPython.core.display.Javascript object>

In [8]:
train, test = df.randomSplit([0.7, 0.3], seed=7)

<IPython.core.display.Javascript object>

In [9]:
print(f"Train set length {train.count()} records")
print(f"Test set length {test.count()} records")

Train set length 4453091 records
Test set length 1909529 records


<IPython.core.display.Javascript object>

In [10]:
train.show(3)

+-------+------------------+-----------------+------------------+-------+
|   type|            amount|    oldbalanceOrg|    newbalanceOrig|isFraud|
+-------+------------------+-----------------+------------------+-------+
|CASH_IN|1.0234829428471104|20458.86177176414| 20457.83828882129|      0|
|CASH_IN| 2.519244678168343| 3215379.73345976|3215377.2142150817|      0|
|CASH_IN| 6.317780115151176|7730148.645905179| 7730142.328125064|      0|
+-------+------------------+-----------------+------------------+-------+
only showing top 3 rows



<IPython.core.display.Javascript object>

In [11]:
test.show(3)

+-------+------------------+--------------------+-------------------+-------+
|   type|            amount|       oldbalanceOrg|     newbalanceOrig|isFraud|
+-------+------------------+--------------------+-------------------+-------+
|CASH_IN|4.9256702943665065|   874573.2830826555|  874568.3574123612|      0|
|CASH_IN|10.730956338195247|   4178432.363640349| 4178421.6326840105|      0|
|CASH_IN| 13.44188791485196|1.0607694283374796E7|1.060768084148688E7|      0|
+-------+------------------+--------------------+-------------------+-------+
only showing top 3 rows



<IPython.core.display.Javascript object>

In [12]:
train.dtypes

[('type', 'string'),
 ('amount', 'double'),
 ('oldbalanceOrg', 'double'),
 ('newbalanceOrig', 'double'),
 ('isFraud', 'int')]

<IPython.core.display.Javascript object>

In [15]:
catCols = [x for (x, dataType) in train.dtypes if dataType == "string"]
numcols = [
    x for (x, dataType) in train.dtypes if ((dataType == "double") & (x != "iFfraud"))
]

<IPython.core.display.Javascript object>

In [16]:
print(catCols)
print(numcols)

['type']
['amount', 'oldbalanceOrg', 'newbalanceOrig']


<IPython.core.display.Javascript object>

In [18]:
train.agg(F.countDistinct("type")).show()

+-----------+
|count(type)|
+-----------+
|          5|
+-----------+



<IPython.core.display.Javascript object>

In [28]:
train.agg(F.countDistinct("isFraud")).show()

+--------------+
|count(isFraud)|
+--------------+
|             2|
+--------------+



<IPython.core.display.Javascript object>

In [19]:
train.groupBy("type").count().show()

+--------+-------+
|    type|  count|
+--------+-------+
|TRANSFER| 373016|
| CASH_IN| 978632|
|CASH_OUT|1566095|
| PAYMENT|1506235|
|   DEBIT|  29113|
+--------+-------+



<IPython.core.display.Javascript object>

In [29]:
train.groupBy("isFraud").count().show()

+-------+-------+
|isFraud|  count|
+-------+-------+
|      0|4447276|
|      1|   5815|
+-------+-------+



<IPython.core.display.Javascript object>

In [30]:
from pyspark.ml.feature import (
    OneHotEncoder,
    StringIndexer,
    VectorAssembler,
)

<IPython.core.display.Javascript object>

In [22]:
string_indexer = [
    StringIndexer(inputCol=x, outputCol=x + "_StringIndexer", handleInvalid="skip")
    for x in catCols
]

<IPython.core.display.Javascript object>

In [25]:
string_indexer

[StringIndexer_f53f70deef4b]

<IPython.core.display.Javascript object>

In [45]:
one_hot_encoder = [
    OneHotEncoder(
        inputCols=[f"{x}_StringIndexer" for x in catCols],
        outputCols=[f"{x}_OneHotEncoder" for x in catCols],
    )
]

<IPython.core.display.Javascript object>

In [46]:
one_hot_encoder

[OneHotEncoder_c03ec8e042eb]

<IPython.core.display.Javascript object>

In [47]:
assemblerInput = [x for x in numcols]
assemblerInput += [f"{x}_OneHotEncoder" for x in catCols]

<IPython.core.display.Javascript object>

In [48]:
assemblerInput

['amount', 'oldbalanceOrg', 'newbalanceOrig', 'type_OneHotEncoder']

<IPython.core.display.Javascript object>

In [49]:
vector_assembler = VectorAssembler(
    inputCols=assemblerInput, outputCol="VectorAssembler_Features"
)

<IPython.core.display.Javascript object>

In [50]:
vector_assembler

VectorAssembler_164dff710f48

<IPython.core.display.Javascript object>

In [51]:
stages = []
stages += string_indexer
stages += one_hot_encoder
stages += [vector_assembler]

<IPython.core.display.Javascript object>

In [52]:
stages

[StringIndexer_f53f70deef4b,
 OneHotEncoder_c03ec8e042eb,
 VectorAssembler_164dff710f48]

<IPython.core.display.Javascript object>

In [53]:
%%time
from pyspark.ml import Pipeline
pipeline=Pipeline().setStages(stages)
model=pipeline.fit(train)

pp_df=model.transform(test)

CPU times: total: 0 ns
Wall time: 10.8 s


<IPython.core.display.Javascript object>

In [55]:
pp_df.select(
    "type", "amount", "oldbalanceOrg", "newbalanceOrig", "VectorAssembler_Features"
).show(3, truncate=False)

+-------+------------------+--------------------+-------------------+----------------------------------------------------------------------------+
|type   |amount            |oldbalanceOrg       |newbalanceOrig     |VectorAssembler_Features                                                    |
+-------+------------------+--------------------+-------------------+----------------------------------------------------------------------------+
|CASH_IN|4.9256702943665065|874573.2830826555   |874568.3574123612  |[4.9256702943665065,874573.2830826555,874568.3574123612,0.0,0.0,1.0,0.0]    |
|CASH_IN|10.730956338195247|4178432.363640349   |4178421.6326840105 |[10.730956338195247,4178432.363640349,4178421.6326840105,0.0,0.0,1.0,0.0]   |
|CASH_IN|13.44188791485196 |1.0607694283374796E7|1.060768084148688E7|[13.44188791485196,1.0607694283374796E7,1.060768084148688E7,0.0,0.0,1.0,0.0]|
+-------+------------------+--------------------+-------------------+-------------------------------------------------

<IPython.core.display.Javascript object>

In [56]:
from pyspark.ml.classification import LogisticRegression

<IPython.core.display.Javascript object>

In [57]:
data = pp_df.select(
    F.col("VectorAssembler_Features").alias("features"),
    F.col("isFraud").alias("label"),
)

<IPython.core.display.Javascript object>

In [62]:
data.count()

1909529

<IPython.core.display.Javascript object>

In [58]:
data.show(5, truncate=False)

+----------------------------------------------------------------------------+-----+
|features                                                                    |label|
+----------------------------------------------------------------------------+-----+
|[4.9256702943665065,874573.2830826555,874568.3574123612,0.0,0.0,1.0,0.0]    |0    |
|[10.730956338195247,4178432.363640349,4178421.6326840105,0.0,0.0,1.0,0.0]   |0    |
|[13.44188791485196,1.0607694283374796E7,1.060768084148688E7,0.0,0.0,1.0,0.0]|0    |
|[15.038989314745036,1112719.4055444363,1112704.3665551215,0.0,0.0,1.0,0.0]  |0    |
|[16.271479745151268,7440869.306727739,7440853.035247994,0.0,0.0,1.0,0.0]    |0    |
+----------------------------------------------------------------------------+-----+
only showing top 5 rows



<IPython.core.display.Javascript object>

In [59]:
%%time
model=LogisticRegression().fit(data)

CPU times: total: 31.2 ms
Wall time: 38.7 s


<IPython.core.display.Javascript object>

In [60]:
model.summary.areaUnderROC

0.9918156925260609

<IPython.core.display.Javascript object>

In [61]:
model.summary.pr.show()

+-------------------+-------------------+
|             recall|          precision|
+-------------------+-------------------+
|                0.0| 0.9025157232704403|
|0.35904920767306087| 0.9025157232704403|
|0.47289407839866554| 0.6483704974271012|
| 0.5371142618849041| 0.5062893081761006|
|  0.579232693911593| 0.4159928122192273|
| 0.6167639699749792|0.35776487663280115|
| 0.6484570475396163| 0.3154798133495638|
| 0.6763969974979149| 0.2833682739343117|
|   0.69557964970809| 0.2558674643350207|
| 0.7139282735613011|0.23407164342357123|
|   0.73185988323603|0.21642619311875694|
|  0.744370308590492|0.20047169811320756|
| 0.7593828190158466| 0.1877513145685122|
| 0.7710592160133445|0.17619592147894034|
| 0.7881567973311092| 0.1674196120116928|
|   0.79232693911593|0.15723270440251572|
| 0.7960800667222686|    0.1482257939281|
| 0.7998331943286072|0.14026619862512799|
| 0.8023352793994996|  0.132973944294699|
| 0.8056713928273561|0.12657232704402516|
+-------------------+-------------

<IPython.core.display.Javascript object>