In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
import math

In [None]:
spark = SparkSession.builder \
    .appName("KafkaStreamReader") \
    .config("spark.streaming.stopGracefullyOnShutdown", True) \
    .config('spark.jars.packages', 'org.apache.spark:spark-sql-kafka-0-10_2.12:3.5.0') \
    .config("spark.sql.shuffle.partitions", 4) \
    .getOrCreate()

In [None]:
kafka_df  = spark.readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", "kafka_v2:9092") \
    .option("subscribe", "credit_card_trans") \
    .option("startingOffsets", "earliest") \
    .option("failOnDataLoss", "false") \
    .load()

In [None]:
kafka_df.printSchema()

In [None]:
# Kafka message is in binary, so cast it to string
value_df = kafka_df.selectExpr("CAST(value AS STRING) as json_str")

In [None]:
schema = StructType([
    StructField("", StringType(), True),
    StructField("trans_date_trans_time", StringType(), True),
    StructField("cc_num", StringType(), True),
    StructField("merchant", StringType(), True),
    StructField("category", StringType(), True),
    StructField("amt", StringType(), True),
    StructField("first", StringType(), True),
    StructField("last", StringType(), True),
    StructField("gender", StringType(), True),
    StructField("street", StringType(), True),
    StructField("city", StringType(), True),
    StructField("state", StringType(), True),
    StructField("zip", StringType(), True),
    StructField("lat", StringType(), True),
    StructField("long", StringType(), True),
    StructField("city_pop", StringType(), True),
    StructField("job", StringType(), True),
    StructField("dob", StringType(), True),
    StructField("trans_num", StringType(), True),
    StructField("unix_time", StringType(), True),
    StructField("merch_lat", StringType(), True),
    StructField("merch_long", StringType(), True),
    StructField("is_fraud", StringType(), True),
    StructField("event_time",StringType(),True)
])

In [None]:
parsed_df = value_df.select(from_json(col("json_str"), schema).alias("data")).select("data.*")

In [None]:
parsed_df = parsed_df.withColumnRenamed("", "indx")
parsed_df = parsed_df.withColumn("indx",sha2(concat_ws("||",col("indx").cast("string"),col("event_time").cast("string")),256))

In [None]:
parsed_df = parsed_df.withColumn("is_fraud",col("is_fraud").cast("int")) \
                        .withColumn("merch_long",col("merch_long").cast("double")) \
                        .withColumn("merch_lat",col("merch_lat").cast("double")) \
                        .withColumn("unix_time",col("unix_time").cast("long")) \
                        .withColumn("merch_long",col("merch_long").cast("double")) \
                        .withColumn("city_pop",col("city_pop").cast("int")) \
                        .withColumn("long",col("long").cast("double")) \
                        .withColumn("lat",col("lat").cast("double")) \
                        .withColumn("amt",col("amt").cast("double")) \
                        .withColumn("zip",col("zip").cast("int")) \
                        .withColumn("dob", to_date("dob", "yyyy-MM-dd")) \
                        .withColumn("trans_date_trans_time", to_timestamp("trans_date_trans_time", "yyyy-MM-dd HH:mm:ss")) \
                        .withColumn("event_time", to_timestamp("event_time", "yyyy-MM-dd HH:mm:ss"))

In [None]:
parsed_df = parsed_df.withColumn("age", (datediff(current_date(), col("dob")) / 365).cast("int")) \
                     .fillna({"age": 0}) \
                     .withColumn("age", abs(col("age")))
                    

In [None]:
parsed_df = parsed_df.fillna({"amt": 0}) \
                     .withColumn("amt", abs(col("amt")))


In [None]:
parsed_df = parsed_df.withColumn("merchant", regexp_replace("merchant", "^fraud_", ""))

In [None]:
parsed_df = parsed_df.withColumn("gender", when(col("gender") == "M", "Male").when(col("gender") == "F", "Female"))

In [None]:
cols_to_trim = ["first", "last", "job", "merchant", "category", "street", "city", "state"]

for c in cols_to_trim:
    parsed_df = parsed_df.withColumn(c, trim(col(c)))

In [None]:
def haversine(lat, long, merch_lat, merch_long):
    R = 6371.0  # Earth radius in kilometers

    # Convert degrees to radians
    lat1_rad = math.radians(lat)
    lon1_rad = math.radians(long)
    lat2_rad = math.radians(merch_lat)
    lon2_rad = math.radians(merch_long)

    dlat = lat2_rad - lat1_rad
    dlon = lon2_rad - lon1_rad

    a = math.sin(dlat / 2) ** 2 + \
        math.cos(lat1_rad) * math.cos(lat2_rad) * \
        math.sin(dlon / 2) ** 2
    c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))

    distance = R * c
    return distance


In [None]:
haversine_udf = udf(haversine, DoubleType())
parsed_df = parsed_df.withColumn("distance_km",haversine_udf("lat","long","merch_lat","merch_long"))

In [None]:
parsed_df = parsed_df \
.withColumn("merchant_id", sha2(concat_ws("||", col("merchant")), 256)) \
.withColumn("customer_id", sha2(concat_ws("||", col("first"), col("last"), col("dob").cast("string")), 256))
                        

In [None]:
cleand_transformed_df = parsed_df

In [None]:
cleand_transformed_df.printSchema()

In [None]:
def write_to_postgres(batch_df, batch_id):
    batch_df.write \
        .format("jdbc") \
        .option("url", "jdbc:postgresql://postgres_v2:5432/sparkdb") \
        .option("dbtable", "transactions") \
        .option("user", "spark") \
        .option("password", "spark") \
        .option("driver", "org.postgresql.Driver") \
        .mode("append") \
        .save()

In [None]:
stream = cleand_transformed_df.writeStream \
    .foreachBatch(write_to_postgres) \
    .outputMode("append") \
    .start()
try:
    stream.awaitTermination()
except KeyboardInterrupt:
    print("Gracefully stopping the stream...")
    stream.stop()

In [None]:
# stream = cleand_transformed_df.writeStream \
#                 .format("console") \
#                 .outputMode("append") 
#                 # .option("checkpointLocation", "checkpoint_dir_kafka")
#                 # .trigger(processingTime="20 seconds")
# qu = stream.start()
# qu.stop()