In [17]:
import os

from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import from_json, col
from pyspark.sql.types import StructType, StructField, DoubleType, StringType

from logs.CustomLogger import CustomLogger

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


We need to set the environment variables before creating the Spark session. We can do this by setting the `PYSPARK_SUBMIT_ARGS` environment variable to include the necessary Kafka package. We can then create a Spark session using the `SparkSession` class.

In [18]:
# Set the necessary Spark environment variables
os.environ['PYSPARK_SUBMIT_ARGS'] = '--packages org.apache.spark:spark-sql-kafka-0-10_2.12:3.4.1 pyspark-shell'

# Setup the Logger

In [19]:
logger = CustomLogger("KafkaConsumer1")

INFO:KafkaConsumer1:Logger is set up


Logger is set up. Check producer.log for logs.


In [20]:
spark = SparkSession.builder \
    .appName("KafkaConsumer1") \
    .config("spark.kafka.consumer.partition.assignment.strategy", "org.apache.kafka.clients.consumer.RoundRobinAssignor") \
    .getOrCreate()

# Create a schema for the data
The data we are going to consume from Kafka is a JSON string that contains the following fields

In [21]:
schema = StructType([
    StructField("Time", DoubleType(), True),
    StructField("V1", DoubleType(), True),
    StructField("V2", DoubleType(), True),
    StructField("V3", DoubleType(), True),
    StructField("V4", DoubleType(), True),
    StructField("V5", DoubleType(), True),
    StructField("V6", DoubleType(), True),
    StructField("V7", DoubleType(), True),
    StructField("V8", DoubleType(), True),
    StructField("V9", DoubleType(), True),
    StructField("V10", DoubleType(), True),
    StructField("V11", DoubleType(), True),
    StructField("V12", DoubleType(), True),
    StructField("V13", DoubleType(), True),
    StructField("V14", DoubleType(), True),
    StructField("V15", DoubleType(), True),
    StructField("V16", DoubleType(), True),
    StructField("V17", DoubleType(), True),
    StructField("V18", DoubleType(), True),
    StructField("V19", DoubleType(), True),
    StructField("V20", DoubleType(), True),
    StructField("V21", DoubleType(), True),
    StructField("V22", DoubleType(), True),
    StructField("V23", DoubleType(), True),
    StructField("V24", DoubleType(), True),
    StructField("V25", DoubleType(), True),
    StructField("V26", DoubleType(), True),
    StructField("V27", DoubleType(), True),
    StructField("V28", DoubleType(), True),
    StructField("Amount", DoubleType(), True),
    StructField("Class", StringType(), True)
])
logger.info(f"Schema created: {schema}")

INFO:KafkaConsumer1:Schema created: StructType([StructField('Time', DoubleType(), True), StructField('V1', DoubleType(), True), StructField('V2', DoubleType(), True), StructField('V3', DoubleType(), True), StructField('V4', DoubleType(), True), StructField('V5', DoubleType(), True), StructField('V6', DoubleType(), True), StructField('V7', DoubleType(), True), StructField('V8', DoubleType(), True), StructField('V9', DoubleType(), True), StructField('V10', DoubleType(), True), StructField('V11', DoubleType(), True), StructField('V12', DoubleType(), True), StructField('V13', DoubleType(), True), StructField('V14', DoubleType(), True), StructField('V15', DoubleType(), True), StructField('V16', DoubleType(), True), StructField('V17', DoubleType(), True), StructField('V18', DoubleType(), True), StructField('V19', DoubleType(), True), StructField('V20', DoubleType(), True), StructField('V21', DoubleType(), True), StructField('V22', DoubleType(), True), StructField('V23', DoubleType(), True), 

# Consume data from Kafka
We can consume data from Kafka using the `readStream` method of the `SparkSession` object. We need to specify the Kafka server and the topic to consume data from. We can then parse the value as JSON (if applicable) and write the parsed data to the console.

In [22]:
df = spark \
    .readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", "localhost:9092") \
    .option("subscribe", "distributed_transactions") \
    .option("startingOffsets", "latest") \
    .option("failOnDataLoss", "false") \
    .option("kafka.group.id", "distributed_transactions") \
    .load()

24/12/06 20:42:45 WARN KafkaSourceProvider: Kafka option 'kafka.group.id' has been set on this query, it is
 not recommended to set this option. This option is unsafe to use since multiple concurrent
 queries or sources using the same group id will interfere with each other as they are part
 of the same consumer group. Restarted queries may also suffer interference from the
 previous run having the same group id. The user should have only one query per group id,
 and/or set the option 'kafka.session.timeout.ms' to be very small so that the Kafka
 consumers from the previous query are marked dead by the Kafka group coordinator before the
 restarted query starts running.
    


In [23]:
# spark.stop()

In [24]:
# Parse the value as JSON (if applicable)
parsed_df = df.selectExpr("CAST(value AS STRING)") \
    .select(from_json(col("value"), schema).alias("data")) \
    .select("data.*")

## Getting parsed data from Kafka
In order to use the transactions data in order for our model to run predictions, we need to save the parsed data in a form that our pipeline can understand and use. We can do this by saving the parsed data

In [25]:
from pipeline.pipeline import CustomPipeline

model_path = "./../models/credit_card_fraud_detection_model"
pipeline_path = "./../pipeline/credit_card_fraud_detection_pipeline"
pipeline = CustomPipeline(model_path, pipeline_path)
cols = ['Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9', 'V10',
        'V11', 'V12', 'V13', 'V14', 'V15', 'V16', 'V17', 'V18', 'V19', 'V20',
        'V21', 'V22', 'V23', 'V24', 'V25', 'V26', 'V27', 'V28', 'Amount']
pipeline.create_pipeline(cols=cols)
# predictions = []

Loading pipeline model from ./../pipeline/credit_card_fraud_detection_pipeline
Pipeline created.


# Predictions
First, we create an empty dataframe to hold our predictions. We then define a function that processes the records in the parsed data. The function takes a dataframe and an epoch id as arguments. We then collect the rows from the dataframe and check if there are any rows. If there are rows, we transform the dataframe using the pipeline and collect the rows from the predictions. We then print the predictions.

In [26]:
prediction_records_holder = [spark.createDataFrame([], StructType([]))]

In [27]:
from functools import partial

def process_record(df, epoch_id, prediction_records_holder, verbose=False):
    rows = df.collect()
    logger.info(f"Processing {rows} records")
    if len(rows) != 0:
        # Do no do anything if there are no rows
        predictions = pipeline.transform(df)
        logger.info(f"Predictions: {predictions}")

        # Append the predictions to the prediction_records DataFrame
        updated_prediction_records = prediction_records_holder[0].unionByName(predictions, allowMissingColumns=True)
        prediction_records_holder[0] = updated_prediction_records

        # Print or process the predictions if verbose is set to True
        if verbose:
            # Convert the predictions DataFrame to a list of dictionaries
            rows = predictions.collect()
            for row in rows:
                print(row.asDict())

process_record_with_params = partial(process_record, prediction_records_holder=prediction_records_holder, verbose=False)

Set `verbose=True` to print the predictions as they happen

In [28]:
# Modify the query to use foreachBatch
query = parsed_df.writeStream \
    .outputMode("append") \
    .foreachBatch(process_record_with_params) \
    .start()

query.awaitTermination()

24/12/06 20:42:48 WARN ResolveWriteToStream: Temporary checkpoint location created which is deleted normally when the query didn't fail: /private/var/folders/fn/0_yk91x94834b2xtx5nmcsqh0000gn/T/temporary-b9c3ab90-542e-41ba-9ae6-5bcc7475c1ed. If it's required to delete it under any circumstances, please set spark.sql.streaming.forceDeleteTempCheckpointLocation to true. Important to know deleting temp checkpoint folder is best effort.
24/12/06 20:42:48 WARN ResolveWriteToStream: spark.sql.adaptive.enabled is not supported in streaming DataFrames/Datasets and will be disabled.
24/12/06 20:42:48 WARN KafkaSourceProvider: Kafka option 'kafka.group.id' has been set on this query, it is
 not recommended to set this option. This option is unsafe to use since multiple concurrent
 queries or sources using the same group id will interfere with each other as they are part
 of the same consumer group. Restarted queries may also suffer interference from the
 previous run having the same group id. Th

KeyboardInterrupt: 

Now we stop the query to check the predictions

In [29]:
query.stop()

In [30]:
predictions = prediction_records_holder[0]
print(f"Number of predictions: {predictions.count()}")
predictions.show()

24/12/06 20:46:41 WARN DAGScheduler: Broadcasting large task binary with size 1827.2 KiB
                                                                                

Number of predictions: 103


24/12/06 20:46:48 WARN DAGScheduler: Broadcasting large task binary with size 7.5 MiB
24/12/06 20:46:48 WARN DAGScheduler: Broadcasting large task binary with size 7.5 MiB
24/12/06 20:46:49 WARN DAGScheduler: Broadcasting large task binary with size 7.5 MiB
24/12/06 20:46:51 WARN DAGScheduler: Broadcasting large task binary with size 7.5 MiB

+-------+------------------+-------------------+------------------+-------------------+------------------+------------------+--------------------+--------------------+------------------+-------------------+-------------------+------------------+------------------+------------------+-------------------+-------------------+--------------------+-------------------+------------------+--------------------+--------------------+-------------------+-------------------+-------------------+-------------------+-------------------+--------------------+--------------------+------+-----+--------------------+--------------------+--------------------+--------------------+----------+
|   Time|                V1|                 V2|                V3|                 V4|                V5|                V6|                  V7|                  V8|                V9|                V10|                V11|               V12|               V13|               V14|                V15|                V16| 

                                                                                

# Evaluating the model

## Converting the columns _Class_
Since we streamed our data from Kafka, the column _Class_ is of type StringType. We need to convert it to IntegerType in order to evaluate the model

In [15]:
from pyspark.sql.functions import col
from pyspark.sql.types import IntegerType

predictions_updated = predictions.withColumn("Class", col("Class").cast(IntegerType()))
predictions_updated.printSchema()

root
 |-- Time: double (nullable = true)
 |-- V1: double (nullable = true)
 |-- V2: double (nullable = true)
 |-- V3: double (nullable = true)
 |-- V4: double (nullable = true)
 |-- V5: double (nullable = true)
 |-- V6: double (nullable = true)
 |-- V7: double (nullable = true)
 |-- V8: double (nullable = true)
 |-- V9: double (nullable = true)
 |-- V10: double (nullable = true)
 |-- V11: double (nullable = true)
 |-- V12: double (nullable = true)
 |-- V13: double (nullable = true)
 |-- V14: double (nullable = true)
 |-- V15: double (nullable = true)
 |-- V16: double (nullable = true)
 |-- V17: double (nullable = true)
 |-- V18: double (nullable = true)
 |-- V19: double (nullable = true)
 |-- V20: double (nullable = true)
 |-- V21: double (nullable = true)
 |-- V22: double (nullable = true)
 |-- V23: double (nullable = true)
 |-- V24: double (nullable = true)
 |-- V25: double (nullable = true)
 |-- V26: double (nullable = true)
 |-- V27: double (nullable = true)
 |-- V28: double (nulla

In [16]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator

evaluator = BinaryClassificationEvaluator(labelCol="Class")
print(f"Area under ROC: {evaluator.evaluate(predictions_updated)}")
logger.info(f"Area under ROC: {evaluator.evaluate(predictions_updated)}")

24/12/06 20:35:52 WARN DAGScheduler: Broadcasting large task binary with size 5.7 MiB
                                                                                

Area under ROC: 0.986046511627907


24/12/06 20:36:00 WARN DAGScheduler: Broadcasting large task binary with size 5.7 MiB
INFO:KafkaConsumer1:Area under ROC: 0.986046511627907                           
