In [3]:
from pyspark.sql import SparkSession, Row
from pyspark.sql.functions import col, from_json, udf, pandas_udf, PandasUDFType
import pyspark.sql.functions as F
from pyspark.sql.window import Window
from pyspark.conf import SparkConf
from datetime import datetime, timezone, timedelta
from pyspark.sql.types import StringType, IntegerType, TimestampType, StructType, StructField, DoubleType, ArrayType
import os
import psycopg2

In [4]:
DEBUG_MODE = os.environ.get('APP_DEBUG_MODE', 'True').lower() in ('1', 't', 'true') # if True, doesn't use the cluster, doesn't save in db and collects and the end

In [5]:
db_conn_string = os.environ.get('APP_CONN_STRING', "postgresql://postgres/fastid-report?user=postgres&password=tappino2025")
db_jdbc_url = os.environ.get('APP_JDBC_URL', "jdbc:postgresql://postgres:5432/fastid-report")
db_jdbc_properties = {
    "user": os.environ.get('APP_JDBC_USERNAME', "postgres"),
    "password": os.environ.get('APP_JDBC_PASSWORD', "tappino2025"),
    "driver": os.environ.get('APP_JDBC_DRIVER', "org.postgresql.Driver")
}

In [6]:
        #.master("spark://spark:7077") \
spark = SparkSession.builder \
		.appName("FastIDReport") \
        .master("local[*]" if DEBUG_MODE else "spark://spark:7077") \
		.config("spark.jars.ivy", "/tmp") \
		.config("spark.sql.shuffle.partitions", 10) \
		.config("spark.local.dir", "/tmp") \
        .config('spark.jars.packages', 'org.apache.spark:spark-sql-kafka-0-10_2.12:3.4.3,org.postgresql:postgresql:42.7.4,org.elasticsearch:elasticsearch-spark-30_2.12:8.17.1') \
		.getOrCreate()

        #.config("spark.sql.streaming.checkpointLocation", "/tmp") \
spark.sparkContext.setLogLevel("INFO")
#spark.sparkContext.setCheckpointDir("/home/jovyan/work/spark-checkpoints")

In [7]:
TRANSACTION_SCHEMA = StructType([
    StructField("ID", IntegerType(), True),
    StructField("InsertTimestamp", TimestampType(), True),
    StructField("Log_Id", IntegerType(), True),
    StructField("InstanceCode", StringType(), True),
    StructField("InstanceType", StringType(), True),
    StructField("User_Id", IntegerType(), True),
    StructField("User_Name", StringType(), True),
    StructField("Timestamp_creation", TimestampType(), True),
    StructField("Timestamp_discovery", TimestampType(), True),
    StructField("Timestamp_profile", TimestampType(), True),
    StructField("ApplicationId", IntegerType(), True),
    StructField("Success", StringType(), True),
    StructField("Type", StringType(), True),
    StructField("VersionApi", StringType(), True),
    StructField("ServiceName", StringType(), True),
    StructField("AttrSet", StringType(), True),
    StructField("AttrClass", StringType(), True),
    StructField("AuthnRequest_IssueInstant", TimestampType(), True),
    StructField("AuthnRequest_Issuer", StringType(), True),
    StructField("Response_IssueInstant", TimestampType(), True),
    StructField("Response_Assertion_AuthnContextClassRef", StringType(), True),
    StructField("Response_Assertion_AttributeStatement", StringType(), True),
    StructField("Response_Issuer", StringType(), True),
    StructField("uniqueIdpIdentity", StringType(), True),
    StructField("Transaction_Id", IntegerType(), True),
])

In [None]:
'''
df_transactions_raw = spark.read \
    .format("kafka") \
    .option("kafka.bootstrap.servers", "kafkaServer:9092") \
    .option("subscribe", "transactions") \
    .load() \
    .selectExpr("CAST(value AS STRING)") \
    .select(from_json(col("value"), TRANSACTION_SCHEMA).alias("data")) \
    .select("data.*")
    
df_transactions_raw.createOrReplaceTempView("transactions")

df_transactions = spark.sql("SELECT COUNT(*) FROM transactions")

df_transactions.show()
'''

In [9]:
df_transactions_raw = spark.readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", "kafkaServer:9092") \
    .option("subscribe", "transactions") \
    .option("startingOffsets", "{\"transactions\":{\"0\":724000}}") \
    .load()

In [None]:
df_transactions = df_transactions_raw \
        .selectExpr("CAST(value AS STRING)") \
        .select(from_json(col("value"), TRANSACTION_SCHEMA).alias("data")) \
        .select("data.*")

In [11]:
df_transactions = df_transactions \
    .repartition("User_Name")

In [None]:
df_transactions = df_transactions \
    .withColumn(
        "Date",
        F.when(F.year("Response_IssueInstant") > 2000, F.col("Response_IssueInstant"))
         .otherwise(F.col("Timestamp_creation"))
    ) \
    .withColumn(
        "DateYearMonth",
        F.trunc(F.col("Date"), "MM")
    ) \
    .withColumn('DateYearMonth', F.col("DateYearMonth").cast(TimestampType())) \
    .withWatermark("Date", "10 seconds") \
    .withColumn("DateYear", F.year("DateYearMonth")) \
    .withColumn("DateMonth", F.month("DateYearMonth"))
    #.withWatermark("DateYearMonth", "10 seconds") \

# Clienti

In [70]:
df_clienti = df_transactions \
    .withColumn("Cliente", F.col("User_Name")) \
    .select("Cliente") \
    .distinct()

#df_clienti.writeStream.outputMode("update").foreachBatch(lambda df, _: save_clienti(df)).start()

In [None]:
query = df_clienti.writeStream \
	.outputMode("append") \
    .foreachBatch(lambda df, _:
                  df.write.jdbc(url=db_jdbc_url, table='"Clienti"', mode="append", properties=db_jdbc_properties)) \
    .start()

if DEBUG_MODE:
    try:
        query.awaitTermination()
    except KeyboardInterrupt:
        query.stop()

# IdP

In [10]:
df_idps = df_transactions \
    .withColumn("prefix", F.substring(F.col("uniqueIdpIdentity"), 1, 4)) \
    .filter(~F.col("prefix").isin("UII_", "TENV") & F.col("Response_Issuer").isNotNull()) \
    .select("Response_Issuer", "prefix") \
    .distinct()

In [25]:
query = df_idps.writeStream \
	.outputMode("append") \
    .foreachBatch(lambda df, _:
                  df.write.jdbc(url=db_jdbc_url, table='"IdP"', mode="append", properties=db_jdbc_properties)) \
    .start()

if DEBUG_MODE:
    try:
        query.awaitTermination()
    except KeyboardInterrupt:
        query.stop()

# Count by Type

In [None]:
query = df_transactions \
    .groupBy("User_Name", "Type") \
    .agg(F.count("*").alias("Count")) \
    .withColumn("Id", F.concat(F.col("User_Name"), F.lit("_"), F.col("Type"))) \
    .writeStream \
    .outputMode("update") \
    .option("checkpointLocation", "/tmp/") \
    .option("es.nodes", "elasticsearch") \
    .option("es.mapping.id", "Id") \
    .option("es.port", "9200") \
    .format("es") \
    .start("count-by-type")

if DEBUG_MODE:
    try:
        query.awaitTermination()
    except KeyboardInterrupt:
        query.stop()

# Interval Count

In [None]:
df_interval_count = df_transactions \
    .groupBy("User_Name", F.window("Date", "5 minute")) \
    .agg(F.count("*").alias("TransactionsCount")) \
    .withColumn("Date", F.col("window.start")) \
    .select("Date", "User_Name", "TransactionsCount")

# df_interval_count.printSchema()

In [None]:
def append_interval_count(df):

    df.write.jdbc(url=db_jdbc_url, table='"TransactionsCount"', mode="append", properties=db_jdbc_properties)

    '''
    transactions_df = df.collect()
    transactions = [tuple(row) for row in transactions_df]

    connection = psycopg2.connect(db_conn_string)
    cursor = connection.cursor()
    insert_query = """
        INSERT INTO "DailyTransactions" ("User_Name", "Date", "TransactionsCount")
        VALUES %s
        ON CONFLICT ("User_Name", "Date")
        DO UPDATE SET "TransactionsCount" = EXCLUDED."TransactionsCount";
    """

    execute_values(cursor, insert_query, transactions)

    connection.commit()
    '''

query = df_interval_count.writeStream \
    .outputMode("append") \
    .foreachBatch(lambda df, _: append_interval_count(df)) \
    .start()

if DEBUG_MODE:
    try:
        query.awaitTermination()
    except KeyboardInterrupt:
        query.stop()

In [None]:
query = df_interval_count \
    .withColumn("key", F.concat(F.col("User_Name"), F.lit("_"), F.col("Date"))) \
    .select(F.col("key"), F.to_json(F.struct("Date", "User_Name", "TransactionsCount")).alias("value")) \
    .writeStream \
    .outputMode("append") \
    .format("kafka") \
    .option("kafka.bootstrap.servers", "kafkaServer:9092") \
    .option("checkpointLocation", "/tmp") \
    .option("topic", "customer-transactions-count") \
    .start()

if DEBUG_MODE:
    try:
        query.awaitTermination()
    except KeyboardInterrupt:
        query.stop()

# Online Training

In [10]:
df_transactions_count_raw = spark.readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", "kafkaServer:9092") \
    .option("subscribe", "customer-transactions-count") \
    .option("startingOffsets", "earliest") \
    .load()

In [11]:
df_transactions_count = df_transactions_count_raw \
	.selectExpr("CAST(value AS STRING)") \
	.select(from_json(col("value"), "Date TIMESTAMP, User_Name STRING, TransactionsCount INT").alias("data")) \
	.select("data.*")

In [None]:
'''
df_transactions_count.writeStream \
    .format("console") \
    .outputMode("update") \
    .start() \
    .awaitTermination()
#    .trigger(processingTime='1 hour') \
'''

In [None]:
df_transactions_daily.printSchema()

In [None]:
from pyspark.sql.streaming.state import GroupStateTimeout
from sklearn.preprocessing import PolynomialFeatures
from sklearn.ensemble import RandomForestRegressor
from sklearn.pipeline import Pipeline
import numpy as np
import pandas as pd

def update_user_state(key, data, state):

    def train_model_and_predict(transactions_count, toPredictDates):

        def extract_features(dates):
            return pd.DataFrame({
                "timestamp": [int(datetime.timestamp(x)) for x in dates],
                "weekday": [x.weekday() for x in dates],
                "is_weekend": [1 if x.weekday() >= 5 else 0 for x in dates]
            })

        X = extract_features(transactions_count["Date"])
        y = transactions_count["TransactionsCount"].values
        
        model = Pipeline([
            ("poly", PolynomialFeatures(degree=2)),
            ("rf", RandomForestRegressor(n_estimators=100, random_state=42))
        ])
        model.fit(X, y)
        
        X_new = extract_features(toPredictDates)
        return model.predict(X_new)

    if state.hasTimedOut: # never
        state.remove()
        
    else:
        for batch in data:

            if state.exists:
                (past_transactions, last_date) = state.get
                print("got state")                
                df_past_transactions = pd.DataFrame.from_dict(past_transactions)
            else:
                df_past_transactions = pd.DataFrame([], columns=["Date", "TransactionsCount"])
                print("brand new state")                
                last_date = None

            print("df_past_transactions", df_past_transactions)

            df = pd.DataFrame([], columns=["User_Name", "Date", "PredictedTransactionsCount", "IsAfterRetrain"])
            for index, row in batch.iterrows():                
                
                if len(df_past_transactions) > 1:
                    predicted_counts = train_model_and_predict(df_past_transactions,
                        [row["Date"]] if last_date is None else [row["Date"], last_date])
                    
                    records = [{
                        "User_Name": row["User_Name"],
                        "Date": row["Date"],
                        "PredictedTransactionsCount": predicted_counts[0],
                        "IsAfterRetrain": False
                    }]
                    if last_date is not None:
                        records.append({
                            "User_Name": row["User_Name"],
                            "Date": last_date,
                            "PredictedTransactionsCount": predicted_counts[1],
                            "IsAfterRetrain": True
                        })

                    df = pd.concat([df, pd.DataFrame.from_records(records)])


                df_past_transactions = pd.concat([df_past_transactions, pd.DataFrame.from_records([{
                    "Date": row["Date"],
                    "TransactionsCount": row["TransactionsCount"]
                }])])

                df_past_transactions = df_past_transactions[df_past_transactions["Date"] > row["Date"] - timedelta(days=30)]

                last_date = row["Date"] # KISS

            # we save all transactions in state
            x = pd.Series(df_past_transactions['TransactionsCount'].values, index=df_past_transactions['Date']).to_dict()
            state.update((x, last_date))
        
            yield df

query = df_transactions_count \
    .withWatermark("Date", "10 seconds") \
    .groupBy("User_Name") \
    .applyInPandasWithState(
        update_user_state,
        "User_Name STRING, Date TIMESTAMP, PredictedTransactionsCount INT, IsAfterRetrain BOOLEAN",
        "TransactionsCount MAP<TIMESTAMP, INT>, LastDate TIMESTAMP",
        "update",
        GroupStateTimeout.NoTimeout
    ) \
    .withColumn("created_at", F.current_timestamp()) \
    .writeStream \
    .outputMode("update") \
    .foreachBatch(lambda df, _: df.write.jdbc(url=db_jdbc_url, table='"PredictedTransactionsCount"', mode="append", properties=db_jdbc_properties)) \
    .start()
'''
.writeStream \
    .outputMode("update") \
    .format("console") \
    .start()'''

if DEBUG_MODE:
    try:
        query.awaitTermination()
    except KeyboardInterrupt:
        query.stop()

In [None]:
if not DEBUG_MODE:
	spark.streams.awaitAnyTermination()