In [None]:
!apt-get install openjdk-8-jdk-headless -qq > /dev/null

!wget -q https://downloads.apache.org/spark/spark-3.5.7/spark-3.5.7-bin-hadoop3.tgz

!tar xf spark-3.5.7-bin-hadoop3.tgz

import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-3.5.7-bin-hadoop3"

!pip install -q findspark
import findspark
findspark.init()

from pyspark.sql import SparkSession
spark = SparkSession.builder\
        .master("local[5]")\
        .appName("Colab")\
        .config('spark.ui.port', '4050')\
        .getOrCreate()

---
**SparkConf | SparkSession | SparkContext**

---

In [None]:
##############################################################################
################################# SPARK CONF #################################
##############################################################################

https://github.com/KaranBulani/spark_practise/blob/main/SBDL/conf/spark.conf

[LOCAL]
spark.app.name = sbdl-local
spark.jars.packages = org.apache.spark:spark-sql-kafka-0-10_2.12:3.3.0
[QA]
spark.app.name = sbdl-qa
spark.jars.packages = org.apache.spark:spark-sql-kafka-0-10_2.12:3.3.0
spark.executor.cores = 5
spark.executor.memory = 10GB
spark.executor.memoryOverhead = 1GB
spark.executor.instances = 20
spark.sql.shuffle.partitions = 800
[PROD]
spark.app.name = sbdl
spark.jars.packages = org.apache.spark:spark-sql-kafka-0-10_2.12:3.3.0
spark.executor.cores = 5
spark.executor.memory = 10GB
spark.executor.memoryOverhead = 1GB
spark.executor.instances = 20
spark.sql.shuffle.partitions = 800
'''
Other imp configs:
spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles")

spark.task.maxFailures      # Spark retries the failed task up to the number (default is 4) of times

spark.sql.shuffle.partitions = 200 (default)
# Dataframe API & Spark SQL - how many shuffle partitions Spark creates when doing wide transformations

spark.shuffle.partitions (RDD API)
# sc.getConf().set("spark.shuffle.partitions", "50")
# Controls number of partitions after reduceByKey, join, and other RDD shuffles.
spark.shuffle.io.maxRetries   # number of times Spark will retry I/O operations during shuffle.

spark.default.parallelism (RDD API) - If you have 4 executors, each with 5 cores, your cluster has 20 cores ‚Üí spark.default.parallelism = 20.

# Case 1: RDD ReduceByKey (No partitions specified)
# rdd = sc.textFile("data.txt")
# wordcounts = rdd.map(lambda x: (x, 1)).reduceByKey(lambda x, y: x + y)
# This uses: spark.default.parallelism, Any Non shuffle operation will have this number of partitions.

# Case 2: DataFrame GroupBy
# df.groupBy("id").count()
# This uses: spark.sql.shuffle.partitions as it causes shuffle

# driver config, ususally added spark-submit
spark.driver.cores = 2                      # cores to allocate for your driver
spark.driver.memory = 4G                    # JVM memory to allocate for your driver
spark.driver.memoryOverhead = 1G            # memory for non-JVM off heap operation (Used by PySpark Driver)

# spark.driver.memoryOverhead calculation can be done using below formula
spark.driver.memoryOverhead = max(spark.driver.memory * spark.driver.memoryOverheadFactor, 384 MB)
                            = max(4G * 0.10, 384)
                            = 409.6 MB
But we would do collect etc so we manually increased spark.driver.memoryOverhead to 1G

# executor config, usually added spark.conf
spark.executor.cores = 5                    # cpu cores for each executor, can improve parallelism and throughput.
spark.executor.memory = 10GB                # JVM memory to allocate for each executor, Low executor memory typically causes out-of-memory (OOM) errors
spark.executor.memoryOverhead = 1GB         # non-JVM memory for each executor
spark.executor.instances = 20               # how many executors
spark.sql.shuffle.partitions = 800

# ratio which decides how spark.executor.memory is divided
spark.memory.fraction = 0.6                 # % of memory used by spark
spark.memory.storageFraction = 0.5          # % of memory protected from cache eviction
spark.python.worker.memory                  # max for Py4J bridge

# ratio which decides how spark.executor.memoryOverhead is divided
# spark.executor.memoryOverheadFactor % used to calculate memoryOverhead
spark.executor.memoryOverhead = max(spark.executor.memory * spark.executor.memoryOverheadFactor, 384 MB)
                              = max(10G * 0.10, 384)
                              = 1G

# spark.executor.pyspark.memory - maximum for Python worker (Default 0)
# spark.memory.offHeap.enabled = true - Disabled by default
# spark.memory.offHeap.size : Memory reserved for off-heap allocations (Default 0)

# ANALYZE TABLE is a SQL command used to compute table statistics, which the Catalyst optimizer uses to choose the most efficient execution plan.
spark.sql("ANALYZE TABLE sales COMPUTE STATISTICS FOR COLUMNS product_id, price")
spark.sql("ANALYZE TABLE sales COMPUTE STATISTICS")

Adaptive Query Execution (AQE)
spark.sql.adaptive.enabled = true (default)
  # Dynamically coalesces shuffle partitions: Reduces partitions based on data size (avoids scheduling empty "created due to spark.sql.shuffle.partitions" or under utilized tasks)
  # Dynamically switches join strategies: Converts sort-merge to broadcast join if small (shuffle stays, sort gets eliminated)
  # Dynamically optimizes skew joins: Handle data skew during a join. Splits skewed partitions # Salting also does this but AQE does post shuffle whereas we implement salting pre shuffle
  # Benefits: Better performance without manual tuning, adapts to runtime statistics

spark.sql.adaptive.coalescePartitions.initialPartitionNum       # Starts with a set value; if not explicitly configured, it defaults to spark.sql.shuffle.partitions
spark.sql.adaptive.coalescePartitions.minPartitionNum           # lower bound after coalescing; defaults to spark.default.parallelism if not set.
# There isn‚Äôt a maxPartitionNum - Because: AQE only coalesces (i.e., reduces) shuffle partitions. It never increases the number of shuffle partitions.
spark.sql.adaptive.advisoryPartitionSizeInBytes                 # (default: 64MB) ideal size for partitions during coalescing and splitting.
spark.sql.adaptive.coalescePartitions.enabled = false           # (default: true) AQE will combine small partitions; setting it to false disables this feature.

spark.sql.autoBroadcastJoinThreshold = 10MB
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1") # disables
spark.sql.adaptive.localShuffleReader.enabled=true  (default)
# Reduces network traffic by reading shuffle data locally

spark.sql.adaptive.skewJoin.enabled = true                                      # Enable skew join optimization
spark.sql.adaptive.skewJoin.skewedPartitionFactor (default - 5)                 # if its size exceeds five times the median partition size.
spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes (default - 256MB)   # partition must exceed this size threshold to be considered skewed

spark.sql.adaptive.maxShuffledHashJoinLocalMapThreshold                         # AQE converts a sort-merge join to a shuffled hash join when all post-shuffle partitions are smaller than the threshold

Dynamic Partition Pruning
spark.sql.optimizer.dynamicPartitionPruning.enabled (default True)

Speculative Execution (Mitigate slow-running tasks: Launch duplicate "speculative" tasks on different nodes; first successful task is retained, others killed.)
spark.speculation (default: false)

Parameter | Default | Description
spark.speculation.interval | 100 ms | Frequency of checking  for slow tasks.
spark.speculation.multiplier | 1.5 | Runtime multiplier over median to flag a task as ‚Äúslow‚Äù. For a task running 5 sec 7.5 seconds is considered slow.
spark.speculation.quantile | 0.75 / 75% | Fraction of tasks that must finish before speculation begins for a stage.
spark.speculation.task.minRuntime | 100 ms | Minimum amount of time a task should run before it‚Äôs eligible for speculation. (avoids very short tasks).
spark.speculation.task.maxTaskRuntime | (none) | Hard upper bound on task duration to trigger speculation unconditionally.

Spark Scheduling Overview

Across application ‚Äì how the cluster manager shares resources among multiple Spark apps.
  Static Allocation (first-come, first reserve) (default)
  Dynamic Allocation (Request executors  when needed Automatically Release idle executors back to cluster, Yarn RM has nothing to do here)

spark.dynamicAllocation.enabled = true                        # disabled by default
spark.dynamicAllocation.shuffleTracking.enabled = true        # enable shuffle tracking to prevent removing executors that still hold shuffle data.
spark.shuffle.service.enabled = true                          # Or set this to true. purpose of the external shuffle service is to allow executors to be removed without deleting shuffle files written by them.
spark.dynamicAllocation.executorIdleTimeout = 60s             # If an executor is not running any task for 60s, Spark Application will release executor back to cluster manager
spark.dynamicAllocation.schedulerBacklogTimeout = 1s          # requests new executors if any tasks are pending for more than 1 second

Within application ‚Äì how a single Spark app requests/releases executors over its lifetime while running multiple action/jobs.
Sequential Execution (FIFO): By default, jobs run one after another (Job 1 ‚Üí finish ‚Üí Job 2 ‚Üí ‚Ä¶), even if they‚Äôre independent of each other.
FAIR Scheduler: Tasks assigned in round-robin across all active jobs.
                spark.conf.set("spark.scheduler.mode", "FAIR")

SPARK HISTORY SERVER (web UI that lets you view past Spark applications, Unlike the Spark UI which disappears after the application ends)
spark.history.fs.logDirectory = hdfs:///spark-logs                  # for the History Server's storage (usually this & below are same)

After enabling the Spark History Server, these two configuration properties are essential to configure your application to store event logs
spark.eventLog.enabled  = true                                                         # turns on event logging
spark.eventLog.dir      = hdfs:///spark-logs                                           # where to store the logs.


remote/ -- remote is used internally to establish connection with Spark Connect

'''
import configparser
from pyspark import SparkConf
def get_spark_conf(env: str) -> SparkConf:
    """
    Loads Spark configuration from conf/spark.conf file based on environment section.
    """
    spark_conf = SparkConf()
    config = configparser.ConfigParser()
    config.read("conf/spark.conf")

    for (key, val) in config.items(env):
        spark_conf.set(key, val)
    return spark_conf

##############################################################################
################################ SPARK SESSION ###############################
##############################################################################
"""
SparkSession acts as the driver. Entry point every program has it.
Unifier for: SparkContext for core, SQLContext for SQL, HiveContext for Hive, StreamingContext for streaming.

üîç What SparkSession fully unifies

‚úî Spark SQL
* Running SQL queries
* Creating temp views
* Catalog access
* Table metadata
* Parsing SQL automatically into DataFrame operations
Example:
spark.sql("SELECT * FROM customers WHERE age > 25")

‚úî Hive Integration
When enabled, SparkSession becomes Hive-aware:
spark.enableHiveSupport()
spark.sql("SHOW DATABASES")

‚úî Catalog Access
access to internal metadata:
spark.catalog.listTables()
spark.catalog.listDatabases()
spark.catalog.listFunctions()

‚úî Data Sources (Unified DataFrameReader & Writer & Transformer)
SparkSession reads everything through one API:
spark.read.csv()
spark.read.json()
spark.read.parquet()
spark.read.format("delta").load()
spark.read.jdbc()

You can create, transform, filter, group, join DataFrames:
df = spark.read.csv("path.csv", header=True, inferSchema=True)
df.groupBy("country").count().show()

Same for writing.

‚úî Underlying SparkContext
SparkSession contains the original heart of Spark:
sc = spark.sparkContext
So it still supports RDD API:
rdd = sc.parallelize([1,2,3])

"""
from pyspark.sql import SparkSession

def get_spark_session(env):
    """Create SparkSession based on environment"""
    if env == "LOCAL":
        return SparkSession.builder \
            .appName("PySpark App") \
            .config(conf=get_spark_conf(env)) \ # get_spark_conf defined above
            .config('spark.sql.autoBroadcastJoinThreshold',-1) \
            .config('spark.sql.adaptive.enabled','false') \
            .config('spark.driver.extraJavaOptions','-Dlog4j.configuration=file:log4j.properties') \ # pass JVM options to the Spark driver
            .master("local[2]") \
            .enableHiveSupport() \
            .getOrCreate()
    else:
        return SparkSession.builder \
            .config(conf=get_spark_conf(env)) \ # get_spark_conf defined above
            .enableHiveSupport() \
            .getOrCreate()
'''
The configuration of the SparkSession can be changed afterwards
spark.conf.set("spark.sql.shuffle.partitions", 3)

spark.conf.get("key") method retrieves the value of a configuration property in a running PySpark session.
spark.conf.getAll() method retrieves all configuration properties in a running PySpark session.
'''
##############################################################################

spark = Utils.get_spark_session(job_run_env)

'''
Normal transformation
'''
spark.stop()
'''
This does:
  Releases cluster resources
  Closes executors
  Cleans shuffle files
  Stops SparkContext
  Ends all SQL/Hive sessions
'''

##############################################################################
################################ SPARK CONTEXT ###############################
##############################################################################

"""
THEORY:
DataFrame API is based on SparkSession, while the RDD API is based on SparkContext.

	from pyspark import SparkConf, SparkContext
	#Some logic where conf = SparkConf ()....
	sc = SparkContext (conf=conf)

To use the RDD API, you need a SparkContext, which can be created with a SparkConf passed as parameter. However in newer Spark versions,
  ‚óè instead of creating it separately , you can access the SparkContext from the SparkSession.
  ‚óè SparkSession in newer version , is a higher-level abstraction that internally uses SparkContext .
  ‚óè You can directly access SparkContext via SparkSession. Access SparkContext from SparkSession (SparkSession.sparkContext).

#############################################################################

USAGE 1:
  Use SparkSession.sparkContext.getConf().getAll() to retrieve all configurations.
  Used while creating logger https://github.com/KaranBulani/spark_practise/blob/main/09-RowDemo/lib/logger.py#L5
"""
from pyspark import SparkContext
# OR
from pyspark.sql import SparkSession.sparkContext

#USAGE 2:
sc = spark.sparkContext
linesRDD = sc.textFile(sys.argv[1]) # returns RDD  where each record in RDD - a line from a text file

# implement select(), filter(), groupBy() of dataframe with basic RDD function
partitionedRDD = linesRDD.repartition(2)
colsRDD = partitionedRDD.map(lambda line: line.replace('"', '').split(","))
selectRDD = colsRDD.map(lambda cols: SurveyRecord(int(cols[1]), cols[2], cols[3], cols[4] ))
filteredRDD = selectRDD.filter(lambda r: r.Age < 40)
kvRDD = filteredRDD.map(lambda r: (r.Country, 1))
countRDD = kvRDD.reduceByKey(lambda v1, v2: v1 + v2)

colsList = countRDD.collect()

"""
USAGE 3 OUTDATED:
    my_rows = [Row("123","04/05/2020"), Row("124","4/5/2020"), Row("125", "04/5/2020"), Row("126", "4/05/2020")]
    my_rdd = spark.sparkContext.parallelize(my_rows, 2)
    my_df = spark.createDataFrame(my_rdd, my_schema)

Convert local list ‚Üí RDD
Split the RDD into 2 partitions
Let Spark process those partitions in parallel

RATHER DO UPDATED:
my_df = spark.createDataFrame(my_rows, my_schema).repartition(2)

#############################################################################

USAGE 4:
  when using accumulator/ broadcast variables

"""

---
**Schema | Reading Dataframe | Creating DataFrame using list, dict, pandas df**

---



---

In [None]:
##############################################################################
################################ SPARK SCHEMA ################################
##############################################################################

# Spark SQL Types
# Data types for schema definition
from pyspark.sql.types import (
    StructType, StructField, StringType, IntegerType,
    LongType, FloatType, DoubleType, BooleanType,
    DateType, TimestampType, ArrayType, MapType, DecimalType
)

# Method 1: Using StructType
# StructType(fields=None) - Struct type, consisting of a list of StructField
# StructField(name, dataType, nullable=True, metadata=None)
flight_schema_struct = StructType([
    StructField("FL_DATE", DateType(), True),
    StructField("OP_CARRIER", StringType(), True),
    StructField("OP_CARRIER_FL_NUM", IntegerType(), True),
    StructField("ORIGIN", StringType(), True),
    StructField("ORIGIN_CITY_NAME", StringType(), True),
    StructField("DEST", StringType(), True),
    StructField("DEST_CITY_NAME", StringType(), True),
    StructField("CRS_DEP_TIME", IntegerType(), True),
    StructField("DEP_TIME", IntegerType(), True),
    StructField("WHEELS_ON", IntegerType(), True),
    StructField("TAXI_IN", IntegerType(), True),
    StructField("CRS_ARR_TIME", IntegerType(), True),
    StructField("ARR_TIME", IntegerType(), True),
    StructField("CANCELLED", IntegerType(), True),
    StructField("DISTANCE", IntegerType(), True)
])
flight_schema_struct.add("new_column",StringType(),True)

# DataFrame.schema Returns the schema of this DataFrame as a pyspark.sql.types.StructType
df.schema["column_name"].dataType

# Method 2: Using DDL String
 = """
    FL_DATE DATE, OP_CARRIER STRING, OP_CARRIER_FL_NUM INT, ORIGIN STRING,
    ORIGIN_CITY_NAME STRING, DEST STRING, DEST_CITY_NAME STRING,
    CRS_DEP_TIME INT, DEP_TIME INT, WHEELS_ON INT, TAXI_IN INT,
    CRS_ARR_TIME INT, ARR_TIME INT, CANCELLED INT, DISTANCE INT
"""

##############################################################################
################################# SPARK.READ #################################
##############################################################################

# DataFrameReader is returned from spark.read
# https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameReader.html#pyspark.sql.DataFrameReader

# If it sounds like a Spark DataFrame feature ‚Üí camelCase.
# If it sounds like a basic file-read/write behavior ‚Üí lowercase.

# wholetext option to read each file as one row and captures its path with input_file_name()
# from pyspark.sql.functions import input_file_name
# df = (spark.read
#           .format("text")
#           .option("wholetext", "true")
#           .load("s3a://bucket/dir/*")
#           .select(
#               input_file_name().alias("path"),
#               col("value").alias("content")
#           ))

def read_data_examples(spark: SparkSession):
# Small stage created for inferSchema
    # CSV with schema
    csv_df = spark.read \
        .format("csv") \
        .option("header", "true") \
        .option("delimiter", "\t") \ # Read with custom delimiter
        .schema(flight_schema_struct) \ # Can use .option("inferSchema", "true")  if confident about schema
        .option("mode", "FAILFAST") \ # mode is option in spark.read whereas not an option in spark.write
        .option("dateFormat", "M/d/y") \
        .load("data/flight*.csv")

# mode options:
# PERMISSIVE (default: set corrupted field to null, Stores corrupted records  in a _corrupt_record column of Dataframe)
# DROPMALFORMED (loads only valid ones),
# FAILFAST (Raises an exception and halts execution)

    # JSON with DDL schema
    json_df = spark.read \
        .format("json") \
        .schema(flight_schema_ddl) \
        .option("dateFormat", "M/d/y") \
        .load("data/flight*.json")

    # Parquet (schema inferred)
    parquet_df = spark.read \
        .format("parquet") \
        .load("data/flight*.parquet")

    bq_df = spark.read \
        .format("bigquery") \
        .option("filter", "_PARTITIONDATE > '2019-01-01'") \
        .load("bigquery-public-data.samples.shakespeare")

    return csv_df, json_df, parquet_df, bq_df

# Alternative shorthand syntax
csv_df2 = spark.read.csv("data/*.csv", header=True, inferSchema=True, sep="|")
json_df2 = spark.read.json("data/*.json")
parquet_df2 = spark.read.option("mergeSchema", "true").parquet("/mnt/data/parquet_dir") # A directory with multiple Parquet files, where each file has a slightly different but compatible schema. mergeSchema option merges all into a single DataFrame
# mergeSchema option prevents failures on mismatched schemas.

# to read Parquet files with nested folders:
spark.read.parquet("root/2024/*/*")
# Wildcard (*) usage:
# Single * ‚Üí matches one directory level.
# Double * ‚Üí matches two levels (e.g., month and day).
# Spark supports wildcards in paths when reading data.

# Read from Hive table
hive_df = spark.read.table("database.table_name")

# Read with custom delimiter
tsv_df = spark.read \
    .option("header", "true") \
    .option("delimiter", "\t") \
    .csv("data/file.tsv")
# csv() Comma Seperated Value can also read tsv() Tab Seperated Value.

##############################################################################
############################ SPARK.createDataFrame ###########################
##############################################################################

def create_dataframe_examples(spark: SparkSession):
    """
    Examples of creating DataFrames from Python data structures
    """
    from datetime import date

    # Method 1: From list of tuples with explicit schema
    data = [
        (date(2024, 1, 1), "AA", 100, "JFK", "New York, NY", "LAX", "Los Angeles, CA", 800, 805, 1130, 15, 1145, 1145, 0, 2475),
        (date(2024, 1, 2), "DL", 200, "ATL", "Atlanta, GA", "ORD", "Chicago, IL", 900, 910, 1045, 10, 1100, 1055, 0, 606),
        (date(2024, 1, 3), "UA", 300, "SFO", "San Francisco, CA", "SEA", "Seattle, WA", 1200, None, None, None, 1400, None, 1, 679)
    ]

    df_with_schema = spark.createDataFrame(data, schema=flight_schema_struct)

    # Method 2: From list of dictionaries (schema inferred)
    data_dict = [
        {"FL_DATE": date(2024, 1, 1), "OP_CARRIER": "AA", "OP_CARRIER_FL_NUM": 100,
         "ORIGIN": "JFK", "DEST": "LAX", "DISTANCE": 2475},
        {"FL_DATE": date(2024, 1, 2), "OP_CARRIER": "DL", "OP_CARRIER_FL_NUM": 200,
         "ORIGIN": "ATL", "DEST": "ORD", "DISTANCE": 606}
    ]

    df_from_dict = spark.createDataFrame(data_dict)

    # Method 3: From pandas DataFrame
    # A major difference is evaluation model:
    # Spark DataFrames ‚Üí lazy evaluation (plan is built, executed only on action).
    # Pandas DataFrames ‚Üí eager evaluation, producing immediate results, which is helpful for prototyping and debugging on smaller data.

    import pandas as pd

    pandas_df = pd.DataFrame({
        "FL_DATE": [date(2024, 1, 1), date(2024, 1, 2)],
        "OP_CARRIER": ["AA", "DL"],
        "OP_CARRIER_FL_NUM": [100, 200],
        "ORIGIN": ["JFK", "ATL"],
        "DEST": ["LAX", "ORD"],
        "DISTANCE": [2475, 606]
    })

    df_from_pandas = spark.createDataFrame(pandas_df)

    # Method 4: From RDD with schema
    rdd = spark.sparkContext.parallelize(data)
    df_from_rdd = spark.createDataFrame(rdd, schema=flight_schema_struct)

    # Method 5: Empty DataFrame with schema
    empty_df = spark.createDataFrame([], schema=flight_schema_struct)

    return df_with_schema, df_from_dict, df_from_pandas, df_from_rdd, empty_df


---
**Dataframe Write | Coalesce, Repartition (Transformation)**

---

In [None]:
##############################################################################
################################ SPARK.WRITE #################################
##############################################################################

# DataFrameWriter is returned from spark.write
# https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrameWriter.html

'''
If the setting is generic across all sources/sinks ‚Üí it has a direct function.
If Spark needs the setting to decide HOW to perform read/write ‚Üí it's a direct function (e.g., format, schema, mode)

If the setting is source-specific or sink-specific ‚Üí it goes inside .option()
If Spark needs the setting to configure a particular format ‚Üí it's an option (e.g., header, delimiter, multiline, bootstrap.servers)

| Method                                          | When used          | Why it's not an option                         |
| ----------------------------------------------- | ------------------ | ---------------------------------------------- |
|  .format("csv")                                 | Select data source | Core API ‚Äî choosing the backend implementation |
|  .mode("overwrite")                             | Save mode          | Applies to *all* writes                        |
|  .schema(mySchema)                              | Define schema      | Works for any structured source                |
|  .load()  /  .save()                            | Trigger IO         | Core action/method                             |
|  .option("path", ...)  (write only alternative) | Path shortcut      | Special exception                              |
'''

# Partitioned write with max records per file
df.write \
    .format("json") \ # Internal (CSV, JSON, Parquet(default), AVRO, ORC) vs External (JDBC, Cassandra, MongoDB, Kafka, Delta Lake)
    .mode("overwrite") \ # Write modes NOT SET VIA OPTION
    .option("path", "output/json/") \ # Minimum requirement : Specify the target
    .partitionBy("OP_CARRIER", "ORIGIN") \ # Each partition of the DataFrame creates one output file, regardless of executors. Unless mentioned partitionBy
    .option("maxRecordsPerFile", 10000) \ # Limits the number of records per file
    .save() #Path can be mentioned here # action

# Write modes:
# - error or errorIfExists: Throws error if data exists (default)
# - append: Adds new data without modifying existing files
# - overwrite: Deletes existing files and writes new ones
# - ignore: Writes only if target is empty

# DataFrameWriter - does not support any ordering options.
# Spark writes data in parallel, and ordering requires a global shuffle, which is not guaranteed when writing multiple files.
# If you want the entire dataset globally sorted, you can do: df.orderBy("partion_columns_used_during_write","some_other_column") This some_other_column should not be used in partition by.
# But Spark cannot guarantee file-level sort
# Even if you sort before writing, Spark still writes data in parallel, so:
# Files can contain sorted chunks
# But file boundaries are not guaranteed
# Within each file, ordering may break depending on shuffle partitions
# If you must have fully sorted single file output then disable parallelism:
df.orderBy("some_column") \
  .coalesce(1) \
  .write \
  .format("json") \
  .mode("overwrite") \
  .save("output/json/")

# For DataFrameWriter.text - DataFrame must have only one column that is of string type. Each row becomes a new line in the output file.


# Coalesce before write (reduce number of output files by merging partitions on the same worker) returns original DataFrame if target > current
# If the single coalesced partition does not fit in executor memory: Spark will spill data to disk
# If it still cannot handle it: Task fails with OOM / GC error
# After retries: The entire job fails
df.coalesce(1).write.csv("output/single_file.csv")

# Repartition before write (increase parallelism) also controls the number of output files. Blind repartitioning (not on basis of a column). Create empty partition if repartition count > number of rows
# repartition(numPartitions, *cols) - Returns a new DataFrame partitioned by the given partitioning expressions. The resulting DataFrame is hash partitioned.
# numPartitions - can be an int to specify the target number of partitions or a Column. If it is a Column, it will be used as the first partitioning column. If not specified, the default number of partitions is used.
# cols - str or Column - partitioning columns.
df.repartition(10).write.parquet("output/repartitioned")
df.repartition("category").write.parquet("output/repartitioned") # Spark decides the number of partitions based on the shuffle.
df.repartition(8, "category").write.parquet("output/repartitioned") # there may be more categories than partitions but 8 partitions only

# Check partitions
df.rdd.getNumPartitions()

# Write with compression
df.write \
    .option("compression", "snappy") \
    .parquet("output/compressed") # action
# Acceptable compression values include: none, uncompressed, snappy, gzip, lzo, brotli, lz4, lz4_raw, zstd. Note that brotli requires BrotliCodec to be installed.

# In this method the data is written first to GCS, and then it is loaded it to BigQuery. A GCS bucket must be configured to indicate the temporary data location.
df.write \
  .format("bigquery") \
  .option("temporaryGcsBucket","some-bucket") \
  .save("dataset.table") # action


---
**Dataframe Write DB | BucketBy, sortBy**

---

In [None]:
spark.sql("CREATE DATABASE IF NOT EXISTS AIRLINE_DB")
spark.catalog.setCurrentDatabase("AIRLINE_DB")
# spark.sql("USE AIRLINE_DB")

# bucketBy & sortBy only works with Hive and needs saveAsTable(), bucket + sort improves join operation (eliminating shuffle & Bucketed datasets can be joined multiple times without additional shuffling)
# also keep .enableHiveSupport() in SparkSession

# Bucketing: Best for optimizing joins and ensuring shuffle reduction (on columns like customer_id). As records with the same key are placed in the same bucket
# Partitioning: Best for filter-based pruning (e.g., year, month, region).
# Too fine-grained partitioning (like date): Creates small files and metadata overhead.

# Each partition folder contains bucketed files.
# Creates a directory for each partition value (date):
# Within each partition folder, rows are hashed into N buckets. If no partition is mentioned then just N buckets. Align N buckets with N node cluster


# Rest all DataFrame write is same

df.write \
  .mode("overwrite") \
  .partitionBy("date") \
  .bucketBy(32, "customer_id") \
  .sortBy("customer_id") \
  .saveAsTable("sales_fact") #table_name

# This is created as a managed Table. Once we mention path it becomes external table.
# Managed table ‚Üí Spark controls both data + metadata ‚Üí No path specified
# External table ‚Üí Spark controls only metadata ‚Üí Path is specified
# If you are dropping an unmanaged table, no data will be removed but you will no longer be able to refer to this data by the table name.


print(spark.catalog.listTables("AIRLINE_DB"))
# https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.Catalog.html
# catalog is abstraction for the storage of metadata about the data stored in your tables as well as other helpful things like databases, tables, functions, and views
spark.catalog.cacheTable("tbl1")
spark.catalog.uncacheTable(‚Äútable‚Äù)




NameError: name 'spark' is not defined

---
**Spark SQL**

---

In [None]:
# Create temporary view (accessible within current SparkSession)
df.createOrReplaceTempView("sales_view")
result = spark.sql("""
    SELECT Country, SUM(Quantity) as total_qty
    FROM sales_view
    WHERE InvoiceDate >= '2010-01-01'
    GROUP BY Country
    ORDER BY total_qty DESC
""")
# Spark SQL does not guarantee the order in which subexpressions are evaluated.

spark.read.table("sales_view") # Not correct as per AI
spark.table("sales_view")

# Create global temp view (accessible across SparkSessions)
df.createGlobalTempView("global_sales")
spark.sql("SELECT * FROM global_temp.global_sales").show()
spark.read.table("global_temp.global_sales")
spark.table("global_temp.global_sales")

# You have a Parquet dataset at path /mnt/data/some_file.parquet. Without creating a table, how do you query it directly
# SELECT * FROM parquet.`/mnt/data/event_file.parquet`;

# USING IS KEPT IN CREATE TABLE STATEMENT
# CREATE TABLE my_table (   id INT,  name STRING,  salary DOUBLE ) USING PARQUET;

---
**Join** + **Broadcast**

---

In [None]:
'''
https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.join.html#pyspark.sql.DataFrame.join
DataFrame.join(other, on=None, how=None)
| Spark Join Type  | All Aliases (identical)            | 														                           	Meaning 						        							                        |
| ---------------- | ---------------------------------- | ------------------------------------------------------------------------------------------------------------------------- |
|   inner(default) | inner                              | Returns only rows that have matching keys in both DataFrames. 							                            							  	|
|   cross          | cross                              | Cartesian product ‚Üí every row in A combined with every row in B. 					                            									  |
|   full outer     | full, outer, fullouter, full_outer | Returns all rows from both DataFrames. Where keys don't match, unmatched columns become NULL.                             |
|   left outer     | left, leftouter, left_outer        | Returns all rows from left DataFrame and matching rows from right. Unmatched right rows ‚Üí NULL.							              |
|   right outer    | right, rightouter, right_outer     | Opposite of left join. Returns all rows from right, and matching from left.												                        |
|   left semi      | semi, leftsemi, left_semi          | Returns only left-side rows that have at least one match on the right. Does not include any columns from right DataFrame. If there is no matching row on the right, Spark removes that left-row from the result. |
|   left anti      | anti, leftanti, left_anti          | Returns only left-side rows where there is NO match on right. keep NON-matches															              |
'''
##############################################################################
#################################### JOIN ####################################
##############################################################################
join_expr = order_df.prod_id == product_df.prod_id
joined_df = order_df.join(product_df, join_expr, "inner")

# Alternative join syntax
joined_df = order_df.join(product_df, "prod_id", "inner") #"prod_id"should be present with same name on both df
joined_df = order_df.join(product_df, ["prod_id", "category"], "inner")

# Left Semi Join (returns only left side where match exists)
semi_join_df = order_df.join(product_df, join_expr, "left_semi")

# Left Anti Join (returns only left side where no match)
anti_join_df = order_df.join(product_df, join_expr, "left_anti")

# Cross Join (Cartesian product)
df_cross = df1.join(df2, how="cross")
cross_join_df = order_df.crossJoin(product_df)

from pyspark.sql.functions import expr

# Complex Example
def join_examples(order_df: DataFrame, product_df: DataFrame):
    join_expr = order_df.prod_id == product_df.prod_id
    product_renamed = product_df.withColumnRenamed("qty", "reorder_qty")

    # Inner Join
    order_df.join(product_renamed, join_expr, "inner") \
        .drop(product_renamed.prod_id) \	# Handle duplicate column names after join
        .select("order_id", "prod_id", "prod_name", "unit_price", "list_price", "qty") \
        .show()

    # Left Join with coalesce + sort
    order_df.join(product_renamed, join_expr, "left") \
        .drop(product_renamed.prod_id) \
        .select("order_id", "prod_id", "prod_name", "unit_price", "list_price", "qty") \
        .withColumn("prod_name", expr("coalesce(prod_name, prod_id)")) \			# Join with coalesce for null handling
        .withColumn("list_price", expr("coalesce(list_price, unit_price)")) \		# Join with coalesce for null handling
        .sort("order_id") \
        .show()

##############################################################################
################################# BROADCAST ##################################
##############################################################################

from pyspark.sql.functions import broadcast

# Broadcast Join (for small tables)
join_expr = large_df.id == small_df.id
broadcast_join_df = large_df.join(broadcast(small_df), join_expr, "inner")


"\n| Spark Join Type  | All Aliases (identical)            | \t\t\t\t\t\t\t\t\t\t\t\t\t\t                           \tMeaning \t\t\t\t\t\t        \t\t\t\t\t\t\t                        |\n| ---------------- | ---------------------------------- | ------------------------------------------------------------------------------------------------------------------------- |\n|   inner(default) | inner                              | Returns only rows that have matching keys in both DataFrames. \t\t\t\t\t\t\t                            \t\t\t\t\t\t\t  \t|\n|   cross          | cross                              | Cartesian product ‚Üí every row in A combined with every row in B. \t\t\t\t\t                            \t\t\t\t\t\t\t\t\t  |\n|   full outer     | full, outer, fullouter, full_outer | Returns all rows from both DataFrames. Where keys don't match, unmatched columns become NULL.                             |\n|   left outer     | left, leftouter, left_outer        | Returns all rows fro

---
**Transformations (Narrow/Wide/Window)**

---

In [None]:
##############################################################################
############################ Basic Transformation ############################
##############################################################################

from pyspark.sql.functions import col, expr, lit, concat, asc, desc

# Select columns
df.select("col1", "col2", "col3")
df.select(col("col1"), col("col2"))
df.select(col("col1").alias("new_name"))
# Note these also work instead of col("a")
# df.col1
# df["a"]

# df can be aliased too
# customers = customers.alias('c')
# then when using col() it would be col("c.id")

df.selectExpr(
    "name",
    "age",
    "CASE WHEN age >= 35 THEN 'Senior' ELSE 'Junior' END as age_category"
).show()


# Select with expressions
df.select(
    col("name"),
    (col("price") * 1.1).alias("price_with_tax"),
    expr("price * 1.1 as price_with_tax2")
)

# Filter / Where
df.filter("age > 21")
df.filter(col("age") > 21)
df.filter(col("age").between(18, 30))
df.where(col("status") == "active")
df.where((col("age") > 21) & (col("country") == "USA")) # & AND
df.where((col("age") > 21) | (col("country") == "USA")) # | OR
df.where(~col("first_name").isin("Jill", "Eva")) # ~ NOT
df.where(col("country").isin("USA","INDIA"))
df.where(col("name").like("J%")) # Normal Like
df.where(col("name").rlike("^J.*")) # Regex Like
df.where(col("name").contains("ill")) # '%ill%'
df.where(col("name").startswith("Ji"))
df.where(col("name").endswith("va"))

df.where("CallType is not null")

# Add new column # In withColumn(), the expression you provide must evaluate to exactly one value per row. 
df.withColumn("new_col", lit("constant_value"))
df.withColumn("price_doubled", col("price") * 2)
df.withColumn("full_name", concat(col("first_name"), lit(" "), col("last_name")))
df.withColumns({"age1": lit(10), "age2": lit(20)}) # withColumns(*colsMap)

# Rename column NOTE ALL ARE CAMELCASE df.
df.withColumnRenamed("old_name", "new_name") # withColumnRenamed(existing, new)
df.withColumnsRenamed({"id": "identifier", "name": "full_name"}) # withColumnsRenamed(colsMap)

# Drop columns drop(*cols)
df.drop("col1", "col2")
df.drop(col("col1"))

# Drop rows dropna(how='any', thresh=None, subset=None)
# how: any or all
# thresh: int - drop rows that have less than thresh non-null values. This overwrites the how parameter.
# subset: optional list of column names to consider.
df.dropna()               # how='any'
df.dropna(how='all')
df.dropna(thresh=2)       # less than 2 non-null values stays, rest (2,3,4) gets dropped
df.dropna(subset=["name", "age"]).show()          # returns rows where both are not null

# df.na returns DataFrameNaFunctions
# DataFrameNaFunctions.drop(how='any', thresh=None, subset=None) Returns a new DataFrame omitting rows with null or NaN values.
# DataFrame.dropna() and DataFrameNaFunctions.drop()
df.na.drop(how="any")
df.na.fill(50) # Fill all null values with 50 for numeric columns. (same for all other datatype)
df.na.fill({'age': 50, 'name': 'unknown'})
df.na.fill(50, subset=["age"])

# fillna(value, subset=None)
# value: value to replace null values with int/str/float.
# subset: list of column names to consider.
df.fillna("N/A", subset=["name", "city"])
df.fillna("default_value") # will only fill where datatype is matching

# Drop duplicates dropDuplicates(subset=None) subset list of column names
df.dropDuplicates() #drop_duplicates both same
df.dropDuplicates(["col1", "col2"])
df.distinct()

# Sort / OrderBy
df.sort("col1", "col2")
df.sort(col("col1").asc(), col("col2").desc())
df.orderBy("count", ascending=False)

df.sort("x", desc("y"))
df.sort(asc("x"), desc("y"))
df.orderBy(col("x").asc(), col("y").desc())
df.orderBy(["x", "y"], ascending=[True, False])
'''
| Function                  | Sort Direction | Null Position | Usage Example                       |
| ------------------------- | -------------- | ------------- | ----------------------------------- |
|   asc_nulls_first(col)    | Ascending      | Nulls first   | df.orderBy(asc_nulls_first("age"))  |
|   asc_nulls_last(col)     | Ascending      | Nulls last    | df.orderBy(asc_nulls_last("age")), df.orderBy(df.a.asc_nulls_last()), df.sort(df.a.asc_nulls_last()), df.sort(asc_nulls_last("a"))  |
|   desc_nulls_first(col)   | Descending     | Nulls first   | df.orderBy(desc_nulls_first("age")) |
|   desc_nulls_last(col)    | Descending     | Nulls last    | df.orderBy(desc_nulls_last("age"))  |
'''

# Limit
df.limit(10)

# DataFrame.sample(withReplacement=None, fraction=None, seed=None)
# Returns a sampled subset of this DataFrame. Creating smaller portions of a large dataset for testing, development, or analysis.
# fraction : float, optional - Probability of including each row [0.0, 1.0]. It is a probability, not an exact count. So fraction=0.1 means: each row has a 10% chance to be included.
# withReplacement : Allow duplicates in sample?
# False (default) ‚Äî Each selected row appears at most once. choosing without putting the item back.
# True ‚Äî Rows can be selected multiple times. behaves like picking an item, then putting it back, and picking again.
# seed : Controls randomness to make result repeatable
df.sample(fraction=0.1, seed=42)
df.sample(withReplacement=False, fraction=0.2)

# randomSplit(weights, seed=None) Randomly splits this DataFrame with the provided weights.
# Split DataFrame into training (70%) and testing (30%)
train_df, test_df = df.randomSplit([0.7, 0.3], seed=42)

# ============================================================================
# COLUMN OPERATIONS
# ============================================================================

# DataFrame.columns - Retrieves the names of all columns in the DataFrame as a list.
selected_cols = [col for col in df.columns if col != "age"]
df.select(selected_cols).show()
# +-----+-----+
# | name|state|
# +-----+-----+
# |  Tom|   CA|
# |Alice|   NY|
# |  Bob|   TX|
# +-----+-----+

# Conditional expressions
from pyspark.sql.functions import when
df.withColumn("category",
    when(col("age") < 18, "minor")
    .when((col("age") >= 18) & (col("age") < 65), "adult")
    .otherwise("senior")
)

# Null handling
from pyspark.sql.functions import coalesce
# ifnull() and nvl() return the second argument if the first is NULL, making them functionally identical.
# nullif() does not replace NULL values. It returns NULL if two values are equal.
# nvl2() does not compare values; it evaluates whether the first value is NULL or not and returns one of two provided options. IFNULL (ISNULL,YES,NO)

df.withColumn("col_filled", coalesce(col("col1"), col("col2"), lit("default")))
df.filter(col("col1").isNull())       # Both are present inside functions, but we can skip import
df.filter(col("col1").isNotNull())    # as we are applying on col() which is imported specifically
df.na.drop()  # Drop rows with any null
df.na.drop(subset=["col1", "col2"])  # Drop rows with null in specific columns
df.na.fill(0)  # Fill all nulls with 0
df.na.fill({"col1": 0, "col2": "unknown"})  # Fill specific columns

# String operations
from pyspark.sql.functions import upper, lower, trim, length, substring, concat_ws, split
df.withColumn("upper_name", upper(col("name")))
df.withColumn("lower_name", lower(col("name")))
df.withColumn("trimmed", trim(col("name")))
df.withColumn("name_length", length(col("name"))) # length of column
df.withColumn("first_3_chars", substring(col("name"), 1, 3))
df.withColumn("concatenated", concat_ws("-", col("col1"), col("col2")))

# str : Column or column name to split
# pattern : Column or literal string representing a regular expression should be a Java regular expression.
# limit : Column or column name or int which controls the number of times pattern is applied.
# limit > 0 : The resulting array‚Äôs length will not be more than limit, and the resulting array‚Äôs last entry will contain all input beyond the last matched pattern.
# limit <= 0 (default): pattern will be applied as many times as possible, and the resulting array can be of any size.
df.withColumn("split_array", split(col("address"), ","))
df.withColumn("first_element", split(col("address"), ",").getItem(0))

# Date operations
from pyspark.sql.functions import to_date, to_timestamp, year, month, dayofmonth, weekofyear, typeof, date_format, date_sub, from_unixtime, to_utc_timestamp, from_utc_timestamp, convert_timezone
df.withColumn("date_col", to_date(col("string_date"), "dd-MM-yyyy")) # give a format which is compatible with java timestamp class #Converts to a date
df.withColumn("timestamp_col", to_timestamp(col("string_timestamp"), "yyyy-MM-dd HH:mm:ss"))
df.withColumn("year", year(col("date_col")))
df.withColumn("month", month(col("date_col")))
df.withColumn("day", dayofmonth(col("date_col")))
df.withColumn("week_number", weekofyear(col("date_col")))
df.select(typeof('dt'), date_format('dt', 'yy--MM--dd')) # date | 15--04--08 # Converts a date/timestamp/string to a value of string in the format specified by the date format https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html
df.withColumn("date_plus_7", date_sub(df.dt, 1)) # Returns the date that is days days before start. If days is a negative value then these amount of days will be added to start.
# from_unixtime(timestamp, format='yyyy-MM-dd HH:mm:ss') format is in this format https://spark.apache.org/docs/latest/sql-ref-datetime-pattern.html
df.withColumn("timestamp_from_epoch", from_unixtime('unix_time')) # 1428476400 -> 2015-04-08 00:00:00
df.withColumn("unix_time", unix_timestamp('timestamp_from_epoch')) # 2015-04-08 00:00:00 -> 1428476400
df.withColumn("utc_timestamp", to_utc_timestamp(col("datetime_utc"), "Asia/Tokyo") ) # converts a timestamp in a given time zone into UTC.
df.withColumn("Tokyo_timestamp", from_utc_timestamp(col("datetime_utc"), "Asia/Tokyo")) # converts a timestamp from UTC into the specified time zone.
df.withColumn('HK_2_LA_timestamp', sf.convert_timezone(sf.lit('Asia/Hong_Kong'), sf.lit('America/Los_Angeles'), df.ts)) # convert_timezone(sourceTz, targetTz, sourceTs)Converts the timestamp without time zone sourceTs from the sourceTz time zone to targetTz.

# Cast / Type conversion
from pyspark.sql.types import IntegerType, DoubleType
df.withColumn("col_as_int", col("col1").cast("int"))
df.withColumn("col_as_int2", col("col1").cast(IntegerType()))
df.withColumn("col_as_double", col("col1").cast(DoubleType()))

# Array and Map operations NOTE ALL ARE LOWERCASE INSIDE FUNCTIONS
from pyspark.sql.functions import flatten, array, size, struct, sort_array, collect_list, explode, array_union, array_append, create_map, map_concat, map_keys, array_contains

# creates a single array from an array of arrays. If a structure of nested arrays is deeper than two levels, only one level of nesting is removed.
df.select(flatten(df.data)).show() # [None, [4, 5]], -> NULL   ||   [[1, 2, 3], [4, 5], [6]], -> [1, 2, 3, 4, 5, 6]
df.withColumn("array_col", array(col("col1"), col("col2"), col("col3"))) # [col1_val, col2_val, col3_val] - For access df.select(col("array_col")[0], col("array_col")[1]).show()
df.withColumn("size", size(col("array_col"))) # 3
df.withColumn("struct_col", struct(col("col1"), col("col2"), col("col3"))) # {col1_val, col2_val, col3_val} - For access df.select("struct_col.col1", "struct_col.col2").show()
df.withColumn("address", struct(col("address.city"), col("address.zip").alias("zipcode"))) # renames zip to zipcode

# Struct column
df.select(col("address.city")).show()
df.select("address.city").show()
df.select(col("address")["city"]).show()

# AS PER AI FIRST LINE IS INCORRECT LAST LINE IS CORRECT - Adds a new line
df.withColumn("struct_col.col4", lit(4)) # works as struct_col is already a struct, but if it was map then use map_concat, if it was array use array_append
df.withColumn("address", struct(col("address.*"), lit("country").alias("country")))

# sort_array(col, asc=True) - Sorts the input array in asc or desc order according to natural ordering of the array elements.
# Null elements will be placed at the beginning of the returned array in ascending order or at the end of the returned array in descending order.
# collect_list(col) - Collects the values from a column into a list, maintaining duplicates, and returns this list of objects.
df = spark.createDataFrame([(2,), (5,), (5,)], ('age',))
df.select(sort_array(collect_list('age'), asc=False).alias('sorted_list')).show()

df.withColumn("exploded", explode(col("array_col")).alias("item"))
'''
explode takes a column containing an array or map and turns each element into a separate row.
  *  If the column is an array, you get one row per element.
  *  If the column is a map, you get one row per key-value pair.

# ARRAY
data = [
    (1, ["apple", "banana", "orange"]),
    (2, ["kiwi"]),
    (3, [])
]
df = spark.createDataFrame(data, ["id", "fruits"])
df.show(truncate=False)
	+---+---------------------------+
	|id |fruits                     |
	+---+---------------------------+
	|1  |[apple, banana, orange]    |
	|2  |[kiwi]                     |
	|3  |[]                         |
	+---+---------------------------+

df2 = df.withColumn("fruit", explode("fruits"))
df2.show()
	+---+---------------------------+------+
	|id |fruits                     |fruit |
	+---+---------------------------+------+
	|1  |[apple, banana, orange]    |apple |
	|1  |[apple, banana, orange]    |banana|
	|1  |[apple, banana, orange]    |orange|
	|2  |[kiwi]                     |kiwi  |
	+---+---------------------------+------+

# MAP
data = [
    (1, {"a": 10, "b": 20}),
    (2, {"x": 5})
]
df = spark.createDataFrame(data, ["id", "scores"])
df2 = df.withColumn("kv", explode("scores"))
df2.show()
  +---+-----------+-----+
  |id |scores     |kv   |
  +---+-----------+-----+
  |1  |{a -> 10...}|[a,10]|
  |1  |{a -> 10...}|[b,20]|
  |2  |{x -> 5}   |[x,5] |
  +---+-----------+-----+

Then you can split keys and values using:
# df3 = df2.withColumn("key", col("kv").getItem(0)) \
#          .withColumn("value", col("kv").getItem(1))
# df3.show()

posexplode also gives the position (index):
# df.withColumn("pos_value", posexplode("fruits")).show()

explode_outer example (keeps NULL and empty arrays)
'''
df.withColumn("first_element", col("array_col")[0])

# array_union(col1, col2) - returns a new array containing the union of elements in col1 and col2, without duplicates.
df.select(array_union(df.c1, df.c2)).show()
df_with_new_fruit = df.withColumn("fruits", array_append(df["fruits"], "kiwi"))
'''
+---+---------------------+           +---+-------------------------+
|id |fruits               |           |id |fruits                   |
+---+---------------------+           +---+-------------------------+
|1  |[apple, banana]      |           |1  |[apple, banana, kiwi]    |
|2  |[orange, grape]      |           |2  |[orange, grape, kiwi]    |
+---+---------------------+           +---+-------------------------+
'''
# element_at(col, extraction)
# Returns element of array at given (1-based) index. If Index is 0, Spark will throw an error. If index < 0, accesses elements from the last to the first.
df.select(element_at(col("numbers"), 1).alias("first_element")).show()

# create_map(*cols): Creates a new map column from an even number of input columns or column references.
# map_concat(*cols): Returns the union of all given maps.
df.withColumn("attributes", map_concat(col("attributes"), create_map(lit("C"), lit(3))))
'''
+---+----------------+            +---+------------------------+
| id|      attributes|            |id |attributes              |
+---+----------------+            +---+------------------------+
|  1|{A -> 1, B -> 2}|            |1  |{A -> 1, B -> 2, C -> 3}|
|  2|{A -> 2, B -> 3}|            |2  |{A -> 2, B -> 3, C -> 3}|
|  3|{A -> 3, B -> 4}|            |3  |{A -> 3, B -> 4, C -> 3}|
+---+----------------+            +---+------------------------+
'''
# Returns an unordered array containing the keys of the map.
df.select(map_keys(col("attributes")).alias("keys")).show()

# Checks if value exists in array
df.filter(array_contains(col("skills"), "Python"))


# ============================================================================
# AGGREGATIONS
# ============================================================================

from pyspark.sql.functions import count, sum, avg, min, max, countDistinct, round, expr, approx_count_distinct, percentile_approx

# Simple aggregations: WIDE TRANSFORMATION
df.select(
    count("*").alias("total_count"),
    sum("Quantity").alias("total_quantity"),
    avg("UnitPrice").alias("avg_price"),
    min("UnitPrice").alias("min_price"),
    max("UnitPrice").alias("max_price"),
    countDistinct("InvoiceNo").alias("distinct_invoices")
    approx_count_distinct("InvoiceNo").alias("approx_distinct_invoices"),
    percentile_approx(col, 0.95, 10000) # percentile_approx(col, percentage, accuracy=10000)
).show()

# Aggregate on the entire DataFrame without groups (shorthand for df.groupBy().agg())
product_df.agg(F.max("price").alias("max_price"), F.sum("price")).show()

df.groupBy("department").sum("salary").avg("salary") # also valid but avg is calculated on sum
product_df.agg({"age": "max"}).show()

# GroupBy : WIDE TRANSFORMATION
# returns a groupedData (variant of a dataframe) object instead of a dataframe
# GroupBy with aggregations
summary_df = df.groupBy("Country", "InvoiceNo") \
    .agg(
        sum("Quantity").alias("TotalQuantity"),
        round(sum(expr("Quantity * UnitPrice")), 2).alias("InvoiceValue"),
        count("*").alias("num_items"),
        avg("UnitPrice").alias("avg_price")
    )

df.groupBy().sum().collect()[0][0]
# groupBy() with no columns creates a global aggregation.
# .sum() computes the sum of the ‚ÄúNumber‚Äù column.
# .collect() returns a list of Row objects.
# [0][0] extracts the aggregated value from the first row and first column.

# Multiple aggregations defined separately
num_invoices = countDistinct("InvoiceNo").alias("NumInvoices")
total_quantity = sum("Quantity").alias("TotalQuantity")
invoice_value = expr("round(sum(Quantity * UnitPrice), 2) as InvoiceValue")

agg_df = df \
    .withColumn("InvoiceDate", to_date(col("InvoiceDate"), "dd-MM-yyyy H.mm")) \
    .where("year(InvoiceDate) == 2010") \
    .withColumn("WeekNumber", weekofyear(col("InvoiceDate"))) \
    .groupBy("Country", "WeekNumber") \
    .agg(num_invoices, total_quantity, invoice_value)

# Count by group (shorthand)
df.groupBy("CallType").count().orderBy("count", ascending=False)

# Pivot
df.groupBy("Country").pivot("Year").sum("Revenue")

# ============================================================================
# WINDOW FUNCTIONS (Value per window)
# ============================================================================

from pyspark.sql import Window
from pyspark.sql.functions import sum, avg, row_number, rank, dense_rank, lag, lead

# Window.partitionBy : WIDE TRANSFORMATION
# Define window specifications
window_spec = Window.partitionBy("Country").orderBy("WeekNumber")

running_total_window = Window \
    .partitionBy("Country") \
    .orderBy("WeekNumber") \
    .rowsBetween(Window.unboundedPreceding, Window.currentRow)

# Unbounded window (entire partition)
unbounded_window = Window.partitionBy("Country")

# Range between (value-based)
range_window = Window \
    .partitionBy("product") \
    .orderBy("date") \
    .rangeBetween(-7, 0)  # Last 7 days

# Apply window functions
df_with_windows = df.withColumn("RunningTotal", sum("InvoiceValue").over(running_total_window))

df_with_windows = df.withColumn("row_number", row_number().over(window_spec))

df_with_windows = df.withColumn("rank", rank().over(window_spec))

df_with_windows = df.withColumn("dense_rank", dense_rank().over(window_spec))

df_with_windows = df.withColumn("lag_value", lag("value", 1).over(window_spec))

df_with_windows = df.withColumn("lead_value", lead("value", 1).over(window_spec))

df_with_windows = df.withColumn("cumulative_sum", sum("amount").over(running_total_window))

df_with_windows = df.withColumn("avg_per_country", avg("amount").over(unbounded_window))
# Step 1: aggregate per country
# country_avg_df = df.groupBy("Country").agg(avg("amount").alias("avg_per_country"))
# Step 2: join back to original DataFrame
# df_without_window = df.join(country_avg_df, on="Country", how="left")

# ============================================================================
# UNION AND INTERSECTION
# ============================================================================

# ALL WIDE TRANSFORMATIONS
# Union (combines DataFrames, allows duplicates)
union_df = df1.union(df2)
union_df = df1.unionAll(df2)  # Deprecated, use union()

# Union by name (matches columns by name, not position)
union_df = df1.unionByName(df2)
union_df = df1.unionByName(df2, allowMissingColumns=True)

# Intersection (common rows) duplicates are removed. To preserve duplicates use intersectAll().
intersection_df = df1.intersect(df2)

# Except / Subtract (rows in df1 but not in df2)
except_df = df1.subtract(df2) #SQL Except - set(df1) - set(df2)
except_df = df1.exceptAll(df2)  # Keeps duplicates # SQL Except all - If a row appears n times in df1 and m times in df2, output will contain it max(n ‚àí m, 0) times.



SyntaxError: invalid syntax (ipython-input-3366183906.py, line 123)

---
**Action**

---

In [None]:
# collect(): This action returns all the rows of the DataFrame as an array. It is useful for debugging purposes, but should be used with caution, as it may cause out-of-memory errors if the DataFrame is large.
rows = df.collect()

# count(): This action returns the number of rows in the DataFrame.
row_count = df.count()

# first(): This action returns the first row of the DataFrame.
first_row = df.first()

# head(n = None): This action returns the first n rows if avilable of the DataFrame. At most n rows. 1 row if n is not passed
# This method should only be used if the resulting array is expected to be small, as all the data is loaded into the driver‚Äôs memory.
# here n is optional & head() and first() are same
# head may return a row object if n = 1 or list of row object when n > 1
head_rows = df.head(5)

# take(n): This action returns the first n rows of the DataFrame. Always returns a list of Row objects
take_rows = df.take(10)

# df.show(n=20, truncate=True, vertical=False) https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.show.html
# n: Number of rows to show.
# truncate: bool or int, optional, default True | If set to True, truncate strings longer than 20 chars. If set to a number greater than one, truncates long strings to length truncate and align cells right.
# vertical bool, optional | If set to True, print output rows vertically (one line per column value).
df.show()
df.show(20, truncate=False)
df.show(n=10, vertical=True)

# foreach(func): This action applies a function func to each row of the DataFrame. It is useful for performing an action on each row of the DataFrame, such as saving the row to a database. or modifying an accumulator
df = spark.createDataFrame( [(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"] )
def func(person):
    print(person.name)
df.foreach(func)

# reduce(func): This action reduces the rows of the RDD to a single value by applying a function func that takes two rows and returns a single row.
# ((((row1 + row2) + row3) + row4) ... )
df = spark.createDataFrame([
    (1, "A"),
    (2, "B"),
    (3, "C")
], ["value", "label"])

def add_rows(r1, r2):
    return Row(value=r1.value + r2.value)

result = df.select("value").rdd.reduce(add_rows)
print(result) # Row(value=6)

# Returns an iterator that contains all of the rows in this DataFrame. The iterator will consume as much memory as the largest partition in this DataFrame.
# With prefetch it may consume up to the memory of the 2 largest partitions.
df = spark.createDataFrame( [(14, "Tom"), (23, "Alice"), (16, "Bob")], ["age", "name"] )
list(df.toLocalIterator())
# [Row(age=14, name='Tom'), Row(age=23, name='Alice'), Row(age=16, name='Bob')]

# Convert to Pandas, Collects entire DF to driver as Pandas ‚Üí heavy, risky.
pandas_df = df.toPandas()

---
**Utility Methods**

---

In [None]:
# ============================================================================
# 15. CACHING AND PERSISTENCE
# ============================================================================

# Cache in memory
df.cache() #shortcut for df.persist(StorageLevel.) MEMORY_AND_DISK_DESER
df.persist()  # Same as cache()
# Note: an action like df.count() is needed for it to cache, if used df.take(n) Only caches the partition(s) needed to fetch n row ‚Üí NOT full cache.

# Persist with storage level
from pyspark import StorageLevel
# pyspark.StorageLevel(useDisk, useMemory, useOffHeap, deserialized, replication=1)
# StorageLevel.MEMORY_AND_DISK_2 is Same as MEMORY_AND_DISK storage level but replicate each partition to two cluster nodes.
df.persist(StorageLevel.MEMORY_AND_DISK_DESER) # default
df.persist(StorageLevel.MEMORY_AND_DISK)
df.persist(StorageLevel.DISK_ONLY)
df.persist(StorageLevel.MEMORY_ONLY)
'''
| StorageLevel              | useDisk | useMemory | useOffHeap | deserialized | replication |
| ------------------------- | :-----: | :-------: | :--------: | :----------: | :---------: |
|   NONE                    |  False  |   False   |    False   |     False    |      0      |
|   DISK_ONLY               |   True  |   False   |    False   |     False    |      1      |
|   DISK_ONLY_2             |   True  |   False   |    False   |     False    |      2      |
|   DISK_ONLY_3             |   True  |   False   |    False   |     False    |      3      |
|   MEMORY_ONLY             |  False  |    True   |    False   |     False    |      1      |
|   MEMORY_ONLY_2           |  False  |    True   |    False   |     False    |      2      |
|   MEMORY_AND_DISK         |   True  |    True   |    False   |     False    |      1      |
|   MEMORY_AND_DISK_2       |   True  |    True   |    False   |     False    |      2      |
|   MEMORY_AND_DISK_DESER   |   True  |    True   |    False   |     True     |      1      |
|   OFF_HEAP                |  True   |    True   |    True    |     False    |      1      |
'''

# Unpersist
df.unpersist()

# Check if cached
df.is_cached

# ============================================================================
# 16. OTHER UTILITY FUNCTIONS
# ============================================================================

# Prints out the schema in the tree format. Optionally allows to specify how many levels to print if schema is nested.
df.printSchema()
# root
#  |-- DEST_COUNTRY_NAME: string (nullable = true)
#  |-- ORIGIN_COUNTRY_NAME: string (nullable = true)
#  |-- count: long (nullable = false)

df.schema.simpleString()
# struct<DEST_COUNTRY_NAME:string,ORIGIN_COUNTRY_NAME:string,count:bigint>

# Returns the schema of this DataFrame as a pyspark.sql.types.StructType.
schema = df.schema
# StructType([StructField('age', LongType(), True), StructField('name', StringType(), True)])

# Retrieves the names of all columns in the DataFrame as a list. The order of the column names in the list reflects their order in the DataFrame.
columns = df.columns
# ['age', 'name', 'state']

# Returns all column names and their data types as a list.
dtypes = df.dtypes
#[('age', 'bigint'), ('name', 'string')]

# DataFrame.describe(*cols)		Computes basic statistics for numeric and string columns.
# This includes count, mean, stddev, min, and max. If no columns are given, this function computes statistics for all numerical or string columns.
# Returns A new DataFrame that describes (provides statistics) given DataFrame.
df.describe(['age', 'weight', 'height']).show()
# +-------+----+------------------+-----------------+
# |summary| age|            weight|           height|
# +-------+----+------------------+-----------------+
# |  count|   3|                 3|                3|
# |   mean|12.0| 40.73333333333333|            145.0|
# | stddev| 1.0|3.1722757341273704|4.763402145525822|
# |    min|  11|              37.8|            142.2|
# |    max|  13|              44.1|            150.5|
# +-------+----+------------------+-----------------+

# DataFrame.summary(*statistics)    Computes specified statistics for numeric and string columns.
# Available statistics are: - count - mean - stddev - min - max - arbitrary approximate percentiles specified as a percentage (e.g., 75%)
df.select("age", "weight", "height").summary("count", "min", "25%", "75%", "max").show()
# +-------+---+------+------+
# |summary|age|weight|height|
# +-------+---+------+------+
# |  count|  3|     3|     3|
# |    min| 11|  37.8| 142.2|
# |    25%| 11|  37.8| 142.2|
# |    75%| 13|  44.1| 150.5|
# |    max| 13|  44.1| 150.5|
# +-------+---+------+------+


# DataFrame.explain(extended=None, mode=None)		Prints the (logical and physical) plans to the console for debugging purposes.
#
# extended (bool, default = False)
#   False ‚Üí prints only the physical plan
#   If a string is passed, it‚Äôs treated as the mode
#
# mode (str) ‚Äì output format:
#   simple ‚Üí physical plan only
#   extended ‚Üí logical + physical plans
#   codegen ‚Üí physical plan + generated code (if available)
#   cost ‚Üí logical plan + statistics (if available)
#   formatted ‚Üí physical plan outline + detailed node info
df.explain(extended=True)

# The internal Java DataFrame backing your PySpark DataFrame
# 	_jdf = Java DataFrame
# 	Exposes Spark‚Äôs JVM internals via Py4J
# 	Not part of the public API (meant for debugging / internals)
df._jdf.queryExecution().toString()

# BOTH WILL RESULT IN

# == Parsed Logical Plan ==
# ...
# == Analyzed Logical Plan ==
# ...
# == Optimized Logical Plan ==
# ...
# == Physical Plan ==
# ...


---
**SALTING** (Joins + Aggregation https://chatgpt.com/c/6947f284-197c-8321-a6a2-41eefded3822)

---

In [None]:
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType
import random

"""
Salting is a technique to handle data skew by:
1. Adding a random "salt" value to skewed keys
2. Distributing hot keys across multiple partitions
3. Processing in parallel instead of single partition
4. Aggregating results after processing

USES:
- Handling skewed joins
- Fixing aggregation bottlenecks
- Improving shuffle performance
- Balancing partition sizes
"""
# ============================================================================
# 1. DATA SKEW - THE PROBLEM
# ============================================================================

# Example: Creating skewed data
skewed_data = []
# Create heavy skew - 80% records have same key
for i in range(8000):
    skewed_data.append(("hot_key", f"value_{i}", i))
# Non skewed data
for i in range(2000):
    skewed_data.append((f"key_{i}", f"value_{i}", i))

skewed_df = spark.createDataFrame(skewed_data, ["key", "value", "amount"])
# Number of partitions would be value mentioned here: spark.default.parallelism
# Post any wide transformation it would follow: spark.sql.shuffle.partitions

# ============================================================================
# 2. SALTING TECHNIQUE #1: Simple Random Salting
# ============================================================================
"""
Use Case: When you have skewed aggregations
Adds random suffix to distribute keys across partitions
"""

# rand(seed=None) - Generates a random column with independent and identically distributed (i.i.d.) samples uniformly distributed in [0.0, 1.0).
# seed - Seed value for the random generator.

def add_salt(df, key_col, salt_range=10):
    """Add random salt to a key column"""
    return df.withColumn(
        "salt",
        F.lit(F.floor(F.rand() * salt_range).cast("int"))
    ).withColumn(
        f"{key_col}_salted",
        F.concat(F.col(key_col), F.lit("_"), F.col("salt"))
    )

# Apply salting
salted_df = add_salt(skewed_df, "key", salt_range=10)

# Perform aggregation on salted keys
result = salted_df.groupBy("key_salted").agg(
    F.sum("amount").alias("total_amount")
)

# Remove salt and re-aggregate
final_result = result.withColumn(
    "original_key",
    F.regexp_replace(F.col("key_salted"), "_\\d+$", "")
).groupBy("original_key").agg(
    F.sum("total_amount").alias("final_total")
)

final_result.show()

# ============================================================================
# 2. SALTING TECHNIQUE #2: Salted Join
# ============================================================================
"""
Use Case: When joining tables with skewed keys
Expands smaller table and salts larger table for balanced join
"""

def salted_join(large_df, small_df, join_key, salt_range=10):
    """
    Perform a salted join between skewed and normal DataFrames
    Steps:
    1. Explode small table by replicating with all salt values
    2. Add random salt to large table
    3. Join on salted keys
    4. Remove salt columns
    """

    # Step 1: Explode small table
    salt_values = list(range(salt_range))
    small_exploded = small_df.withColumn(
        "salt",
        F.explode(F.array([F.lit(i) for i in salt_values]))
    ).withColumn(
        f"{join_key}_salted",
        F.concat(F.col(join_key), F.lit("_"), F.col("salt"))
    )

    # Step 2: Salt large table
    large_salted = add_salt(large_df, join_key, salt_range)

    # Step 3: Join
    joined = large_salted.join(
        small_exploded,
        large_salted[f"{join_key}_salted"] == small_exploded[f"{join_key}_salted"],
        "inner"
    )

    # Step 4: Clean up
    return joined.drop("salt", f"{join_key}_salted")

# Example: Salted Join
orders = spark.createDataFrame([
    ("hot_key", "order1", 100),
    ("hot_key", "order2", 200),
    ("hot_key", "order3", 150),
    ("key_1", "order4", 300),
], ["customer_id", "order_id", "amount"])

customers = spark.createDataFrame([
    ("hot_key", "VIP Customer"),
    ("key_1", "Regular Customer"),
], ["customer_id", "customer_type"])

salted_join_result = salted_join(orders, customers, "customer_id", salt_range=3)
salted_join_result.show()

# ============================================================================
# 3. SALTING TECHNIQUE #3: Adaptive Salting
# ============================================================================
"""
Use Case: When only certain keys are skewed
Selectively salt only the hot keys instead of entire dataset
"""
def adaptive_salting(df, key_col, threshold=1000, salt_range=10):
    """
    Apply salting only to keys exceeding threshold
    1. Identify hot keys
    2. Salt only those keys
    3. Union with non-hot keys
    """

    # Identify hot keys
    key_counts = df.groupBy(key_col).count()
    hot_keys = key_counts.filter(F.col("count") > threshold).select(key_col)

    # Separate hot and cold data
    hot_data = df.join(hot_keys, key_col, "inner")
    cold_data = df.join(hot_keys, key_col, "left_anti")

    # Salt hot data
    hot_salted = add_salt(hot_data, key_col, salt_range)

    # Process hot data
    hot_result = hot_salted.groupBy(f"{key_col}_salted").agg(
        F.sum("amount").alias("total")
    ).withColumn(
        key_col,
        F.regexp_replace(F.col(f"{key_col}_salted"), "_\\d+$", "")
    ).groupBy(key_col).agg(F.sum("total").alias("amount"))

    # Process cold data
    cold_result = cold_data.groupBy(key_col).agg(
        F.sum("amount").alias("amount")
    )

    # Union results
    return hot_result.union(cold_result)

adaptive_result = adaptive_salting(skewed_df, "key", threshold=100, salt_range=5)
adaptive_result.show()

# ============================================================================
# 4. SALTING TECHNIQUE #4: Two-Stage Aggregation (Same as Simple Random Salting up above)
# ============================================================================
"""
Use Case: Heavy aggregations with multiple group-by columns
Reduces data before final aggregation
"""
def two_stage_aggregation(df, group_cols, agg_col, salt_range=10):
    """
    Two-stage aggregation with salting
    Stage 1: Partial aggregation with salt
    Stage 2: Final aggregation without salt
    """

    # Stage 1: Add salt and partial aggregate
    salted = df.withColumn("salt", F.floor(F.rand() * salt_range).cast("int"))

    partial_agg = salted.groupBy(group_cols + ["salt"]).agg(
        F.sum(agg_col).alias("partial_sum"),
        F.count(agg_col).alias("partial_count")
    )

    # Stage 2: Final aggregation
    final_agg = partial_agg.groupBy(group_cols).agg(
        F.sum("partial_sum").alias("total_sum"),
        F.sum("partial_count").alias("total_count")
    )

    return final_agg

two_stage_result = two_stage_aggregation(skewed_df, ["key"], "amount", salt_range=5)
two_stage_result.show()

# ============================================================================
# 9. MONITORING & DEBUGGING SKEW
# ============================================================================
"""
Detecting Data Skew:
-------------------
"""

def analyze_partition_distribution(df, key_col):
    """Analyze how data is distributed across partitions"""

    # Add partition ID
    with_partition = df.withColumn("partition_id", F.spark_partition_id())

    # Count records per partition
    partition_stats = with_partition.groupBy("partition_id").agg(
        F.count("*").alias("record_count")
    ).orderBy("record_count", ascending=False)

    print("\n=== Partition Distribution ===")
    partition_stats.show()

    # Key distribution
    key_stats = df.groupBy(key_col).agg(
        F.count("*").alias("count")
    ).orderBy("count", ascending=False)

    print("\n=== Top Keys by Count ===")
    key_stats.show(10)

    return partition_stats, key_stats

# Analyze skew
analyze_partition_distribution(skewed_df, "key")

"""
SALTING BEST PRACTICES:
1. When to Use Salting:
   - Joins with skewed keys (>5% data in single key)
   - Heavy aggregations on hot keys
   - Tasks taking 3x+ longer than median
2. Salt Range Selection:
   - Small datasets: 5-10 salts
   - Medium datasets: 10-50 salts
   - Large datasets: 50-100 salts
   - Rule: sqrt(num_executors) * 2
3. Trade-offs:
   - Pros: Better parallelism, prevents OOM
   - Cons: Extra shuffle, increased data size
4. Alternatives to Consider:
   - Broadcast joins (for small tables)
   - Bucketing (for repeated joins)
   - Repartitioning with custom partitioner
   - AQE (Adaptive Query Execution) in Spark 3.0+
5. Monitoring:
   - Use Spark UI to identify stragglers
   - Check partition sizes in shuffle
   - Monitor spill metrics
"""

---
**COMMON PATTERNS AND BEST PRACTICES**

---

In [None]:
from pyspark.sql.functions import col, count, when

# 1. Check data quality
df.select([count(when(col(c).isNull(), c)).alias(c) for c in df.columns]).show()

# 2. Find duplicates
df.groupBy(df.columns).count().filter(col("count") > 1).show()

# 3. Sample data for testing
sample_df = df.sample(fraction=0.01)

# 4. Rename all columns (lowercase)
df_clean = df.select([col(c).alias(c.lower()) for c in df.columns])

# 5. Drop columns with all nulls
non_null_counts = df.select([count(c).alias(c) for c in df.columns]).collect()[0]
cols_to_keep = [c for c in df.columns if non_null_counts[c] > 0]
df_filtered = df.select(cols_to_keep)

# DataFrame.stat Returns a DataFrameStatFunctions for statistic functions.

# approxQuantile(col, probabilities, relativeError)
# Calculates the approximate quantiles of numerical columns of a DataFrame.
df.stat.approxQuantile("price", [0.25, 0.5, 0.75], 0.01)

# corr(col1, col2[, method])
# Calculates the correlation of two columns of a DataFrame as a double value.
df.stat.corr("col1", "col2")

# cov(col1, col2)
# Calculate the sample covariance for the given columns, specified by their names, as a double value.

# crosstab(col1, col2)
# Computes a pair-wise frequency table of the given columns.
df.stat.crosstab("category", "status").show()

# freqItems(cols[, support])
# Finding frequent items for columns, possibly with false positives.
df.stat.freqItems(["col1", "col2"]).show()

# sampleBy(col, fractions[, seed])
# Returns a stratified sample without replacement based on the fraction given on each stratum.

---
**UDF**

---

In [None]:
# ============================================================================
# UDFs (User Defined Functions)
# ============================================================================
from pyspark.sql.functions import udf, expr

# Define UDF
def categorize_age(age):
    if age < 18:
        return "minor"
    elif age < 65:
        return "adult"
    else:
        return "senior"

# Register UDF - Registered UDFs are serialized and sent to Spark executors.
# Used when data transformation is done on Dataframe. Does not create a catalog entry. has local scope and cannot be used in SQL
# udf(f=None, returnType=StringType(), *, useArrow=None)
categorize_udf = udf(categorize_age, StringType())

# Use UDF
df_with_category = df.withColumn("category", categorize_udf(col("age")))

# Register UDF for SQL - Adds the function to the Spark catalog for SQL usage. Hence, across all SQL queries in the session; however, neither method persists the UDF beyond the current SparkSession.
# register(name, f, returnType=None)
spark.udf.register("categorize_age_sql", categorize_age, StringType())
df.createOrReplaceTempView("people")

spark.sql("SELECT name, categorize_age_sql(age) as category FROM people").show()
df_with_category = df.withColumn("category", expr("categorize_age_sql(age)"))

# Pandas UDF (vectorized, more efficient)
from pyspark.sql.functions import pandas_udf
import pandas as pd

# Series to Series
@pandas_udf(DoubleType())
def multiply_by_two(col: pd.Series) -> pd.Series:
    return col * 2

df.withColumn("doubled", multiply_by_two(col("value"))).show()

# Iterator of Series to Iterator of Series (for batching)
@pandas_udf("double")
def multiply_batch(iterator):
	for s in iterator:
		yield s * 2

# Series to Scalar (aggregate)
@pandas_udf("double")
def mean_udf(s: pd.Series) -> float:
	return s.mean()

df.groupBy("group").agg(mean_udf(col("value")))

'''
Types of Pandas UDFs:
1. Series to Series: Element-wise transformations
2. Iterator of Series to Iterator of Series: Batch processing
3. Iterator of Multiple Series to Iterator of Series: Multiple input columns
4. Series to Scalar: Aggregations
5. Grouped Map: Operates on entire groups
'''

---
**PANDAS API**

---

In [None]:
# Pandas-on-Spark is lazy once data is in Spark, but importing from Pandas is eager because the data already exists in memory.
# Creating Pandas on Spark DataFrames:

import pandas as pd
import pyspark.pandas as ps

# From Spark DataFrame
psdf = spark_df.to_pandas_on_spark() # Lazy

# From Pandas
pdf = pd.DataFrame({'a': [1, 2, 3]})
psdf = ps.from_pandas(pdf) # Spark job runs right away to parallelize the Pandas data, Subsequent operations on psdf are lazy, just like Spark

# Direct creation Eager
psdf = ps.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})

# Read files
psdf = ps.read_csv("path/to/file.csv")

---
**Accumulator**

---

In [None]:
# ============================================================================
# ACCUMULATORS IN DATAFRAME UDFs
# ============================================================================
'''
ACCUMULATORS:
  ‚Ä¢ Write-only variables for aggregating metrics across executors on a per-row basis.
  ‚Ä¢ Use for: counters, data quality tracking, error monitoring
  ‚Ä¢ Only driver can read the final value.

  ‚Ä¢ Be careful with transformations (may execute multiple times) Because transformations are lazy and may run multiple times: during job retries, during lineage recomputation, during stage re-execution
  ‚Ä¢ For accumulator restarted tasks will not update the value in case of a failure. Accumulator update from the failed attempt is not counted. Spark only applies accumulator updates from successful task attempts.
  ‚Ä¢ Use accumulators in actions or with cache/persist to avoid double-counting

  ‚Ä¢ UI Visibility (Just name them):
    ‚Ä¢ Scala: Named accumulators visible in Spark UI.
    ‚Ä¢ PySpark: Unnamed; not displayed in UI.
  ‚Ä¢ Types: Default support for Long/Float; custom accumulators possible (advanced, out of scope).
    You can define your own custom accumulator class by extending org.apache.spark.util.AccumulatorV2 (V2) in Java or Scala or pyspark.AccumulatorParam in Python.
'''
from pyspark.sql.functions import col, udf
from pyspark.sql.types import StringType

# Base DataFrame used for all versions
df = spark.createDataFrame([
    (10, 2), (20, 0), (15, 3), (8, 0), (12, 4)
], ["num", "denom"])

sc = spark.sparkContext

# ================ VERSION A: UNSAFE (Accumulator inside transformation, NO CACHE) ================
error_counter = sc.accumulator(0)

def safe_divide_udf_A(n, d):
    if d == 0:
        error_counter.add(1)        # Accumulator inside TRANSFORMATION ‚Üí may run multiple times
        return None
    return n / d

safe_divide_A = udf(safe_divide_udf_A, StringType())

df_A = df.withColumn("result", safe_divide_A(col("num"), col("denom")))
df_A.show()                         # Action triggers evaluation
print(f"Version A errors counted: {error_counter.value}")  # May be double-counted depending on Spark plan

# ================ VERSION B: SAFE (Accumulator used inside an ACTION) ================
error_counter = sc.accumulator(0)

def count_errors(row):
    if row.denom == 0:
        error_counter.add(1)        # Executed exactly once per row by foreach

# Action executes the function exactly once
df.foreach(count_errors)

print(f"Version B errors counted: {error_counter.value}")

# Compute results using a pure UDF (no side effects)
safe_divide_B = udf(lambda n, d: None if d == 0 else n/d, StringType())
df_B = df.withColumn("result", safe_divide_B(col("num"), col("denom")))
df_B.show()


# ================ VERSION C: SAFE (Accumulator inside transformation, BUT cached before action) ================
error_counter = sc.accumulator(0)

def safe_divide_udf_C(n, d):
    if d == 0:
        error_counter.add(1)
        return None
    return n / d

safe_divide_C = udf(safe_divide_udf_C, StringType())

df_C = df.withColumn("result", safe_divide_C(col("num"), col("denom")))

df_C.cache()             # Key step to avoid multiple evaluations

df_C.count()             # Triggers UDF once
df_C.show()

print(f"Version C errors counted: {error_counter.value}")

---
**Broadcast Variable**

---

In [None]:
# ============================================================================
# BROADCAST IN DATAFRAME TRANSFORMATIONS (Broadcast Join are also under this)
# ============================================================================
'''
BROADCAST VARIABLES:
  ‚Ä¢ Read-only variables cached on each executor in deserialized form (not sent per task)
  ‚Ä¢ Use for: lookup tables, configuration, filtering lists
  ‚Ä¢ Significantly reduces network overhead for large shared data
  ‚Ä¢ Remember to unpersist() when done to free memory
'''
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType

sc = spark.sparkContext
price_tiers = {
    "economy": 1.0, "standard": 1.5,
    "premium": 2.0, "luxury": 3.0
}
bc_tiers = sc.broadcast(price_tiers)

def calculate_price_udf(tier, base_price):
    multiplier = bc_tiers.value.get(tier, 1.0)
    return base_price * multiplier

calc_price = udf(calculate_price_udf, StringType())

products = spark.createDataFrame([
      ("Product A", "economy", 100),
      ("Product B", "premium", 150),
      ("Product C", "luxury", 200),
      ("Product D", "standard", 120)
  ], ["name", "tier", "base_price"])

products_priced = products.withColumn(
    "final_price",
    calc_price(col("tier"), col("base_price"))
)
products_priced.show()


---
**HINTS** (Check execution plan with df.explain() to verify hints)

---

In [None]:
# ============================================================================
# PARTITIONING STRATEGY HINTS
# ============================================================================

# COALESCE - Reduces partitions to a specified number.
result = spark.sql("""
	SELECT /*+ COALESCE(10) */
	CASE WHEN score > 80 THEN 'A' ELSE 'B' END AS grade,
	student_id
	FROM large_table
""")
# In Spark SQL, query hints must appear immediately after the SELECT keyword.

# Hints are applied before transformations like select, just like SQL hints apply at query planning time.
df_result = df.hint("coalesce", 10)
# Or use direct method:
df_result = df.coalesce(10)

# REPARTITION - Repartition your dataframe to the specified number of partitions.
result = spark.sql("""
  SELECT /*+ COALESCE(10), REPARTITION(20) */
  CASE WHEN score > 80 THEN 'A' ELSE 'B' END AS grade,
  student_id
  FROM large_table
""")

from pyspark.sql import functions as F
result_df = (
    df
    .hint("COALESCE", 10)
    .hint("REPARTITION", 20)
    .select(
        F.when(F.col("score") > 80, F.lit("A"))
         .otherwise(F.lit("B"))
         .alias("grade"),
        F.col("student_id")
    )
)

# REPARTITION_BY_RANGE - Similar to REPARTITION but uses data ranges for partitioning.
result = spark.sql("""
  SELECT /*+ REPARTITION(50, dept, country) */ *
  FROM employees
""")

df_result = df.hint("repartition", 50, "dept", "country")
# Or:
df_result = df.repartition(50, "dept", "country")

# REBALANCE - Used to rebalance  the query result output partitions, so that every partition is of a reasonable size.
spark.sql("SELECT /*+ REBALANCE */ * FROM t")
spark.sql("SELECT /*+ REBALANCE(3) */ * FROM t")
spark.sql("SELECT /*+ REBALANCE(c) */ * FROM t")
spark.sql("SELECT /*+ REBALANCE(3, c) */ * FROM t")

df_rebalanced = df.hint("REBALANCE") # Rebalance using default Spark behavior -> spark.sql.shuffle.partitions
df_rebalanced = df.hint("REBALANCE", 3) # Rebalance to 3 partitions
df_rebalanced = df.hint("REBALANCE", "c") # Rebalance using column c
df_rebalanced = df.hint("REBALANCE", 3, "c") # Rebalance to 3 partitions using column c

# ============================================================================
# JOIN STRATEGY HINTS
# ============================================================================
# When different join strategy hints are specified on both sides of a join, Spark prioritizes hints in the following order: BROADCAST over MERGE over SHUFFLE_HASH over SHUFFLE_REPLICATE_NL.
# When both sides are specified with the BROADCAST hint or the SHUFFLE_HASH hint, Spark will pick the build side based on the join type and the sizes of the relations.


# BROADCAST / BROADCASTJOIN / MAPJOIN - Forces broadcast join (hash join)
result = spark.sql("""
  SELECT /*+ BROADCAST(small_table) */ *
  FROM large_table
  JOIN small_table ON large_table.id = small_table.id
""")
df = large_df.join(small_df.hint("broadcast"), large_df.id == small_df.id, "inner")
# If both sides of the join have the broadcast hints, the one with the smaller size (based on stats) will be broadcast.

# MERGE/ SHUFFLE_MERGE/ MERGEJOIN - default join for large datasets, Both DataFrames are shuffled on the join key.
spark.sql("SELECT /*+ SHUFFLE_MERGE(t1) */ * FROM t1 INNER JOIN t2 ON t1.key = t2.key")
spark.sql("SELECT /*+ MERGEJOIN(t2) */ * FROM t1 INNER JOIN t2 ON t1.key = t2.key")
spark.sql("SELECT /*+ MERGE(t1) */ * FROM t1 INNER JOIN t2 ON t1.key = t2.key")

df = t1.hint("shuffle_merge").join(t2, on="key", how="inner")
df = t2.hint("mergejoin").join(t1, on="key", how="inner")
df = t1.hint("merge").join(t2, on="key", how="inner")

# SHUFFLE_HASH - Both DataFrames are shuffled on the join key (like merge join). Instead of sorting, Spark builds a hash map of the smaller side per partition. Probes the hash map for matching rows.
result = spark.sql("""
  SELECT /*+ SHUFFLE_HASH(df1) */ *
  FROM df1
  JOIN df2 ON df1.key = df2.key
""")
df_result = df1.hint("shuffle_hash").join(
    df2,
    df1.key == df2.key
)

# SHUFFLE_REPLICATE_NL - The small DataFrame is replicated (copied) to every executor and every partition of the large dataframe. A nested-loop join (cartesian-like) happens per partition.
# When no join keys exist Or when join condition is non-equijoin, e.g.: df1.col("value") > df2.col("value")
spark.sql("""
  SELECT /*+ SHUFFLE_REPLICATE_NL(df1) */ *
  FROM df1
  JOIN df2 ON df1.col > df2.col
""")
df_result = df1.hint("shuffle_replicate_nl").join( df2, df1.col > df2.col)


---
**Integration with external API**

---

In [None]:
# ============================================================================
# 16. WORKING WITH EXTERNAL APIS
# ============================================================================

import requests
from datetime import datetime
from itertools import product
from pyspark.sql.functions import broadcast

def get_conversion_rates(date, pairs=None):
    """Fetch conversion rates for currency pairs from API"""
    base_url = "https://api.bank.com/exchange-rate"

    # If pairs not provided, get all currency combinations
    if pairs is None:
        currencies_url = "https://api.bank.com/currencies"
        try:
            resp = requests.get(currencies_url, timeout=10)
            resp.raise_for_status()
            currencies = resp.json().get("currencies", [])
            if not currencies:
                return []

            # Create all combinations (excluding same currency)
            pairs = [(src, tgt) for src, tgt in product(currencies, currencies)
                     if src != tgt]

        except Exception as e:
            print(f"Failed to fetch currency list: {e}")
            return []

    # Fetch rates for all pairs
    rates = []
    for src, tgt in pairs:
        params = {"date": date, "src": src, "target": tgt}
        try:
            resp = requests.get(base_url, params=params, timeout=10)
            resp.raise_for_status()
            data = resp.json()
            rate = data.get("rate")
            if rate:
                rates.append((date, src, tgt, rate))
        except Exception as e:
            print(f"Failed for {src}->{tgt}: {e}")

    return rates


# Create DataFrame from API data
today = datetime.now().strftime("%Y-%m-%d")
pairs = [("USD", "INR"), ("EUR", "INR"), ("GBP", "INR")]
rates_data = get_conversion_rates(today, pairs)

columns = ["date", "src_currency", "target_currency", "rate"]
rates_df = spark.createDataFrame(rates_data, columns)

# Use broadcast for small lookup tables
transactions_df = spark.read.parquet("transactions/")
joined_df = transactions_df.join(
    broadcast(rates_df),
    transactions_df.currency == rates_df.src_currency,
    "left"
)

converted_df = joined_df.withColumn(
    "amount_in_inr",
    col("amount") * col("rate")
).select("txn_id", "currency", "amount", "rate", "amount_in_inr")

---
**Spark Streaming**

---

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *

# Initialize Spark Session
spark = SparkSession.builder \
    .appName("StreamingExample") \
    .config("spark.sql.shuffle.partitions", 4) \
    .getOrCreate()

# Define Schema
schema = StructType([
    StructField("timestamp", TimestampType(), True),
    StructField("user_id", StringType(), True),
    StructField("event_type", StringType(), True),
    StructField("value", DoubleType(), True)
])

# Read Stream from Kafka
raw_stream = spark.readStream \
    .format("kafka") \
    .option("kafka.bootstrap.servers", "localhost:9092") \
    .option("subscribe", "events") \
    .option("startingOffsets", "latest") \
    .load()
# Use maxFilesPerTrigger and/or maxBytesPerTrigger in the read options to limit the number of files or total data size

# Parse and Transform
parsed_stream = raw_stream \
    .selectExpr("CAST(value AS STRING)") \
    .select(from_json(col("value"), schema).alias("data")) \
    .select("data.*")

# filter() / where() can be used to filter rows based on conditions. Filtering on aggregations is not supported, but row-level filtering like this is supported.

# df.filter(F.count("x") > 30)
# Aggregations with a condition like count(x) > 30 require grouping and stateful processing. These are supported only with groupBy + agg and proper output modes, not directly inside a simple filter like this.

# Actions like show(), collect(), limit() and count() are not supported on streaming DataFrames,

# Apply Watermark and Window Aggregation
windowed_stream = parsed_stream \
    .withWatermark("timestamp", "15 minutes") \
    .groupBy(
        window(col("timestamp"), "5 minutes", "1 minute"),
        col("event_type")
    ).agg(
        count("*").alias("event_count"),
        avg("value").alias("avg_value"),
        max("value").alias("max_value")
    )
'''
A watermark sets an upper bound on how late a duplicate can arrive, allowing Spark to safely remove old state and control memory usage.
‚Ä¢ Watermarks handle late-arriving data in event-time processing
‚Ä¢ Define how long to wait for late data before finalizing results
‚Ä¢ Based on event time, not processing time
‚Ä¢ Allows state cleanup and bounded memory usage
‚Ä¢ Trade-off between completeness and latency
‚Ä¢ Essential for window operations with late data

from pyspark.sql.functions import window, col

# Basic Watermark
df_with_watermark = df \
    .withWatermark("timestamp", "10 minutes")

# Watermark with Window Aggregation
windowed_counts = df \
    .withWatermark("timestamp", "10 minutes") \
    .groupBy(
        window(col("timestamp"), "5 minutes"),
        col("user_id")
    ) \
    .count()

# Multiple Watermarks in Joins
df1_watermarked = df1.withWatermark("timestamp", "5 minutes")
df2_watermarked = df2.withWatermark("timestamp", "10 minutes")

joined = df1_watermarked.join(
    df2_watermarked,
    expr("""
        user_id = user_id AND
        timestamp1 >= timestamp2 AND
        timestamp1 <= timestamp2 + interval 1 hour
    """)
)

# Watermark with Append Mode
query = windowed_counts.writeStream \
    .outputMode("append") \
    .option("checkpointLocation", "/checkpoint") \
    .format("console") \
    .start()

# Understanding Watermark Calculation
# Watermark = max_event_time - watermark_delay
# Data with event_time < watermark is considered late
'''
# Write Stream with Checkpointing
query = windowed_stream.writeStream \
    .outputMode("update") \
    .format("parquet") \
    .option("path", "/output/events") \
    .option("checkpointLocation", "/checkpoint/events") \
    .trigger(processingTime='30 seconds') \
    .start()

'''
Append mode (default) - This is the default mode, where only the new rows added to the Result Table since the last trigger will be outputted to the sink. This is supported for only those queries where rows added to the Result Table are never going to change. Hence, this mode guarantees that each row will be output only once (assuming fault-tolerant sink). For example, queries with only select, where, map, flatMap, filter, join, etc. will support Append mode.
Complete mode - The whole Result Table will be outputted to the sink after every trigger. This is supported for aggregation queries.
Update mode - (Available since Spark 2.1.1) Only the rows in the Result Table that were updated since the last trigger will be outputted to the sink. More information to be added in future releases.

# Append Mode (default) - Only new rows added since last trigger
query = df.writeStream \
    .outputMode("append") \
    .format("console") \
    .start()

# Complete Mode - Entire result table output every trigger
agg_df = df.groupBy("category").count()

query = agg_df.writeStream \
    .outputMode("complete") \
    .format("console") \
    .start()

# Update Mode - Only rows updated since last trigger
windowed = df \
    .withWatermark("timestamp", "10 minutes") \
    .groupBy(
        window(col("timestamp"), "5 minutes"),
        col("category")
    ).count()

query = windowed.writeStream \
    .outputMode("update") \
    .format("console") \
    .start()

# Mode Compatibility Examples
# Append: No aggregation, watermarked aggregation
# Complete: All aggregations (memory intensive)
# Update: Aggregations with state

# Example: Aggregation without Watermark
# Must use Complete or Update mode
df.groupBy("user").count() \
    .writeStream \
    .outputMode("complete") \
    .format("console") \
    .start()

---------------------------------------------------------------------------------------------------------

Checkpointing provides fault tolerance and exactly-once semantics
‚Ä¢ Stores metadata about streaming query progress
‚Ä¢ Includes offsets, state information, and query configuration
‚Ä¢ Required for stateful operations and production deployments
‚Ä¢ Enables restart from last checkpoint after failures
‚Ä¢ Checkpoint location must be reliable (HDFS, S3, etc.)

# Basic Checkpointing
query = df.writeStream \
    .format("parquet") \
    .option("path", "/output/path") \
    .option("checkpointLocation", "/checkpoint/path") \
    .start()

# Multiple Queries with Different Checkpoints
query1 = df1.writeStream \
    .option("checkpointLocation", "/checkpoints/query1") \
    .format("parquet") \
    .start("/output/query1")

query2 = df2.writeStream \
    .option("checkpointLocation", "/checkpoints/query2") \
    .format("parquet") \
    .start("/output/query2")

# Checkpoint with State Store
from pyspark.sql.functions import window, col

windowed = df.groupBy(
    window(col("timestamp"), "10 minutes")
).count()

query = windowed.writeStream \
    .outputMode("update") \
    .option("checkpointLocation", "/checkpoint/windowed") \
    .format("console") \
    .start()

---------------------------------------------------------------------------------------------------------

Triggers define when the streaming query should process data
‚Ä¢ Default: micro-batch processing as data arrives
‚Ä¢ Types: ProcessingTime, Once, Continuous, AvailableNow
‚Ä¢ Affects latency vs throughput tradeoff
‚Ä¢ Continuous mode offers low latency (experimental)

# Default Trigger (as fast as possible)
query = df.writeStream \
    .format("console") \
    .start()

# Fixed Interval Trigger
query = df.writeStream \
    .trigger(processingTime='10 seconds') \
    .format("console") \
    .start()

# One-time Trigger (batch-like)
query = df.writeStream \
    .trigger(once=True) \
    .format("console") \
    .start()

# Available Now (process all available data)
query = df.writeStream \
    .trigger(availableNow=True) \
    .format("console") \
    .start()

# Continuous Trigger (low latency)
query = df.writeStream \
    .trigger(continuous='1 second') \
    .format("console") \
    .start()
'''

# Monitor Query
query.awaitTermination()

# Additional Monitoring
print(f"Query ID: {query.id}")
print(f"Status: {query.status}")
print(f"Recent Progress: {query.recentProgress}")