In [None]:
import os
os.environ["PYSPARK_ALLOW_INSECURE_GATEWAY"] = "1"


from pyspark.sql import SparkSession
from pyspark.sql.functions import from_json, col, window
from pyspark.sql.types import StructType, StringType, IntegerType, TimestampType


DB_URL = "jdbc:postgresql://postgres:5432/postgres"
DB_USER = "myuser"
DB_PASS = "myuserpass"
DB_TABLE = "kafka_data_04"
BAD_ROWS_PATH = "/tmp/bad_rows_04"
CHECKPOINT_PATH = "/tmp/checkpoints/kafka-to-pgsql_04"


# SparkSession z obsługą Kafka i okien czasowych
spark = (
    SparkSession.builder
    .appName("WindowedAggregation")
    .master("local[*]")
    .getOrCreate()
)

# Schemat danych JSON z timestampem
schema = (
    StructType()
    .add("id", IntegerType())
    .add("name", StringType())
    .add("timestamp", TimestampType())
)

# Wczytanie danych z Kafka
df = (
    spark.readStream
    .format("kafka")
    .option("kafka.bootstrap.servers", "kafka:9092")
    .option("subscribe", "spark-lab4-topic")
    .option("startingOffsets", "latest")
    .option("badRecordsPath", BAD_ROWS_PATH)
    .load()
)

# Parsowanie JSON z pola value
parsed = (
    df.selectExpr("CAST(value AS STRING)")
    .select(from_json(col("value"), schema).alias("data"))
    .select("data.*")
)

# Agregacja z użyciem okien czasowych
aggregated = (
    parsed
    .withWatermark("timestamp", "20 seconds")
    .groupBy(
        window(col("timestamp"), "10 seconds")
    )
    .count()
).select(
    col("window.start").alias("window_start"),
    col("window.end").alias("window_end"),
    col("count")
)

# Zapisywanie do PostgreSQL
def write_to_postgres(batch_df, batch_id):
    (
        batch_df.write
        .format("jdbc")
        .option("url", DB_URL)
        .option("dbtable", DB_TABLE)
        .option("user", DB_USER)
        .option("password", DB_PASS)
        .option("driver", "org.postgresql.Driver")
        .mode("append")
        .save()
    )

# Zapis jako foreachBatch
query = (
    aggregated.writeStream
    .foreachBatch(write_to_postgres)
    .outputMode("append")
    .option("checkpointLocation", CHECKPOINT_PATH)
    .start()
)

query.awaitTermination()
