In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
import math

In [2]:
spark = SparkSession.builder \
    .appName("KafkaStreamReader") \
    .config("spark.streaming.stopGracefullyOnShutdown", True) \
    .config('spark.jars.packages', 'org.apache.spark:spark-sql-kafka-0-10_2.12:3.5.0') \
    .config("spark.sql.shuffle.partitions", 4) \
    .getOrCreate()

In [3]:
kafka_df  = spark.readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", "kafka_v2:9092") \
    .option("subscribe", "credit_card_trans") \
    .option("startingOffsets", "earliest") \
    .option("failOnDataLoss", "false") \
    .load()

In [4]:
kafka_df.printSchema()

root
 |-- key: binary (nullable = true)
 |-- value: binary (nullable = true)
 |-- topic: string (nullable = true)
 |-- partition: integer (nullable = true)
 |-- offset: long (nullable = true)
 |-- timestamp: timestamp (nullable = true)
 |-- timestampType: integer (nullable = true)



In [5]:
# Kafka message is in binary, so cast it to string
value_df = kafka_df.selectExpr("CAST(value AS STRING) as json_str")

In [6]:
schema = StructType([
    StructField("", StringType(), True),
    StructField("trans_date_trans_time", StringType(), True),
    StructField("cc_num", StringType(), True),
    StructField("merchant", StringType(), True),
    StructField("category", StringType(), True),
    StructField("amt", StringType(), True),
    StructField("first", StringType(), True),
    StructField("last", StringType(), True),
    StructField("gender", StringType(), True),
    StructField("street", StringType(), True),
    StructField("city", StringType(), True),
    StructField("state", StringType(), True),
    StructField("zip", StringType(), True),
    StructField("lat", StringType(), True),
    StructField("long", StringType(), True),
    StructField("city_pop", StringType(), True),
    StructField("job", StringType(), True),
    StructField("dob", StringType(), True),
    StructField("trans_num", StringType(), True),
    StructField("unix_time", StringType(), True),
    StructField("merch_lat", StringType(), True),
    StructField("merch_long", StringType(), True),
    StructField("is_fraud", StringType(), True),
    StructField("event_time",StringType(),True)
])

In [7]:
parsed_df = value_df.select(from_json(col("json_str"), schema).alias("data")).select("data.*")

In [8]:
parsed_df = parsed_df.withColumnRenamed("", "indx")

In [9]:
parsed_df = parsed_df.withColumn("is_fraud",col("is_fraud").cast("int")) \
                        .withColumn("merch_long",col("merch_long").cast("double")) \
                        .withColumn("merch_lat",col("merch_lat").cast("double")) \
                        .withColumn("unix_time",col("unix_time").cast("long")) \
                        .withColumn("merch_long",col("merch_long").cast("double")) \
                        .withColumn("city_pop",col("city_pop").cast("int")) \
                        .withColumn("long",col("long").cast("double")) \
                        .withColumn("lat",col("lat").cast("double")) \
                        .withColumn("amt",col("amt").cast("double")) \
                        .withColumn("zip",col("zip").cast("int")) \
                        .withColumn("dob", to_date("dob", "yyyy-MM-dd")) \
                        .withColumn("trans_date_trans_time", to_timestamp("trans_date_trans_time", "yyyy-MM-dd HH:mm:ss")) \
                        .withColumn("event_time", to_timestamp("event_time", "yyyy-MM-dd HH:mm:ss"))

In [10]:
parsed_df = parsed_df.withColumn("age", (datediff(current_date(), col("dob")) / 365).cast("int")) \
                     .fillna({"age": 0}) \
                     .withColumn("age", abs(col("age")))
                    

In [11]:
parsed_df = parsed_df.fillna({"amt": 0}) \
                     .withColumn("amt", abs(col("amt")))


In [12]:
parsed_df = parsed_df.withColumn("merchant", regexp_replace("merchant", "^fraud_", ""))

In [13]:
parsed_df = parsed_df.withColumn("gender", when(col("gender") == "M", "Male").when(col("gender") == "F", "Female"))

In [14]:
cols_to_trim = ["first", "last", "job", "merchant", "category", "street", "city", "state"]

for c in cols_to_trim:
    parsed_df = parsed_df.withColumn(c, trim(col(c)))

In [15]:
def haversine(lat, long, merch_lat, merch_long):
    R = 6371.0  # Earth radius in kilometers

    # Convert degrees to radians
    lat1_rad = math.radians(lat)
    lon1_rad = math.radians(long)
    lat2_rad = math.radians(merch_lat)
    lon2_rad = math.radians(merch_long)

    dlat = lat2_rad - lat1_rad
    dlon = lon2_rad - lon1_rad

    a = math.sin(dlat / 2) ** 2 + \
        math.cos(lat1_rad) * math.cos(lat2_rad) * \
        math.sin(dlon / 2) ** 2
    c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))

    distance = R * c
    return distance


In [16]:
haversine_udf = udf(haversine, DoubleType())
parsed_df = parsed_df.withColumn("distance_km",haversine_udf("lat","long","merch_lat","merch_long"))

In [17]:
parsed_df = parsed_df \
.withColumn("merchant_id", sha2(concat_ws("||", col("merchant"), col("merch_lat").cast("string"), col("merch_long").cast("string")), 256)) \
.withColumn("customer_id", sha2(concat_ws("||", col("first"), col("last"), col("dob").cast("string")), 256))
                        

In [18]:
cleand_transformed_df = parsed_df

In [19]:
cleand_transformed_df.printSchema()

root
 |-- indx: string (nullable = true)
 |-- trans_date_trans_time: timestamp (nullable = true)
 |-- cc_num: string (nullable = true)
 |-- merchant: string (nullable = true)
 |-- category: string (nullable = true)
 |-- amt: double (nullable = false)
 |-- first: string (nullable = true)
 |-- last: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- street: string (nullable = true)
 |-- city: string (nullable = true)
 |-- state: string (nullable = true)
 |-- zip: integer (nullable = true)
 |-- lat: double (nullable = true)
 |-- long: double (nullable = true)
 |-- city_pop: integer (nullable = true)
 |-- job: string (nullable = true)
 |-- dob: date (nullable = true)
 |-- trans_num: string (nullable = true)
 |-- unix_time: long (nullable = true)
 |-- merch_lat: double (nullable = true)
 |-- merch_long: double (nullable = true)
 |-- is_fraud: integer (nullable = true)
 |-- event_time: timestamp (nullable = true)
 |-- age: integer (nullable = false)
 |-- distance_km: double 

In [20]:
def write_to_postgres(batch_df, batch_id):
    batch_df.write \
        .format("jdbc") \
        .option("url", "jdbc:postgresql://postgres_v2:5432/sparkdb") \
        .option("dbtable", "transactions") \
        .option("user", "spark") \
        .option("password", "spark") \
        .option("driver", "org.postgresql.Driver") \
        .mode("append") \
        .save()

In [21]:
stream = cleand_transformed_df.writeStream \
    .foreachBatch(write_to_postgres) \
    .outputMode("append")

In [22]:
q = stream.start()

In [24]:
q.stop()

In [22]:
stream = cleand_transformed_df.writeStream \
                .format("console") \
                .outputMode("append") 
                # .option("checkpointLocation", "checkpoint_dir_kafka")
                # .trigger(processingTime="20 seconds")

In [23]:
qu = stream.start()

In [25]:
qu.stop()