In [1]:
%load_ext dotenv
%dotenv

In [2]:
import logging
import os
from typing import Optional

from pyspark.sql import SparkSession

# Setup basic logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class SimpleSparkSession:
    """Simple Spark session builder for Jupyter notebooks"""

    def __init__(
            self,
            app_name="Jupyter Spark Session",
            master="local[*]",
            spark_config=None,
            enable_hive_support=False,
            # S3 configuration
            s3_bucket_name=None,
            s3_endpoint=None,
            s3_access_key=None,
            s3_secret_key=None,
            s3_region="us-east-1",
            s3_path_style_access=True,
            # PostgreSQL configuration
            postgres_config=None,
            # Package configuration
            packages=None
    ):
        self.app_name = app_name
        self.master = master
        self.spark_config = spark_config or {}
        self.enable_hive_support = enable_hive_support

        # S3 config
        self.s3_bucket_name = s3_bucket_name
        self.s3_endpoint = s3_endpoint
        self.s3_access_key = s3_access_key
        self.s3_secret_key = s3_secret_key
        self.s3_region = s3_region
        self.s3_path_style_access = s3_path_style_access

        # PostgreSQL config
        self.postgres_config = postgres_config
        self.jdbc_driver_path: Optional[str] = None

        # Packages
        self.packages = packages or []

        self._session = None

    def build_session(self):
        """Build and return a SparkSession"""
        if self._session is not None:
            return self._session

        # Start building the session
        builder = SparkSession.builder.appName(self.app_name).master(self.master)

        builder = builder.config("spark.sql.execution.arrow.pyspark.enabled", "true") \
            .config("spark.driver.extraJavaOptions", "-Djava.security.manager=allow") \
            .config("spark.executor.extraJavaOptions", "-Djava.security.manager=allow") \
 \
            # Add Hive support if requested
        if self.enable_hive_support:
            builder = builder.enableHiveSupport()

        if self.jdbc_driver_path:
            builder = builder.config("spark.driver.extraClassPath", self.jdbc_driver_path)
            builder = builder.config("spark.executor.extraClassPath", self.jdbc_driver_path)

        # Add all configuration options
        for key, value in self.spark_config.items():
            builder = builder.config(key, value)

        # Configure packages
        if self.packages:
            packages = ",".join(self.packages)
            builder = builder.config("spark.jars.packages", packages)

        # Add S3 configuration if credentials provided
        if self.s3_access_key and self.s3_secret_key:
            builder = builder.config("spark.hadoop.fs.s3a.access.key", self.s3_access_key)
            builder = builder.config("spark.hadoop.fs.s3a.secret.key", self.s3_secret_key)
            builder = builder.config("spark.hadoop.fs.s3a.aws.credentials.provider",
                                     "org.apache.hadoop.fs.s3a.SimpleAWSCredentialsProvider")

            # Config for non-AWS S3
            if self.s3_endpoint:
                builder = builder.config("spark.hadoop.fs.s3a.endpoint", self.s3_endpoint)
                builder = builder.config("spark.hadoop.fs.s3a.endpoint.region", self.s3_region)

            # Path style access for non-AWS implementations
            if self.s3_path_style_access:
                builder = builder.config("spark.hadoop.fs.s3a.path.style.access", "true")
                builder = builder.config("spark.hadoop.fs.s3a.connection.ssl.enabled", "false")
                builder = builder.config("spark.hadoop.fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem")
                builder = builder.config("spark.hadoop.fs.s3a.multiobjectdelete.enable", "false")

        # Build the session
        logger.info(f"Building Spark session with app name: {self.app_name}, master: {self.master}")
        self._session = builder.getOrCreate()

        return self._session

    def get_session(self):
        """Get the current SparkSession or create a new one"""
        return self.build_session()

    def stop_session(self):
        """Stop the current Spark session if it exists"""
        if self._session is not None:
            self._session.stop()
            self._session = None
            logger.info("Spark session stopped")

In [3]:
spark = SimpleSparkSession(
    app_name="Data Analysis Notebook",
    packages=[
        "org.postgresql:postgresql:42.5.4",
        "org.apache.hadoop:hadoop-aws:3.3.4",
        "com.amazonaws:aws-java-sdk-bundle:1.12.426"
    ],
    s3_access_key=os.getenv("S3_ACCESS_KEY"),
    s3_secret_key=os.getenv("S3_SECRET_KEY"),
    s3_endpoint=os.getenv("S3_ENDPOINT"),
    s3_region="garage",
    s3_path_style_access=True,
    postgres_config={
        "user": os.getenv("POSTGRES_USER"),
        "password": os.getenv("POSTGRES_PASSWORD"),
        "driver": "org.postgresql.Driver",
        "currentSchema": "public"
    },
    enable_hive_support=False,
    s3_bucket_name="traffy-troffi"
).get_session()

INFO:__main__:Building Spark session with app name: Data Analysis Notebook, master: local[*]
25/05/08 19:59:23 WARN Utils: Your hostname, PatrickChoDevMacbook.local resolves to a loopback address: 127.0.0.1; using 10.201.246.45 instead (on interface en0)
25/05/08 19:59:23 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Ivy Default Cache set to: /Users/patrick/.ivy2/cache
The jars for the packages stored in: /Users/patrick/.ivy2/jars
org.postgresql#postgresql added as a dependency
org.apache.hadoop#hadoop-aws added as a dependency
com.amazonaws#aws-java-sdk-bundle added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-e61a6dca-8a5a-4030-91ac-77a21cb09f00;1.0
	confs: [default]
	found org.postgresql#postgresql;42.5.4 in central
	found org.checkerframework#checker-qual;3.5.0 in central
	found org.apache.hadoop#hadoop-aws;3.3.4 in central
	found org.wildfly.openssl#wildfly-openssl;1.0.7.Final in central
	found com.amazonaws#aws-java-sdk

:: loading settings :: url = jar:file:/Users/patrick/Desktop/Workspace/Projects/traffy-troffi/.venv/lib/python3.12/site-packages/pyspark/jars/ivy-2.5.1.jar!/org/apache/ivy/core/settings/ivysettings.xml


25/05/08 19:59:24 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


In [4]:
df = spark.read.jdbc(table='traffy_fondue',
                     url="jdbc:postgresql://localhost:5432/traffy-troffi",
                     properties={"user": "postgres", "password": "troffi",
                                 "driver": "org.postgresql.Driver",
                                 "currentSchema": "public"}
                     )
df.printSchema()

root
 |-- ticket_id: string (nullable = true)
 |-- complaint: string (nullable = true)
 |-- timestamp: timestamp (nullable = true)
 |-- image: string (nullable = true)
 |-- image_after: string (nullable = true)
 |-- latitude: double (nullable = true)
 |-- longitude: double (nullable = true)
 |-- district: string (nullable = true)
 |-- subdistrict: string (nullable = true)
 |-- categories: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- categories_idx: array (nullable = true)
 |    |-- element: float (containsNull = true)



In [5]:
import pyspark.sql.functions as F

In [6]:
df.filter((F.col("district") == 'ราชเทวี') & (F.array_contains('categories', 'ถนน'))).show()

+-----------+--------------------+--------------------+--------------------+--------------------+--------+---------+--------+-----------+--------------------+-----------------+
|  ticket_id|           complaint|           timestamp|               image|         image_after|latitude|longitude|district|subdistrict|          categories|   categories_idx|
+-----------+--------------------+--------------------+--------------------+--------------------+--------+---------+--------+-----------+--------------------+-----------------+
|2024-UPZB93|รบกวนช่วยทำพื้นถน...|2024-09-09 11:44:...|https://storage.g...|https://storage.g...| 13.7548|100.54201| ราชเทวี|   ถนนพญาไท|               [ถนน]|            [9.0]|
|2024-KTGPQY|บริเวณนี้มีการเปิ...|2024-05-07 14:25:...|https://storage.g...|https://storage.g...|13.75871|100.54239| ราชเทวี|   ถนนพญาไท|               [ถนน]|            [9.0]|
|2024-K3PRE6|จุดนี้มืดและเปลี่...|2024-11-01 05:02:...|https://storage.g...|https://storage.g...|13.75672|100.52418

In [7]:
df.filter((F.col("district") == 'ราชเทวี') & (F.array_contains('categories', 'ถนน'))).sample(0.1).show()

+-----------+--------------------+--------------------+--------------------+--------------------+--------+---------+--------+-----------+--------------------+--------------------+
|  ticket_id|           complaint|           timestamp|               image|         image_after|latitude|longitude|district|subdistrict|          categories|      categories_idx|
+-----------+--------------------+--------------------+--------------------+--------------------+--------+---------+--------+-----------+--------------------+--------------------+
|2024-628FCL|*ถนน พระราม6 มุ่ง...|2024-07-24 04:24:...|https://storage.g...|https://storage.g...|13.76074|100.52516| ราชเทวี|  ทุ่งพญาไท|               [ถนน]|               [9.0]|
|2024-DPZYG9|หลุมบ่อกลางถนนบริ...|2024-07-09 05:53:...|https://storage.g...|https://storage.g...|13.75215|100.52589| ราชเทวี|ถนนเพชรบุรี|               [ถนน]|               [9.0]|
|2024-N8B26R|ถนนชำรุดค่ะ ทำให้...|2024-05-23 07:32:...|https://storage.g...|https://storage.g...|13.

In [8]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

checkpoint = "phor2547/final-dsde-type"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = AutoModelForSequenceClassification.from_pretrained(
    checkpoint,
    problem_type="multi_label_classification"
).to(device)

Using device: cpu


In [9]:
import numpy as np
import torch
from pyspark.sql import functions as F

all_predictions = []
inputs = tokenizer(
    [row.complaint for row in
     df.filter((F.col("district") == 'ราชเทวี') & (F.array_contains(F.col("categories"), 'ถนน'))).sample(0.01).select(
         "complaint").collect()],
    padding=True,
    truncation=True,
    return_tensors="pt"
).to(device)

with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs.logits
    probs = torch.sigmoid(logits).cpu().numpy()
    preds = (probs >= 0.5).astype(int)

    # Convert each prediction row to encoded format
    encoded_predictions = []
    for pred_row in preds:
        # Get indices where value is 1
        positive_indices = np.where(pred_row == 1)[0]
        encoded_predictions.append(positive_indices.tolist())

    all_predictions.extend(encoded_predictions)

all_predictions

[[9],
 [9],
 [9],
 [5, 9],
 [5, 9],
 [9, 23],
 [7, 9, 10],
 [2, 9],
 [2, 7, 9],
 [5],
 [9],
 [2, 9],
 [9],
 [7, 9],
 [5, 9, 21, 23],
 [9],
 [5, 9],
 [9],
 [9],
 [9, 15],
 [2, 7, 9],
 [2, 7, 9],
 [9],
 [9],
 [9, 19],
 [9],
 [9],
 [9],
 [6, 9, 10],
 [5, 9]]

In [10]:
import numpy as np
import torch
from pyspark.sql import functions as F
from pyspark.sql.types import ArrayType, IntegerType

# Store the complaint rows to maintain order and for later joining
filtered_rows = df.filter(
    (F.col("district") == 'ราชเทวี') &
    (F.array_contains(F.col("categories"), 'ถนน'))
).sample(0.01)

# Keep track of row identifiers (assuming there's an ID column; if not, we'll create one)
row_ids = [row.ticket_id for row in
           filtered_rows.select("ticket_id").collect()]  # Adjust if your ID column has a different name

complaints = [row.complaint for row in filtered_rows.select("complaint").collect()]

inputs = tokenizer(
    complaints,
    padding=True,
    truncation=True,
    return_tensors="pt"
).to(device)

encoded_predictions = []
with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs.logits
    probs = torch.sigmoid(logits).cpu().numpy()
    preds = (probs >= 0.5).astype(int)

    # Convert each prediction row to encoded format
    for pred_row in preds:
        # Get indices where value is 1
        positive_indices = np.where(pred_row == 1)[0]
        encoded_predictions.append(positive_indices.tolist())

# Create a DataFrame with the predictions
predictions_df = spark.createDataFrame(
    [(id, pred) for id, pred in zip(row_ids, encoded_predictions)],
    ["ticket_id", "encoded_predictions"]
)

# Join the predictions with the original filtered DataFrame
result_df = filtered_rows.join(
    predictions_df,
    on="ticket_id",
    how="left"
)

In [11]:
result_df.show()

+-----------+--------------------+--------------------+--------------------+--------------------+--------+---------+--------+-----------+--------------------+--------------------+-------------------+
|  ticket_id|           complaint|           timestamp|               image|         image_after|latitude|longitude|district|subdistrict|          categories|      categories_idx|encoded_predictions|
+-----------+--------------------+--------------------+--------------------+--------------------+--------+---------+--------+-----------+--------------------+--------------------+-------------------+
|2022-KW4PD8|แท๊กซี่จอดตรงป้าย...|2022-08-14 15:13:...|https://storage.g...|https://storage.g...|13.76683|100.52798| ราชเทวี|  ทุ่งพญาไท|               [ถนน]|               [9.0]|                [9]|
|2024-PCR67G|ด้านหน้ารร สุโกศล...|2024-01-10 05:47:...|https://storage.g...|https://storage.g...|13.75796|100.53594| ราชเทวี|   ถนนพญาไท|               [ถนน]|               [9.0]|                [9]|


In [12]:
spark.stop()