In [10]:
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.feature import OneHotEncoder, StringIndexer, VectorAssembler
from pyspark.ml import Pipeline
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, split, when
import importlib
import numpy as np

In [2]:
from modules.load import read_file

In [3]:
spark = SparkSession.builder.appName("Test").getOrCreate()

# Data Preparation

In [4]:
df = read_file(spark, "data/transactions.ndjson").drop("accountOpenDate", "currentExpDate", "transactionDateTime", "accountNumber", 
                  "dateOfLastAddressChange", "acqCountry",
                  "transactionType", "enteredCVV", "merchantCategoryCode","customerId")

In [5]:
# Clean up the "merchantName" column
df = df.withColumn(
    "merchantName",
    when(col("merchantName").contains(" #"), split(col("merchantName"), " #").getItem(0))
                            .when(col("merchantName").contains("Blue Mountain"), "Blue Mountain")
                            .when(col("merchantName").contains("ethanallen.com"), "Ethan Allen")
                            .when(col("merchantName").contains("pottery-barn.com"), "Pottery Barn")
                            .when(col("merchantName").contains("westelm.com"), "West Elm")
                            .when(col("merchantName").contains("williamssonoma"), "Williams Sonoma")
                            .otherwise(col("merchantName"))
)
df = df.withColumn("isFraud", col("isFraud").cast("double"))

In [6]:
string_columns = ["cardCVV", "cardLast4Digits", "merchantName",
                  "merchantCountryCode", "posConditionCode", "posEntryMode"]
# Ignored columns:
# "accountNumber", "customerId"
#                  "merchantName", "dateOfLastAddressChange", "acqCountry",
#                  "transactionType", "enteredCVV", "merchantCategoryCode",

In [7]:
df = df.drop(*np.array(df.columns)[["Encoded" in col for col in df.columns]].tolist())

In [8]:
df.columns

['availableMoney',
 'cardCVV',
 'cardLast4Digits',
 'cardPresent',
 'creditLimit',
 'currentBalance',
 'expirationDateKeyInMatch',
 'isFraud',
 'merchantCountryCode',
 'merchantName',
 'posConditionCode',
 'posEntryMode',
 'transactionAmount',
 'creditLimitIndexed',
 'merchantNameIndexed',
 'acqCountryIndexed',
 'merchantCountryCodeIndexed',
 'posEntryModeIndexed',
 'posConditionCodeIndexed',
 'merchantCategoryCodeIndexed',
 'transactionTypeIndexed']

# Decision Tree

In [None]:
for col in td.columns:
    count = td.select(col).distinct().count()
    print(col,", ", count)

In [29]:
assembler = VectorAssembler(inputCols=["transactionAmount"], outputCol="feature")
df = assembler.transform(df)

In [31]:
dtc = DecisionTreeClassifier(featuresCol="feature", labelCol="isFraud")
model = dtc.fit(df)

In [50]:
model.transform(df).show()

+--------------+-------+---------------+-----------+-----------+--------------+------------------------+-------+-------------------+--------------------+----------------+------------+-----------------+------------------+-------------------+-----------------+--------------------------+-------------------+-----------------------+---------------------------+----------------------+--------+------------------+--------------------+----------+
|availableMoney|cardCVV|cardLast4Digits|cardPresent|creditLimit|currentBalance|expirationDateKeyInMatch|isFraud|merchantCountryCode|        merchantName|posConditionCode|posEntryMode|transactionAmount|creditLimitIndexed|merchantNameIndexed|acqCountryIndexed|merchantCountryCodeIndexed|posEntryModeIndexed|posConditionCodeIndexed|merchantCategoryCodeIndexed|transactionTypeIndexed| feature|     rawPrediction|         probability|prediction|
+--------------+-------+---------------+-----------+-----------+--------------+------------------------+-------+------