In [1]:
%load_ext dotenv
%dotenv

In [2]:
import logging
from typing import Optional

from pyspark.sql import SparkSession

# Setup basic logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class SimpleSparkSession:
    """Simple Spark session builder for Jupyter notebooks"""

    def __init__(
            self,
            app_name="Jupyter Spark Session",
            master="local[*]",
            spark_config=None,
            enable_hive_support=False,
            # S3 configuration
            s3_bucket_name=None,
            s3_endpoint=None,
            s3_access_key=None,
            s3_secret_key=None,
            s3_region="us-east-1",
            s3_path_style_access=True,
            # PostgreSQL configuration
            postgres_config=None,
            # Package configuration
            packages=None
    ):
        self.app_name = app_name
        self.master = master
        self.spark_config = spark_config or {}
        self.enable_hive_support = enable_hive_support

        # S3 config
        self.s3_bucket_name = s3_bucket_name
        self.s3_endpoint = s3_endpoint
        self.s3_access_key = s3_access_key
        self.s3_secret_key = s3_secret_key
        self.s3_region = s3_region
        self.s3_path_style_access = s3_path_style_access

        # PostgreSQL config
        self.postgres_config = postgres_config
        self.jdbc_driver_path: Optional[str] = None

        # Packages
        self.packages = packages or []

        self._session = None

    def build_session(self):
        """Build and return a SparkSession"""
        if self._session is not None:
            return self._session

        # Start building the session
        builder = SparkSession.builder.appName(self.app_name).master(self.master)

        builder = builder.config("spark.sql.execution.arrow.pyspark.enabled", "true") \
            .config("spark.driver.extraJavaOptions", "-Djava.security.manager=allow") \
            .config("spark.executor.extraJavaOptions", "-Djava.security.manager=allow") \
 \
            # Add Hive support if requested
        if self.enable_hive_support:
            builder = builder.enableHiveSupport()

        if self.jdbc_driver_path:
            builder = builder.config("spark.driver.extraClassPath", self.jdbc_driver_path)
            builder = builder.config("spark.executor.extraClassPath", self.jdbc_driver_path)

        # Add all configuration options
        for key, value in self.spark_config.items():
            builder = builder.config(key, value)

        # Configure packages
        if self.packages:
            packages = ",".join(self.packages)
            builder = builder.config("spark.jars.packages", packages)

        # Add S3 configuration if credentials provided
        if self.s3_access_key and self.s3_secret_key:
            builder = builder.config("spark.hadoop.fs.s3a.access.key", self.s3_access_key)
            builder = builder.config("spark.hadoop.fs.s3a.secret.key", self.s3_secret_key)
            builder = builder.config("spark.hadoop.fs.s3a.aws.credentials.provider",
                                     "org.apache.hadoop.fs.s3a.SimpleAWSCredentialsProvider")

            # Config for non-AWS S3
            if self.s3_endpoint:
                builder = builder.config("spark.hadoop.fs.s3a.endpoint", self.s3_endpoint)
                builder = builder.config("spark.hadoop.fs.s3a.endpoint.region", self.s3_region)

            # Path style access for non-AWS implementations
            if self.s3_path_style_access:
                builder = builder.config("spark.hadoop.fs.s3a.path.style.access", "true")
                builder = builder.config("spark.hadoop.fs.s3a.connection.ssl.enabled", "false")
                builder = builder.config("spark.hadoop.fs.s3a.impl", "org.apache.hadoop.fs.s3a.S3AFileSystem")
                builder = builder.config("spark.hadoop.fs.s3a.multiobjectdelete.enable", "false")

        # Build the session
        logger.info(f"Building Spark session with app name: {self.app_name}, master: {self.master}")
        self._session = builder.getOrCreate()

        return self._session

    def get_session(self):
        """Get the current SparkSession or create a new one"""
        return self.build_session()

    def stop_session(self):
        """Stop the current Spark session if it exists"""
        if self._session is not None:
            self._session.stop()
            self._session = None
            logger.info("Spark session stopped")

In [21]:
spark = SimpleSparkSession(
    app_name="Create table notebook",
    packages=[
        "org.postgresql:postgresql:42.5.4",
        "org.apache.hadoop:hadoop-aws:3.3.4",
        "com.amazonaws:aws-java-sdk-bundle:1.12.426"
    ],
    s3_access_key=os.getenv("S3_ACCESS_KEY"),
    s3_secret_key=os.getenv("S3_SECRET_KEY"),
    s3_endpoint=os.getenv("S3_ENDPOINT"),
    s3_region="garage",
    s3_path_style_access=True,
    postgres_config={
        "user": os.getenv("POSTGRES_USER"),
        "password": os.getenv("POSTGRES_PASSWORD"),
        "driver": "org.postgresql.Driver",
        "currentSchema": "public"
    },
    enable_hive_support=False,
    s3_bucket_name="traffy-troffi"
).get_session()

INFO:__main__:Building Spark session with app name: Create table notebook, master: local[*]


In [32]:
create_table_query = """
                     CREATE TABLE IF NOT EXISTS public.traffy_fondue
                     (
                         ticket_id      VARCHAR(20)      NOT NULL PRIMARY KEY,
                         complaint      TEXT             NOT NULL,
                         timestamp      TIMESTAMP        NOT NULL,
                         image          TEXT             NOT NULL,
                         image_after    TEXT             NOT NULL,
                         latitude       DOUBLE PRECISION NOT NULL,
                         longitude      DOUBLE PRECISION NOT NULL,
                         district       VARCHAR(128)     NOT NULL,
                         subdistrict    VARCHAR(128)     NOT NULL,
                         categories     VARCHAR(128)[]   NOT NULL,
                         categories_idx REAL[]           NOT NULL
                     );

                     ALTER TABLE public.traffy_fondue
                         OWNER TO postgres;

                     CREATE UNIQUE INDEX IF NOT EXISTS idx_traffy_fondue_ticket_id ON public.traffy_fondue (ticket_id);

                     CREATE INDEX IF NOT EXISTS idx_traffy_fondue_district ON public.traffy_fondue (district);

                     CREATE INDEX IF NOT EXISTS idx_traffy_fondue_subdistrict ON public.traffy_fondue (subdistrict);

                     CREATE INDEX IF NOT EXISTS idx_traffy_fondue_district_subdistrict ON public.traffy_fondue (district, subdistrict);

                     CREATE INDEX IF NOT EXISTS idx_traffy_fondue_timestamp ON public.traffy_fondue (timestamp);

                     CREATE INDEX IF NOT EXISTS idx_traffy_fondue_categories ON public.traffy_fondue USING GIN (categories);

                     CREATE INDEX IF NOT EXISTS idx_traffy_fondue_time_district ON public.traffy_fondue (timestamp, district);
                     CREATE INDEX IF NOT EXISTS idx_traffy_fondue_time_district_subdistrict ON public.traffy_fondue (timestamp, district, subdistrict);

                     ALTER TABLE public.traffy_fondue
                         ALTER COLUMN district SET STATISTICS 1000;
                     ALTER TABLE public.traffy_fondue
                         ALTER COLUMN subdistrict SET STATISTICS 1000;
                     """

In [42]:
def create_postgres_table():
    # Define connection parameters
    jdbc_url = "jdbc:postgresql://localhost:5432/traffy-troffi"
    properties = {
        "user": "postgres",
        "password": "troffi",
        "driver": "org.postgresql.Driver"
    }

    connection = spark._jvm.java.sql.DriverManager.getConnection(
        jdbc_url, properties["user"], properties["password"])

    # Create a statement
    statement = connection.createStatement()

    # Split the query by semicolons but preserve semicolons in other contexts
    import re
    sql_statements = re.split(r';(?=(?:[^\']*\'[^\']*\')*[^\']*$)', create_table_query)

    # Execute each statement
    for sql in sql_statements:
        sql = sql.strip()
        if sql:  # Skip empty statements
            print(f"Executing: {sql}")
            statement.execute(sql)

    print("All DDL statements executed successfully")

In [43]:
create_postgres_table()

Executing: CREATE TABLE IF NOT EXISTS public.traffy_fondue
                     (
                         ticket_id      VARCHAR(20)      NOT NULL PRIMARY KEY,
                         complaint      TEXT             NOT NULL,
                         timestamp      TIMESTAMP        NOT NULL,
                         image          TEXT             NOT NULL,
                         image_after    TEXT             NOT NULL,
                         latitude       DOUBLE PRECISION NOT NULL,
                         longitude      DOUBLE PRECISION NOT NULL,
                         district       VARCHAR(128)     NOT NULL,
                         subdistrict    VARCHAR(128)     NOT NULL,
                         categories     VARCHAR(128)[]   NOT NULL,
                         categories_idx REAL[]           NOT NULL
                     )
Executing: ALTER TABLE public.traffy_fondue
                         OWNER TO postgres
Executing: CREATE UNIQUE INDEX IF NOT EXISTS idx_traffy_fond