# Real-Time Credit Card Fraud Detection

## Basic Imports and Settings

In [None]:
# import modules from pyspark
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql.functions import *
from pyspark.sql import SQLContext
import pandas as pd
import matplotlib.pyplot as plt
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml import Pipeline

# uncomment the following line if running pyspark from the notebook itself
# spark = SparkSession.builder.enableHiveSupport().getOrCreate()
spark.sparkContext.setLogLevel("ERROR")
sqlContext = SQLContext(spark)

## Loading the Dataset File and Performing Basic Data Type Conversions

Source of the dataset: https://www.kaggle.com/datasets/kartik2112/fraud-detection

In [None]:
# define a reusable schema for the dataset (will be useful for the real-time portion)
ccschema = StructType([
    StructField("_c0", IntegerType(), True),
    StructField("trans_date_trans_time", TimestampType(), True),
    StructField("cc_num", StringType(), True),
    StructField("merchant", StringType(), True),
    StructField("category", StringType(), True),
    StructField("amt", DoubleType(), True),
    StructField("first", StringType(), True),
    StructField("last", StringType(), True),
    StructField("gender", StringType(), True),
    StructField("street", StringType(), True),
    StructField("city", StringType(), True),
    StructField("state", StringType(), True),
    StructField("zip", StringType(), True),
    StructField("lat", DoubleType(), True),
    StructField("long", DoubleType(), True),
    StructField("city_pop", DoubleType(), True),
    StructField("job", StringType(), True),
    StructField("dob", DateType(), True),
    StructField("trans_num", StringType(), True),
    StructField("unix_time", StringType(), True),
    StructField("merch_lat", DoubleType(), True),
    StructField("merch_long", DoubleType(), True),
    StructField("is_fraud", IntegerType(), True),
])

In [None]:
# Read The data from local folder
cc = (spark.read.csv("fraudTrain.csv", schema=ccschema, header=True))

In [None]:
# Read The data in the context of Databricks
# cc = (spark.read.csv("s3://group9-ml-project/fraudTrain.csv", schema=ccschema, header=True))

In [None]:
# use this code to create some sample data from the main dataset for the purpose
# of loading transactions for the real time portion

# cc.limit(2).write.option("header",False).csv("sample")

## Exploratory Data Analysis

### A Look at the Data and its Basic Statistics

In [None]:
# let's look a the first 5 rows
cc.columns

In [None]:
# basic statistics
cc.describe().toPandas()

In [None]:
# looking to see if there are null values
cc.select(*(sum(col(c).isNull().cast("int")).alias(c) for c in cc.columns)).toPandas()

### Visualizing the Data Distribution

In [None]:
fig, axs = plt.subplots(4 , 2, figsize=(15, 20))
fig.suptitle('CC Fraud Data Distribution')

for idx, column in enumerate(['amt', 'city_pop', 'lat', 'long', 'merch_lat', 'merch_long', 'unix_time', 'is_fraud']):
    # Show histogram of the column
    bins, counts = cc.select(column).rdd.flatMap(lambda x: x).map(float).histogram(20)
    axs[idx//2][idx%2].set_title(column)
    axs[idx//2][idx%2].hist(bins[:-1], bins=bins, weights=counts)
    
plt.show()

In [None]:
cc.select("category").groupby("category").count().toPandas()

In [None]:
cc.select("gender").groupby("gender").count().toPandas()

### Fraud Ratio by Age Group

Is fraud more prevalent in certain age groups than others? The analysis below suggests that seniors aged 80+ are more exposed to fraudulent transactions than other groups. This makes age group a potentially viable feature.

In [None]:
# calculate age group from date of birth. Notice the use of the floor function, which
# means that a person in the 30-40 age group would be deemed belonging to the "30" group
cc_age = cc.withColumn("age_group", 
                       floor(months_between(current_date(), 
                        col("dob"))/12/10)*10)

In [None]:
# get histogram of fraudulent transactions by age group
bins, counts = (cc_age.where(col("is_fraud")==1.0)
                      .select("age_group").rdd.flatMap(lambda x: x)
                               .map(float).histogram(list(range(0,110,10))))

In [None]:
# get histogram of non-fraudulent transactions by age group
bins2, counts2 = (cc_age.where(col("is_fraud")==0.0)
                      .select("age_group").rdd.flatMap(lambda x: x)
                               .map(float).histogram(list(range(0,110,10))))

In [None]:
# calculate ratios between fraudulent and non-fraudulent transactions
def safediv(arg1, arg2):
    return arg1 / arg2 if (arg2 != 0) else 0;

ratios = list(map(safediv, counts, counts2))
# ratios

In [None]:
import seaborn as sns

plt.figure(figsize = (10,5))
hist = pd.DataFrame(zip(bins,ratios), columns=['age_group','fraud_ratio'])
sns.barplot(hist, x="age_group", y="fraud_ratio").set(title='Fraud Ratio by Age Group')
plt.show()

### EDA Preliminary Findings

- The data appears to be clean with no missing values
- Some of the heavily skewed features like amt and city_pop may benefit from logarithmic transformation
- The target class (is_fraud) is heavily imbalanced

### Feature Engineering

### Last 24 Hour Transactions

In [None]:
txByWindow = cc.groupby("cc_num",window(cc.trans_date_trans_time, "1 hour")).count()

In [None]:
pd.DataFrame(txByWindow.take(20), columns=txByWindow.columns)

In [None]:
cc_window = cc.withColumn("window", window(cc.trans_date_trans_time, "1 hour"))

In [None]:
cc_joined = cc_window.join(txByWindow, ['cc_num', 'window'], "outer").withColumnRenamed("count", "fraud_count")

In [None]:
# pd.DataFrame(cc_joined.where(col("is_fraud")==1).take(40), cc_joined.columns)

cc_fraud_counts = cc_joined.where(col("is_fraud")==1).select("fraud_count").groupby("fraud_count").count()

In [None]:
cc_fraud_counts.orderBy("fraud_count").show()

In [None]:
# cc_sl = cc.withColumn('ts_seconds', col("trans_date_trans_time").cast('long'))

In [None]:
# from pyspark.sql.window import Window

#w = (Window()
#     .partitionBy(col("cc_num"))
#     .orderBy('ts_seconds')
#     .rangeBetween(-60*60*24, Window.currentRow)
#     )

#df1 = (cc_sl
#       .withColumn('txns_last24', count("*").over(w))
#       .orderBy(desc("ts_seconds"))
#       )

In [None]:
cc_window = cc.withColumn('window', window(col("trans_date_trans_time"), "24 hours", "1 hour"))

In [None]:
# Group the data by window and word and compute the count of each group
windowedCounts = cc.groupBy(
    window(cc.trans_date_trans_time, "24 hours", "1 hour"),
    cc.cc_num
).agg(count('cc_num').alias('txns_last24'))

In [None]:
cc_counts = cc_window.join(windowedCounts, ["window", "cc_num"])

In [None]:
pd.DataFrame(cc_counts.take(10), columns=cc_counts.columns)

In [None]:
# pd.DataFrame(df1.where(col("txns_last_hour")>3).take(20), columns=df1.columns)

In [None]:
# df2 = (df1
#       .withColumn('amt_last24', sum("amt").over(w))
#       .orderBy(desc("amt"))
#    )

In [None]:
# pd.DataFrame(df2.where(col("txns_last_hour")>3).take(20), columns=df2.columns)

### Distance from Home

In [None]:
# use haversine formula to calculate distance from home
cc_dist = cc.withColumn('dist_kms' , \
            round((acos((sin(radians(col("lat"))) * sin(radians(col("merch_lat")))) + \
                   ((cos(radians(col("lat"))) * cos(radians(col("merch_lat")))) * \
                    (cos(radians(col("long")) - radians(col("merch_long")))))
                       ) * lit(6371.0)), 0))

In [None]:
pd.DataFrame(cc_dist.select(["lat", "long", "merch_lat", "merch_long", "dist_kms"]).take(10), 
             columns=["lat", "long", "merch_lat", "merch_long" , "dist_kms"])

In [None]:
cc_dist.select("dist_kms").describe().toPandas()

In [None]:
# get histogram of fraudulent transactions by distance from home
bins_dist, counts_dist = (cc_dist.where(col("is_fraud")==1.0)
                      .select("dist_kms").rdd.flatMap(lambda x: x)
                               .map(float).histogram(10))

In [None]:
import seaborn as sns

plt.figure(figsize = (8,4))
hist_dist = pd.DataFrame(zip(bins_dist,counts_dist), columns=['distance','count'])
sns.barplot(hist_dist, x="distance", y="count").set(title='Fraud Count by Distance from Home')
plt.show()

## ML Pipeline Setup

In [None]:
# define logarithmic transformer
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
from pyspark import keyword_only  # Note: use pyspark.ml.util.keyword_only if Spark < 2.0
from pyspark.ml import Transformer
from pyspark.ml.param.shared import HasInputCol, HasOutputCol, Param, Params, TypeConverters
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
 
class LogTransformer(Transformer,               # Base class
                     HasInputCol,               # Sets up an inputCol parameter
                     HasOutputCol,              # Sets up an outputCol parameter
                     DefaultParamsReadable,     # Makes parameters readable from file
                     DefaultParamsWritable      # Makes parameters writable from file
                    ):
  
    @keyword_only
    def __init__(self, inputCol=None, outputCol=None, append_str=None):
        """
        Constructor: set values for all Param objects
        """
        super().__init__()
        self._setDefault()
        kwargs = self._input_kwargs
        self.setParams(**kwargs)
  
    @keyword_only
    def setParams(self, inputCol=None, outputCol=None):
        kwargs = self._input_kwargs
        return self._set(**kwargs)
  
    # Required if you use Spark >= 3.0
    def setInputCol(self, new_inputCol):
        return self.setParams(inputCol=new_inputCol)
  
    # Required if you use Spark >= 3.0
    def setOutputCol(self, new_outputCol):
        return self.setParams(outputCol=new_outputCol)
  
    def _transform(self, dataset):
        """
        This is the main member function which applies the transform to transform data from the `inputCol` to the `outputCol`
        """
        if not self.isSet("inputCol"):
            raise ValueError(
                "No input column set for the "
                "LogTransformer transformer."
            )
        input_column = self.getInputCol()
        output_column = self.getOutputCol()

        return dataset.withColumn(output_column,
                                  log(col(input_column)))

In [None]:
from pyspark.ml.feature import StringIndexer

# define a transformer to convert string categorical features to numeric indices
inputs = ['merchant', 'category', 'gender', 'city', 'state', 'job']
outputs = ['merchant_idx', 'category_idx', 'gender_idx', 'city_idx', 'state_idx', 'job_idx']
stringIndexer = StringIndexer(inputCols=inputs, outputCols=outputs)


In [None]:
from pyspark.ml.feature import OneHotEncoder

# define a transformer to one-hot encode indexed categorical features
inputs_1hot = ['merchant_idx', 'category_idx', 'city_idx', 'state_idx', 'job_idx', 'hour_of_day']
outputs_1hot = ['merchant_1hot', 'category_1hot', 'city_1hot', 'state_1hot', 'job_1hot', 'hour_of_day_1hot']

oneHotEncoder = OneHotEncoder(inputCols=inputs_1hot, outputCols=outputs_1hot)


In [None]:
from pyspark.ml.feature import VectorAssembler

# assemble the prepped features into one single vector.
#featureCols = ['amt_log', 'city_pop_log', 'job_1hot', 'state_1hot', 
#               'category_1hot', 'gender_idx', 'hour_of_day_1hot', 'txns_last_hour']
featureCols = ['amt_log', 'city_pop_log', 'job_1hot', 'state_1hot', 'txns_last_24h',
               'category_1hot', 'gender_idx', 'hour_of_day_1hot', 'dist_kms']
assembler = (VectorAssembler()
  .setInputCols(featureCols)
  .setOutputCol("features"))

# cc_final = assembler.transform(cc_prepped)

In [None]:
amtTransformer = LogTransformer(inputCol="amt", outputCol="amt_log")
cityPopTransformer = LogTransformer(inputCol="city_pop", outputCol="city_pop_log")

In [None]:
from pyspark.ml.feature import SQLTransformer

colTransformer = SQLTransformer(
    statement = """SELECT *, hour(trans_date_trans_time) hour_of_day,
                             log(amt) amt_log,
                             log(city_pop) city_pop_log
                    FROM __THIS__
                """
)

In [None]:
from pyspark.ml.feature import SQLTransformer

addWindow = SQLTransformer(
    statement = """SELECT *, window(trans_date_trans_time, '24 hours', '1 hour') time_window
                    FROM __THIS__
                """
)

In [None]:
myTest = addWindow.transform(cc)
pd.DataFrame(myTest.take(10), columns=myTest.columns)

In [None]:
addCount = SQLTransformer(
    statement = """SELECT *, (SELECT count() FROM __THIS__ m2 
                                where m2.cc_num = cc_num and m2.time_window = time_window) as txns_last24
                   FROM __THIS__
                """
)

In [None]:
myCounts = addCount.transform(myTest)
pd.DataFrame(myCounts.take(10), columns=myCounts.columns)

In [None]:
from pyspark.ml.feature import SQLTransformer

dist = SQLTransformer(
    statement = """SELECT *, 
                    round((acos((sin(radians(lat)) * sin(radians(merch_lat))) + 
                   ((cos(radians(lat)) * cos(radians(merch_lat))) * 
                    (cos(radians(long) - radians(merch_long))))) * 6371.0), 0)
                    as dist_kms                       
                    FROM __THIS__
                """
)

In [None]:
from pyspark.ml.feature import SQLTransformer

# calculate fraudulent transaction ratio
fraudRatio = cc.filter(col("is_fraud")==1).count() / cc.count()

weightCalc = SQLTransformer(
    statement = "SELECT *, CASE WHEN is_fraud = 1 THEN " + 
                str(1-fraudRatio) + " ELSE " + 
                str(fraudRatio) + " END AS weight FROM __THIS__"
)

In [None]:
# from pyspark.ml.feature import SQLTransformer

# tsSec = SQLTransformer(
#     statement = """SELECT *, CAST(trans_date_trans_time AS LONG) AS ts_seconds
#                     FROM __THIS__"""
# )

In [None]:
# cc_sec = tsSec.transform(cc)

In [None]:
# pd.DataFrame(cc_sec.take(20), columns=cc_sec.columns)

In [None]:
from pyspark.ml.feature import SQLTransformer

# txLast24 = SQLTransformer(
#    statement = """SELECT *, COUNT(*) OVER (PARTITION BY cc_num ORDER BY ts_seconds 
#                                RANGE BETWEEN 86400 PRECEDING AND CURRENT ROW) AS txns_last_24h
#                    FROM __THIS__"""
#)

txLast24 = SQLTransformer(
    statement = """SELECT *, COUNT(*) OVER (PARTITION BY cc_num ORDER BY trans_date_trans_time 
                                RANGE BETWEEN INTERVAL 24 hours PRECEDING AND CURRENT ROW) AS txns_last_24h
                    FROM __THIS__"""
)

In [None]:
cc_last24 = txLast24.transform(cc)
pd.DataFrame(cc_last24.take(10), columns=cc_last24.columns)

### Training and Test Setup

In [None]:
from pyspark.sql.functions import percent_rank
from pyspark.sql import Window

# as our dataset is a time series, we do not want to randomly split it
# so we will split by using the rank() function

cc = cc.withColumn("rank", percent_rank().over(Window.partitionBy().orderBy("trans_date_trans_time")))

training = cc.where("rank <= .8").drop("rank")
test = cc.where("rank > .8").drop("rank")

# training, test = cc.randomSplit([0.7, 0.3])

print(training.count())
print(test.count())

In [None]:
# weighted_cc = weightCalc.transform(cc)

In [None]:
# pd.DataFrame(weighted_cc.take(10), columns=weighted_cc.columns)

In [None]:
# temp_df = colTransformer.transform(training)

In [None]:
# pd.DataFrame(temp_df.take(10), columns=temp_df.columns)

## ML Training and Prediction - RandomForestClassifier

In [None]:
# from pyspark.sql.window import Window

# cc = cc.withColumn('ts_seconds', col("trans_date_trans_time").cast('long'))

# w = (Window()
#     .partitionBy(col("cc_num"))
#     .orderBy('ts_seconds')
#     .rangeBetween(-60*60*24, Window.currentRow)
#     )

#cc = (cc
#       .withColumn('txns_last_24h', count("*").over(w))
#       .orderBy(desc("ts_seconds"))
#       )

In [None]:
# calculate fraudulent transaction ratio
# fraudRatio = cc.filter(col("is_fraud")==1).count() / cc.count()
# fraudRatio

In [None]:
# weighted_cc =  cc.withColumn("weight", when(col("is_fraud")==1.0, 1-fraudRatio).otherwise(fraudRatio))

In [None]:
# weighted_cc =  weighted_cc.withColumn('hour_of_day', hour('trans_date_trans_time'))

In [None]:
# my_cc = weighted_cc.withColumn('hour_of_day2', expr("hour(trans_date_trans_time)"))

In [None]:
# pd.DataFrame(my_cc.where(col("cc_num")=='630423337322').take(10), columns=my_cc.columns)

In [None]:
# pd.DataFrame(training.where(col("txns_last_hour")>5).take(10), columns=training.columns)

In [None]:
# pd.DataFrame(training.where(col("cc_num")==3573030041201292).take(20), columns=training.columns)

In [None]:
# pd.DataFrame(training.where(col("is_fraud")==1.0).groupBy("txns_last_hour").count().head(10), columns=['txns_last_hour', 'count'])

In [None]:
rf = RandomForestClassifier(numTrees=10, maxDepth=5, labelCol="is_fraud", seed=42,
    leafCol="leafId")
rf.setFeaturesCol("features")
rf.setWeightCol("weight")

# define pipeline using previously defined stages
# pipeline = Pipeline(stages=[stringIndexer, oneHotEncoder, amtTransformer, cityPopTransformer, assembler, rf])
pipeline = Pipeline(stages=[txLast24, colTransformer, dist,
                            weightCalc, stringIndexer, oneHotEncoder, assembler, rf])

model = pipeline.fit(training)

In [None]:
# model.stages[-1].featureImportances

In [None]:
preds = model.transform(test) 

In [None]:
def ExtractFeatureImp(featureImp, dataset, featuresCol):
    list_extract = []
    for i in dataset.schema[featuresCol].metadata["ml_attr"]["attrs"]:
        list_extract = list_extract + dataset.schema[featuresCol].metadata["ml_attr"]["attrs"][i]
    varlist = pd.DataFrame(list_extract)
    varlist['score'] = varlist['idx'].apply(lambda x: featureImp[x])
    return(varlist.sort_values('score', ascending = False))



In [None]:
features = ExtractFeatureImp(model.stages[-1].featureImportances, preds, "features").head(10)

In [None]:
import seaborn as sns

# Now let's plot and set the xlim to not exceed 100%
plt.figure(figsize=(9, 3))
sns.barplot(x='score', y='name', data=features, orient='h')
plt.title('Feature Importances')
plt.xlabel('Feature Score')
plt.ylabel('Feature')

# Set x-axis limit to not exceed 1
plt.xlim(0, 0.5)
plt.show()

In [None]:
from pyspark.sql.types import DoubleType
from pyspark.sql.functions import udf, col

def extract_prob(v):
    try:
        return float(v[1])  # Your VectorUDT is of length 2
    except ValueError:
        return None

extract_prob_udf = udf(extract_prob, DoubleType())

df2 = preds.withColumn("prob_1", extract_prob_udf(col("probability")))

pd.DataFrame(df2.where(col("prediction")==1).where(col("is_fraud")==0).take(20), columns=df2.columns)

In [None]:
df3 = df2.withColumn("pred_adj",when(col("prob_1")>.65, 1.0).otherwise(0.0))

In [None]:
pd.DataFrame(df3.head(20), columns=df3.columns)

In [None]:
from pyspark.mllib.evaluation import MulticlassMetrics

# create confusion matrix

preds_float2 = df3 \
    .select("pred_adj", "is_fraud") \
    .withColumn("is_fraud", col("is_fraud").cast(DoubleType())) \
    .orderBy("pred_adj")

cm3 = MulticlassMetrics(preds_float2.rdd.map(tuple))

# print(cm.confusionMatrix().toArray())

#show the confusion matrix as a pandas df for clearer presentation
pd.DataFrame(cm3.confusionMatrix().toArray(),
             columns= ["predicted (0)", "predicted (1)"],
             index= ["actual (0)", "actual (1)"])

In [None]:
# print overall classification stats

precision = cm3.precision(1.0)
recall = cm3.recall(1.0)
f1Score = cm3.fMeasure(1.0)
print("Summary Stats")
print("Precision = %s" % precision)
print("Recall = %s" % recall)
print("F1 Score = %s" % f1Score)

In [None]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator

# Instantiate the evaluator
bce= BinaryClassificationEvaluator(rawPredictionCol= "rawPrediction",
                                   labelCol="is_fraud", 
                                   metricName= "areaUnderPR")
                                   
bce.evaluate(preds)

In [None]:
from pyspark.mllib.evaluation import MulticlassMetrics

# create confusion matrix

preds_float = preds \
    .select("prediction", "is_fraud") \
    .withColumn("is_fraud", col("is_fraud").cast(DoubleType())) \
    .orderBy("prediction")

cm = MulticlassMetrics(preds_float.rdd.map(tuple))

# print(cm.confusionMatrix().toArray())

#show the confusion matrix as a pandas df for clearer presentation
pd.DataFrame(cm.confusionMatrix().toArray(),
             columns= ["predicted (0)", "predicted (1)"],
             index= ["actual (0)", "actual (1)"])

In [None]:
# print overall classification stats

precision = cm.precision(1.0)
recall = cm.recall(1.0)
f1Score = cm.fMeasure(1.0)
print("Summary Stats")
print("Precision = %s" % precision)
print("Recall = %s" % recall)
print("F1 Score = %s" % f1Score)

In [None]:
# source: https://stackoverflow.com/questions/52847408/pyspark-extract-roc-curve

from pyspark.mllib.evaluation import BinaryClassificationMetrics

# Scala version implements .roc() and .pr()
# Python: https://spark.apache.org/docs/latest/api/python/_modules/pyspark/mllib/common.html
# Scala: https://spark.apache.org/docs/latest/api/java/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.html
class CurveMetrics(BinaryClassificationMetrics):
    def __init__(self, *args):
        super(CurveMetrics, self).__init__(*args)

    def _to_list(self, rdd):
        points = []
        # Note this collect could be inefficient for large datasets 
        # considering there may be one probability per datapoint (at most)
        # The Scala version takes a numBins parameter, 
        # but it doesn't seem possible to pass this from Python to Java
        for row in rdd.collect():
            # Results are returned as type scala.Tuple2, 
            # which doesn't appear to have a py4j mapping
            points += [(float(row._1()), float(row._2()))]
        return points

    def get_curve(self, method):
        rdd = getattr(self._java_model, method)().toJavaRDD()
        return self._to_list(rdd)

In [None]:
import matplotlib.pyplot as plt

# Returns as a list (false positive rate, true positive rate)
preds_list = preds.select('is_fraud','probability').rdd.map(lambda row: (float(row['probability'][1]), float(row['is_fraud'])))
points = CurveMetrics(preds_list).get_curve('pr')

plt.figure()
x_val = [x[0] for x in points]
y_val = [x[1] for x in points]
plt.title("PR Curve")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.plot(x_val, y_val)
plt.show()

### Results Analysis

With a false negative rate of 100% in the confusion matrix, and 0.5 AUC score we obviously have work to do! ;-)

## Real Time Prediction - Prototype 1 - File Based Structured Streaming

In [None]:
# Set the following value to True if you want to test drive streaming
enableStreaming = True
# Set to true if running locally, false if run on Databricks
isLocal = True

inputPath = 'events' if isLocal else "/Filestore/events"

In [None]:
if (enableStreaming):

    # Repartition the test data and break them down into 100 different files and write it to a csv file.
    testData = test.repartition(100)
    testData.write.mode("overwrite").format("CSV").option("header",True).save(inputPath)

In [None]:
# establish event stream
if (enableStreaming):
    events = spark.readStream.format("csv") \
                             .option("header",True) \
                             .schema(ccschema) \
                             .option("ignoreLeadingWhiteSpace",True) \
                             .option("mode","dropMalformed") \
                             .option("maxFilesPerTrigger",1) \
                             .load(inputPath)

In [None]:
if (enableStreaming):
    # make predictions as new batches come in.
    # to stop this process, interrupt the kernel with the stop button
    predStream = model.transform(events).select("is_fraud", "probability", "prediction")

    # write to console as new batches come in
    predStream.writeStream.format("console").outputMode("append").start().awaitTermination()

## Real Time Prediction - Prototype 2 - Kafka Based Structured Streaming

In [None]:
enableKafka = True

In [None]:
if (enableKafka):
    df = spark.readStream \
        .format("kafka") \
        .option("kafka.bootstrap.servers", "localhost:9092") \
        .option("subscribe", "cc_fraud_topic") \
        .load()

# .option("startingOffsets", "earliest") \

    df.printSchema()

In [None]:
if (enableKafka):
    stringDF = df.selectExpr("CAST(value AS STRING)")

In [None]:
from pyspark.sql.functions import from_csv

if (enableKafka):
    events = stringDF.select(from_csv(col("value"),ccschema.simpleString()).alias("data")).select("data.*")

In [None]:
#ccDF.writeStream \
#      .format("console") \
#      .outputMode("append") \
#      .start() \
#      .awaitTermination()

In [None]:
if (enableKafka):
    # make predictions as new batches come in.
    # to stop this process, interrupt the kernel with the stop button
    predStream = model.transform(events).select("trans_date_trans_time", "cc_num", "amt", "probability", "prediction")

    # write to console as new batches come in
    predStream.writeStream.format("console").outputMode("append").start().awaitTermination()