# Learn and practice pyspark

## Prepare the  environment

Import necessary modules and packages

In [1]:
import os
import requests
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, DoubleType, TimestampType
# from pyspark.sql import functions as sf
from pyspark.sql.functions import col, concat, lit, when

Prepare the project's directories

In [2]:
# Prepare directories
os.makedirs("data/", exist_ok=True)

Download a subset of the NYC taxi trips dataset. This subset contains the trips of yellow taxis from march 2025

In [3]:
# Download NYC taxi trips dataset
URL = "https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2025-03.parquet"
response = requests.get(url=URL, stream=True, timeout=(10, 30))

if os.path.exists("data/yellow_taxi_data_202503.parquet"):
    print("Dataset ready")
else:
    with open("data/yellow_taxi_data_202503.parquet", "wb") as file:
        for chunk in response.iter_content(chunk_size=4092):
            file.write(chunk)
    print("Dataset downloaded")

Dataset ready


# Pyspark basics

Initialize Spark Session, this is the basic entrypoint to interact with Apache Spark

In [4]:
spark = SparkSession.builder.appName("LearnSpark") \
    .config('spark.executor.memory', '4g') \
    .config('spark.executor.cores', '4') \
    .config('spark.sql.shuffle.partitions', '200') \
    .getOrCreate()

25/07/09 20:57:53 WARN Utils: Your hostname, xblade resolves to a loopback address: 127.0.1.1; using 192.168.10.156 instead (on interface enp2s0)
25/07/09 20:57:53 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/07/09 20:57:53 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In the SparkSession builder we can configure the spark deployment and its process like the amount of memory and cores assigned to the executor process. The [configuration docs are Here](!https://spark.apache.org/docs/latest/configuration.html)

## The Dataframe basics

Create a Dataframe from scratch

In [5]:
df_plain = spark.createDataFrame(data=[('Alan', 10), ('Cindy', 12)], schema=('Name', 'Age'))
print(df_plain)

DataFrame[Name: string, Age: bigint]


Create a Dataframe by reading a file, it could be parquet, json, csv, etc

In [6]:
df: DataFrame = spark.read.parquet("data/yellow_taxi_data_202503.parquet")

Show the first rows of a dataframe with Dataframe.show(), the default value is 20

In [7]:
df.show()

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|cbd_congestion_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+
|       1| 2025-03-01 00:17:16|  2025-03-01 00:25:52|              1|          0.9|         1|                 N|         140|    

In [8]:
# only the first 5 rows
df.show(5)

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|cbd_congestion_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+
|       1| 2025-03-01 00:17:16|  2025-03-01 00:25:52|              1|          0.9|         1|                 N|         140|    

The Dataframe.head() Returns the first n rows like Pandas (1 by default on spark)

In [9]:
df.head(5)

[Row(VendorID=1, tpep_pickup_datetime=datetime.datetime(2025, 3, 1, 0, 17, 16), tpep_dropoff_datetime=datetime.datetime(2025, 3, 1, 0, 25, 52), passenger_count=1, trip_distance=0.9, RatecodeID=1, store_and_fwd_flag='N', PULocationID=140, DOLocationID=236, payment_type=1, fare_amount=7.9, extra=3.5, mta_tax=0.5, tip_amount=2.6, tolls_amount=0.0, improvement_surcharge=1.0, total_amount=15.5, congestion_surcharge=2.5, Airport_fee=0.0, cbd_congestion_fee=0.0),
 Row(VendorID=1, tpep_pickup_datetime=datetime.datetime(2025, 3, 1, 0, 37, 38), tpep_dropoff_datetime=datetime.datetime(2025, 3, 1, 0, 43, 51), passenger_count=1, trip_distance=0.6, RatecodeID=1, store_and_fwd_flag='N', PULocationID=140, DOLocationID=262, payment_type=1, fare_amount=6.5, extra=3.5, mta_tax=0.5, tip_amount=2.3, tolls_amount=0.0, improvement_surcharge=1.0, total_amount=13.8, congestion_surcharge=2.5, Airport_fee=0.0, cbd_congestion_fee=0.0),
 Row(VendorID=2, tpep_pickup_datetime=datetime.datetime(2025, 3, 1, 0, 24, 35)

The DataFrame.take() method works similarly to head()

In [10]:
df.take(5)

[Row(VendorID=1, tpep_pickup_datetime=datetime.datetime(2025, 3, 1, 0, 17, 16), tpep_dropoff_datetime=datetime.datetime(2025, 3, 1, 0, 25, 52), passenger_count=1, trip_distance=0.9, RatecodeID=1, store_and_fwd_flag='N', PULocationID=140, DOLocationID=236, payment_type=1, fare_amount=7.9, extra=3.5, mta_tax=0.5, tip_amount=2.6, tolls_amount=0.0, improvement_surcharge=1.0, total_amount=15.5, congestion_surcharge=2.5, Airport_fee=0.0, cbd_congestion_fee=0.0),
 Row(VendorID=1, tpep_pickup_datetime=datetime.datetime(2025, 3, 1, 0, 37, 38), tpep_dropoff_datetime=datetime.datetime(2025, 3, 1, 0, 43, 51), passenger_count=1, trip_distance=0.6, RatecodeID=1, store_and_fwd_flag='N', PULocationID=140, DOLocationID=262, payment_type=1, fare_amount=6.5, extra=3.5, mta_tax=0.5, tip_amount=2.3, tolls_amount=0.0, improvement_surcharge=1.0, total_amount=13.8, congestion_surcharge=2.5, Airport_fee=0.0, cbd_congestion_fee=0.0),
 Row(VendorID=2, tpep_pickup_datetime=datetime.datetime(2025, 3, 1, 0, 24, 35)

We can create a new Pandas Dataframe based on the original dataframe with Dataframe.toPandas()

In [11]:
pd_df = df_plain.toPandas()
print(type(pd_df))

<class 'pandas.core.frame.DataFrame'>


                                                                                

Show the DataFrame  with Pandas's Dataframe.head()

In [12]:
pd_df.head()

Unnamed: 0,Name,Age
0,Alan,10
1,Cindy,12


Write a dataframe to csv files.
The write.csv() method will create a directory and write multiple csv files(parts of the dataframe) that together, make up the whole dataframe and _SUCCESS files as confirmation.

The .csv() method should be called after .option() and .mode() methods. In this case the header = true is used to write a header with the column names in the .csv files

In [13]:
df.write.option('header', 'true').mode('overwrite').csv('data/yellow_taxi_data_csv')

                                                                                

The df.write() method returns a **DataFrameWriter** object that has many modes:

**append**: Append contents of this DataFrame to existing data.

**overwrite**: Overwrite existing data.

**error or errorifexists**: Throw an exception if data already exists.

**ignore**: Silently ignore this operation if data already exists.

Read the csv data as a new dataframe

In [14]:
df_csv = spark.read.csv('data/yellow_taxi_data_csv')
df_csv.show(5)

+--------+--------------------+--------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+--------------------+------------+--------------------+-----------+------------------+
|     _c0|                 _c1|                 _c2|            _c3|          _c4|       _c5|               _c6|         _c7|         _c8|         _c9|       _c10| _c11|   _c12|      _c13|        _c14|                _c15|        _c16|                _c17|       _c18|              _c19|
+--------+--------------------+--------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+--------------------+------------+--------------------+-----------+------------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_date...|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationI

The dataframe columns were not inferred, the header = true option must be called. The option inferSchema is useful too.

First let's check the current schema

In [15]:
df_csv.printSchema()

root
 |-- _c0: string (nullable = true)
 |-- _c1: string (nullable = true)
 |-- _c2: string (nullable = true)
 |-- _c3: string (nullable = true)
 |-- _c4: string (nullable = true)
 |-- _c5: string (nullable = true)
 |-- _c6: string (nullable = true)
 |-- _c7: string (nullable = true)
 |-- _c8: string (nullable = true)
 |-- _c9: string (nullable = true)
 |-- _c10: string (nullable = true)
 |-- _c11: string (nullable = true)
 |-- _c12: string (nullable = true)
 |-- _c13: string (nullable = true)
 |-- _c14: string (nullable = true)
 |-- _c15: string (nullable = true)
 |-- _c16: string (nullable = true)
 |-- _c17: string (nullable = true)
 |-- _c18: string (nullable = true)
 |-- _c19: string (nullable = true)



Now let's read again it using both options, passed as a dictionary with the **.options** method, this is different from **.option**

In [16]:
# we can also use the read.csv() method and pass parameter to it
# df_csv = spark.read.csv('data/yellow_taxi_data_csv', header=true, inferschema=true)
df_csv = spark.read.options(header=True, inferschema=True, delimiter=',').csv('data/yellow_taxi_data_csv')

                                                                                

The columns's names are mapped now

In [17]:
df_csv.show(5)

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|cbd_congestion_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+
|       1| 2025-03-11 11:00:41|  2025-03-11 11:13:49|              2|          1.6|         1|                 Y|         170|    

And the dataframe has a different(and much better) schema

In [18]:
df_csv.printSchema()

root
 |-- VendorID: integer (nullable = true)
 |-- tpep_pickup_datetime: timestamp (nullable = true)
 |-- tpep_dropoff_datetime: timestamp (nullable = true)
 |-- passenger_count: integer (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: integer (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- payment_type: integer (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- Airport_fee: double (nullable = true)
 |-- cbd_congestion_fee: double (nullable = true)



The dataframe.schema Attribute holds the schema

In [19]:
print(df_csv.schema)

StructType([StructField('VendorID', IntegerType(), True), StructField('tpep_pickup_datetime', TimestampType(), True), StructField('tpep_dropoff_datetime', TimestampType(), True), StructField('passenger_count', IntegerType(), True), StructField('trip_distance', DoubleType(), True), StructField('RatecodeID', IntegerType(), True), StructField('store_and_fwd_flag', StringType(), True), StructField('PULocationID', IntegerType(), True), StructField('DOLocationID', IntegerType(), True), StructField('payment_type', IntegerType(), True), StructField('fare_amount', DoubleType(), True), StructField('extra', DoubleType(), True), StructField('mta_tax', DoubleType(), True), StructField('tip_amount', DoubleType(), True), StructField('tolls_amount', DoubleType(), True), StructField('improvement_surcharge', DoubleType(), True), StructField('total_amount', DoubleType(), True), StructField('congestion_surcharge', DoubleType(), True), StructField('Airport_fee', DoubleType(), True), StructField('cbd_conges

We can create a schema definition using the StructType and StructField classes, the structfield constructor parameters are: StructField(name, dataType, nullable=True, metadata=None)

In [20]:
df_schema = StructType([
    StructField("VendorID", IntegerType(), True),
    StructField("tpep_pickup_datetime", TimestampType(), True),
    StructField("tpep_dropoff_datetime", TimestampType(), True),
    StructField("passenger_count", IntegerType(), True),
    StructField("trip_distance", DoubleType(), True),
    StructField("RatecodeID", IntegerType(), True),
    StructField("store_and_fwd_flag", StringType(), True),
    StructField("PULocationID", IntegerType(), True),
    StructField("DOLocationID", IntegerType(), True),
    StructField("payment_type", IntegerType(), True),
    StructField("fare_amount", DoubleType(), True),
    StructField("extra", DoubleType(), True),
    StructField("mta_tax", DoubleType(), True),
    StructField("tip_amount", DoubleType(), True),
    StructField("tolls_amount", DoubleType(), True),
    StructField("improvement_surcharge", DoubleType(), True),
    StructField("total_amount", DoubleType(), True),
    StructField("congestion_surcharge", DoubleType(), True),
    StructField("Airport_fee", DoubleType(), True),
    StructField("cbd_congestion_fee", DoubleType(), True)
])
print(f'Dataframe schema:\n{df_schema}')

Dataframe schema:
StructType([StructField('VendorID', IntegerType(), True), StructField('tpep_pickup_datetime', TimestampType(), True), StructField('tpep_dropoff_datetime', TimestampType(), True), StructField('passenger_count', IntegerType(), True), StructField('trip_distance', DoubleType(), True), StructField('RatecodeID', IntegerType(), True), StructField('store_and_fwd_flag', StringType(), True), StructField('PULocationID', IntegerType(), True), StructField('DOLocationID', IntegerType(), True), StructField('payment_type', IntegerType(), True), StructField('fare_amount', DoubleType(), True), StructField('extra', DoubleType(), True), StructField('mta_tax', DoubleType(), True), StructField('tip_amount', DoubleType(), True), StructField('tolls_amount', DoubleType(), True), StructField('improvement_surcharge', DoubleType(), True), StructField('total_amount', DoubleType(), True), StructField('congestion_surcharge', DoubleType(), True), StructField('Airport_fee', DoubleType(), True), Struc

In [21]:
df_csv = spark.read.schema(schema=df_schema).csv('data/yellow_taxi_data_csv')

In [22]:
df_csv.printSchema()

root
 |-- VendorID: integer (nullable = true)
 |-- tpep_pickup_datetime: timestamp (nullable = true)
 |-- tpep_dropoff_datetime: timestamp (nullable = true)
 |-- passenger_count: integer (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: integer (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- payment_type: integer (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- Airport_fee: double (nullable = true)
 |-- cbd_congestion_fee: double (nullable = true)



We can print only the columns of the dataframe with the Dataframe.columns Attribute

In [23]:
print(df.columns)

['VendorID', 'tpep_pickup_datetime', 'tpep_dropoff_datetime', 'passenger_count', 'trip_distance', 'RatecodeID', 'store_and_fwd_flag', 'PULocationID', 'DOLocationID', 'payment_type', 'fare_amount', 'extra', 'mta_tax', 'tip_amount', 'tolls_amount', 'improvement_surcharge', 'total_amount', 'congestion_surcharge', 'Airport_fee', 'cbd_congestion_fee']


The DataFrame.dtypes Attribute can be useful if we only want the columans and their types


In [24]:
print(df.dtypes)

[('VendorID', 'int'), ('tpep_pickup_datetime', 'timestamp_ntz'), ('tpep_dropoff_datetime', 'timestamp_ntz'), ('passenger_count', 'bigint'), ('trip_distance', 'double'), ('RatecodeID', 'bigint'), ('store_and_fwd_flag', 'string'), ('PULocationID', 'int'), ('DOLocationID', 'int'), ('payment_type', 'bigint'), ('fare_amount', 'double'), ('extra', 'double'), ('mta_tax', 'double'), ('tip_amount', 'double'), ('tolls_amount', 'double'), ('improvement_surcharge', 'double'), ('total_amount', 'double'), ('congestion_surcharge', 'double'), ('Airport_fee', 'double'), ('cbd_congestion_fee', 'double')]


With a correct schema, we can use the DataFrame.summary() method to generate a new dataframe with the statistics of said DataFrame. Let's use the original dataframe

In [25]:
df.summary().show()

25/07/09 20:58:07 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
[Stage 18:>                                                         (0 + 1) / 1]

+-------+------------------+------------------+-----------------+------------------+------------------+------------------+-----------------+------------------+------------------+------------------+------------------+------------------+------------------+---------------------+------------------+--------------------+-------------------+------------------+
|summary|          VendorID|   passenger_count|    trip_distance|        RatecodeID|store_and_fwd_flag|      PULocationID|     DOLocationID|      payment_type|       fare_amount|             extra|           mta_tax|        tip_amount|      tolls_amount|improvement_surcharge|      total_amount|congestion_surcharge|        Airport_fee|cbd_congestion_fee|
+-------+------------------+------------------+-----------------+------------------+------------------+------------------+-----------------+------------------+------------------+------------------+------------------+------------------+------------------+---------------------+------------

                                                                                

If we only want the count here's the simple way

In [26]:
print(f'Count: {df.count()}')

Count: 4145257


# Dataframe Operations

## Select operations

Select: Returns a dataframe with only the **selected** columns

In [27]:
df.select('VendorID', 'fare_amount', 'extra', 'tip_amount', 'total_amount').show(5)

+--------+-----------+-----+----------+------------+
|VendorID|fare_amount|extra|tip_amount|total_amount|
+--------+-----------+-----+----------+------------+
|       1|        7.9|  3.5|       2.6|        15.5|
|       1|        6.5|  3.5|       2.3|        13.8|
|       2|       14.9|  1.0|      5.16|       25.81|
|       2|        7.2|  1.0|      2.59|       15.54|
|       1|        8.6| 4.25|      2.85|        17.2|
+--------+-----------+-----+----------+------------+
only showing top 5 rows



We  can use SQL `*` in a select method just like SQL 

In [28]:
df.select('*').show(5)

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|cbd_congestion_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+
|       1| 2025-03-01 00:17:16|  2025-03-01 00:25:52|              1|          0.9|         1|                 N|         140|    

We can select columns with Pandas's style too

In [29]:
df['VendorID', 'total_amount'].show(5)

+--------+------------+
|VendorID|total_amount|
+--------+------------+
|       1|        15.5|
|       1|        13.8|
|       2|       25.81|
|       2|       15.54|
|       1|        17.2|
+--------+------------+
only showing top 5 rows



Or an Attribute style

In [30]:
df.select(df.VendorID, df.total_amount).show(5)

+--------+------------+
|VendorID|total_amount|
+--------+------------+
|       1|        15.5|
|       1|        13.8|
|       2|       25.81|
|       2|       15.54|
|       1|        17.2|
+--------+------------+
only showing top 5 rows



Dataframe.limit(): Limits the result count of a Dataframe

In [31]:
df.select(df.VendorID, df.total_amount).limit(10).show()

+--------+------------+
|VendorID|total_amount|
+--------+------------+
|       1|        15.5|
|       1|        13.8|
|       2|       25.81|
|       2|       15.54|
|       1|        17.2|
|       1|        20.8|
|       2|        27.3|
|       2|       16.35|
|       2|        23.1|
|       2|       21.42|
+--------+------------+



DataFrame.collect() returns all the rows of the DataFrame

In [32]:
df.select(df.VendorID, df.total_amount).limit(10).collect()

[Row(VendorID=1, total_amount=15.5),
 Row(VendorID=1, total_amount=13.8),
 Row(VendorID=2, total_amount=25.81),
 Row(VendorID=2, total_amount=15.54),
 Row(VendorID=1, total_amount=17.2),
 Row(VendorID=1, total_amount=20.8),
 Row(VendorID=2, total_amount=27.3),
 Row(VendorID=2, total_amount=16.35),
 Row(VendorID=2, total_amount=23.1),
 Row(VendorID=2, total_amount=21.42)]

We can return a column using the col() function

In [33]:
df.select(col(col='VendorID'), col('total_amount')).show(5)

+--------+------------+
|VendorID|total_amount|
+--------+------------+
|       1|        15.5|
|       1|        13.8|
|       2|       25.81|
|       2|       15.54|
|       1|        17.2|
+--------+------------+
only showing top 5 rows



Dataframe Aliases (another name like in sql), they are useful to refer a column within a dataframe in joins

In [34]:
df_as = df.alias('nyc')

In [35]:
df_as.select(col('nyc.VendorID'), col('nyc.total_amount')).show(5)

+--------+------------+
|VendorID|total_amount|
+--------+------------+
|       1|        15.5|
|       1|        13.8|
|       2|       25.81|
|       2|       15.54|
|       1|        17.2|
+--------+------------+
only showing top 5 rows



Columns Aliases

In [36]:
df.select('VendorID', col('total_amount').alias('TOTAL_ALIAS')).show(5)

+--------+-----------+
|VendorID|TOTAL_ALIAS|
+--------+-----------+
|       1|       15.5|
|       1|       13.8|
|       2|      25.81|
|       2|      15.54|
|       1|       17.2|
+--------+-----------+
only showing top 5 rows



Sorting Data:

We can sort the data with the .sort() method, any number of columns can be used

In [37]:
df.select('VendorID', col('total_amount').alias('TOTAL_ALIAS')) \
    .sort(['TOTAL_ALIAS', 'VendorID'], ascending=[False, True]).show(15)

+--------+-----------+
|VendorID|TOTAL_ALIAS|
+--------+-----------+
|       1|   46269.44|
|       2|    2246.55|
|       2|    1694.31|
|       2|     1249.3|
|       2|    1240.52|
|       2|    1226.25|
|       2|     1189.2|
|       2|    1183.39|
|       2|    1175.52|
|       2|     1171.3|
|       2|     1081.2|
|       2|    1059.21|
|       2|    1015.55|
|       1|     1009.0|
|       2|    1004.25|
+--------+-----------+
only showing top 15 rows



There is also the Dataframe.orderBy() method, it's the same implementation of Dataframe.sort()

In [38]:
df.select(df.VendorID, df.total_amount).orderBy(df.total_amount, ascending=False).show()

+--------+------------+
|VendorID|total_amount|
+--------+------------+
|       1|    46269.44|
|       2|     2246.55|
|       2|     1694.31|
|       2|      1249.3|
|       2|     1240.52|
|       2|     1226.25|
|       2|      1189.2|
|       2|     1183.39|
|       2|     1175.52|
|       2|      1171.3|
|       2|      1081.2|
|       2|     1059.21|
|       2|     1015.55|
|       1|      1009.0|
|       2|     1004.25|
|       2|     1004.25|
|       1|     1003.62|
|       2|      1000.0|
|       2|      992.06|
|       2|      967.09|
+--------+------------+
only showing top 20 rows



We can use the Column.desc() method to sort in descending order, without providing the ascending parameter

In [39]:
df.select(df.VendorID, df.total_amount).orderBy(df.total_amount.desc()).show(15)

+--------+------------+
|VendorID|total_amount|
+--------+------------+
|       1|    46269.44|
|       2|     2246.55|
|       2|     1694.31|
|       2|      1249.3|
|       2|     1240.52|
|       2|     1226.25|
|       2|      1189.2|
|       2|     1183.39|
|       2|     1175.52|
|       2|      1171.3|
|       2|      1081.2|
|       2|     1059.21|
|       2|     1015.55|
|       1|      1009.0|
|       2|     1004.25|
+--------+------------+
only showing top 15 rows



## Filter Operations

We can use the DataFrame.filter() or DataFrame.where() methods to filter the data, .where() is an alias for filter(). This methods are used alongside filter operators to create a column expression that represents a `Filter Condition`.

Some Filter Operators:

| Operation        | Symbol | Example (Expression)                    | Equivalent SQL       |
| ---------------- | ------ | --------------------------------------- | -------------------- |
| Equal to         | `==`   | `col("age") == 30`                      | `age = 30`           |
| Not equal        | `!=`   | `col("age") != 30`                      | `age <> 30`          |
| Greater than     | `>`    | `col("age") > 18`                       | `age > 18`           |
| Less than        | `<`    | `col("age") < 65`                       | `age < 65`           |
| Greater or equal | `>=`   | `col("age") >= 21`                      | `age >= 21`          |
| Less or equal    | `<=`   | `col("age") <= 65`                      | `age <= 65`          |
| And              | `&`    | `(col("age") > 18) & (col("age") < 65)` | `AND`                |
| Or               | `\|`   | `(col("age") > 18) \| (col("age") < 65)`| `OR`                 |
| Not              | `~`    | `~(col("active") == True)`              | `NOT active = true`  |


In [40]:
df.select('VendorID', 'total_amount').where('total_amount > 100').show(15)

+--------+------------+
|VendorID|total_amount|
+--------+------------+
|       2|      112.89|
|       2|      153.85|
|       2|      103.86|
|       2|      100.13|
|       1|      121.75|
|       2|      106.97|
|       1|      100.39|
|       1|      101.89|
|       2|      100.13|
|       2|      148.26|
|       2|      238.89|
|       2|      104.21|
|       1|      100.09|
|       2|      103.86|
|       2|      112.75|
+--------+------------+
only showing top 15 rows



In [41]:
df.select('VendorID', 'total_amount').filter(df.VendorID == 1).show(5)

+--------+------------+
|VendorID|total_amount|
+--------+------------+
|       1|        15.5|
|       1|        13.8|
|       1|        17.2|
|       1|        20.8|
|       1|       83.44|
+--------+------------+
only showing top 5 rows



Multiple Operators

In [42]:
df.select('VendorID', 'payment_type', 'total_amount').filter((df.VendorID == 2) & (df.total_amount > 1000)).show()

+--------+------------+------------+
|VendorID|payment_type|total_amount|
+--------+------------+------------+
|       2|           2|     1694.31|
|       2|           1|     1015.55|
|       2|           2|     1059.21|
|       2|           1|      1171.3|
|       2|           1|      1081.2|
|       2|           1|      1189.2|
|       2|           1|     1175.52|
|       2|           1|     1240.52|
|       2|           2|     1183.39|
|       2|           1|     1004.25|
|       2|           1|     1004.25|
|       2|           1|     1226.25|
|       2|           1|      1249.3|
|       2|           1|     2246.55|
+--------+------------+------------+



In [43]:
df.select('VendorID', 'payment_type', 'total_amount') \
    .filter((df.VendorID == 2) & (df.total_amount > 1000) & (df.payment_type == 1)).show()

+--------+------------+------------+
|VendorID|payment_type|total_amount|
+--------+------------+------------+
|       2|           1|     1015.55|
|       2|           1|      1171.3|
|       2|           1|      1081.2|
|       2|           1|      1189.2|
|       2|           1|     1175.52|
|       2|           1|     1240.52|
|       2|           1|     1004.25|
|       2|           1|     1004.25|
|       2|           1|     1226.25|
|       2|           1|      1249.3|
|       2|           1|     2246.55|
+--------+------------+------------+



In [44]:
df.select('VendorID', 'total_amount') \
    .filter(((df.VendorID == 2) & (df.total_amount > 1000)) | (df.total_amount > 2000)).show()

+--------+------------+
|VendorID|total_amount|
+--------+------------+
|       2|     1694.31|
|       2|     1015.55|
|       2|     1059.21|
|       1|    46269.44|
|       2|      1171.3|
|       2|      1081.2|
|       2|      1189.2|
|       2|     1175.52|
|       2|     1240.52|
|       2|     1183.39|
|       2|     1004.25|
|       2|     1004.25|
|       2|     1226.25|
|       2|      1249.3|
|       2|     2246.55|
+--------+------------+



The Between operator

In [45]:
df.select(df.VendorID, df.total_amount).where(df.total_amount.between(15, 16)).show(15)

+--------+------------+
|VendorID|total_amount|
+--------+------------+
|       1|        15.5|
|       2|       15.54|
|       2|        15.7|
|       2|       15.29|
|       2|       15.93|
|       1|        15.5|
|       1|        15.5|
|       2|       15.54|
|       1|        15.0|
|       2|       15.48|
|       2|       15.75|
|       2|       15.54|
|       2|       15.75|
|       1|        15.5|
|       2|        15.8|
+--------+------------+
only showing top 15 rows



Filtering null values can be done with .isnull(), .isNotNull() and .isNaN() methods

In [46]:
df['VendorID', 'Airport_fee', 'total_amount'].where(col('Airport_fee').isNull()).show(5)

+--------+-----------+------------+
|VendorID|Airport_fee|total_amount|
+--------+-----------+------------+
|       2|       NULL|        4.76|
|       2|       NULL|       18.95|
|       2|       NULL|         2.0|
|       2|       NULL|        43.4|
|       2|       NULL|       39.23|
+--------+-----------+------------+
only showing top 5 rows



Filter text in a column with .contains()

In [47]:
df['VendorID', 'Airport_fee', 'total_amount', 'tpep_pickup_datetime'] \
    .where(col('tpep_pickup_datetime').contains('2025-03-05')).show()

+--------+-----------+------------+--------------------+
|VendorID|Airport_fee|total_amount|tpep_pickup_datetime|
+--------+-----------+------------+--------------------+
|       1|       1.75|       57.25| 2025-03-05 00:17:44|
|       1|        0.0|       16.35| 2025-03-05 00:02:33|
|       2|        0.0|       14.44| 2025-03-05 00:17:56|
|       2|        0.0|       19.25| 2025-03-05 00:39:48|
|       2|        0.0|       17.16| 2025-03-05 00:20:42|
|       2|        0.0|       16.32| 2025-03-05 00:00:03|
|       2|        0.0|        30.6| 2025-03-05 00:13:04|
|       2|        0.0|        -8.0| 2025-03-05 00:45:09|
|       2|        0.0|         8.0| 2025-03-05 00:45:09|
|       2|        0.0|       23.94| 2025-03-05 00:22:02|
|       2|        0.0|       17.22| 2025-03-05 00:29:21|
|       2|       1.75|       71.74| 2025-03-05 00:09:59|
|       2|       1.75|      147.72| 2025-03-05 00:26:52|
|       1|        0.0|       26.45| 2025-03-05 00:32:22|
|       2|        0.0|        1

SQL LIKE pattern filter

In [48]:
df['VendorID', 'Airport_fee', 'total_amount', 'tpep_pickup_datetime'] \
    .where(col('tpep_pickup_datetime').like('%51%')).show()

+--------+-----------+------------+--------------------+
|VendorID|Airport_fee|total_amount|tpep_pickup_datetime|
+--------+-----------+------------+--------------------+
|       2|        0.0|       18.15| 2025-03-01 00:41:51|
|       2|        0.0|       38.99| 2025-03-01 00:51:48|
|       2|        0.0|       18.06| 2025-03-01 00:51:23|
|       2|        0.0|       48.56| 2025-03-01 00:44:51|
|       1|        0.0|       26.45| 2025-03-01 00:16:51|
|       1|        0.0|       18.55| 2025-03-01 00:56:51|
|       1|        0.0|        14.0| 2025-03-01 00:51:04|
|       2|        0.0|       21.42| 2025-03-01 00:51:30|
|       2|        0.0|       16.38| 2025-03-01 00:51:14|
|       2|        0.0|       51.09| 2025-03-01 00:02:51|
|       2|        0.0|       14.95| 2025-03-01 00:22:51|
|       2|        0.0|       25.62| 2025-03-01 00:03:51|
|       2|        0.0|       15.65| 2025-03-01 00:51:35|
|       2|        0.0|       26.46| 2025-03-01 00:51:47|
|       2|        0.0|       21

Case insensitive SQL LIKE pattern filter

In [49]:
df.select(df.VendorID, df.store_and_fwd_flag).where(col('store_and_fwd_flag').like('n')).show(5)
df.select(df.VendorID, df.store_and_fwd_flag).where(col('store_and_fwd_flag').ilike('n')).show(5)

+--------+------------------+
|VendorID|store_and_fwd_flag|
+--------+------------------+
+--------+------------------+

+--------+------------------+
|VendorID|store_and_fwd_flag|
+--------+------------------+
|       1|                 N|
|       1|                 N|
|       2|                 N|
|       2|                 N|
|       1|                 N|
+--------+------------------+
only showing top 5 rows



Column.rlike(): SQL LIKE pattern with REGEX filter

In [50]:
df['VendorID', 'Airport_fee', 'total_amount', 'tpep_pickup_datetime'] \
    .where(col('tpep_pickup_datetime').rlike('.51$')).show()

+--------+-----------+------------+--------------------+
|VendorID|Airport_fee|total_amount|tpep_pickup_datetime|
+--------+-----------+------------+--------------------+
|       2|        0.0|       18.15| 2025-03-01 00:41:51|
|       2|        0.0|       48.56| 2025-03-01 00:44:51|
|       1|        0.0|       26.45| 2025-03-01 00:16:51|
|       1|        0.0|       18.55| 2025-03-01 00:56:51|
|       2|        0.0|       51.09| 2025-03-01 00:02:51|
|       2|        0.0|       14.95| 2025-03-01 00:22:51|
|       2|        0.0|       25.62| 2025-03-01 00:03:51|
|       2|        0.0|       21.75| 2025-03-01 00:50:51|
|       2|        0.0|        35.7| 2025-03-01 00:27:51|
|       2|        0.0|       16.23| 2025-03-01 00:20:51|
|       2|        0.0|       34.02| 2025-03-01 00:11:51|
|       1|        0.0|        15.3| 2025-03-01 00:51:51|
|       2|        0.0|       15.05| 2025-03-01 00:17:51|
|       2|        0.0|        18.9| 2025-03-01 00:35:51|
|       1|        0.0|       17

For advanced Regex operations check: regexp_count(), regexp_extract(), regexp_extract() and more at: https://spark.apache.org/docs/latest/api/python/search.html?q=regex

Startswith, filter strings that starts with the specified string filter without using LIKE or REGEX

In [51]:
df['VendorID', 'Airport_fee', 'total_amount', 'tpep_pickup_datetime'] \
    .where(col('tpep_pickup_datetime').startswith('2025-03-30 23')).show()

+--------+-----------+------------+--------------------+
|VendorID|Airport_fee|total_amount|tpep_pickup_datetime|
+--------+-----------+------------+--------------------+
|       2|        0.0|        27.3| 2025-03-30 23:59:06|
|       2|        0.0|       20.58| 2025-03-30 23:00:05|
|       1|       1.75|        71.0| 2025-03-30 23:18:41|
|       2|       1.75|       53.05| 2025-03-30 23:55:58|
|       1|       1.75|        91.8| 2025-03-30 23:05:01|
|       1|        0.0|       12.25| 2025-03-30 23:48:17|
|       2|        0.0|       26.46| 2025-03-30 23:00:27|
|       2|        0.0|        14.7| 2025-03-30 23:48:30|
|       1|        0.0|        19.7| 2025-03-30 23:33:25|
|       1|        0.0|       14.35| 2025-03-30 23:07:44|
|       1|        0.0|       13.65| 2025-03-30 23:17:18|
|       2|        0.0|       24.15| 2025-03-30 23:22:33|
|       1|        0.0|        11.5| 2025-03-30 23:04:25|
|       2|        0.0|       29.58| 2025-03-30 23:01:48|
|       1|        0.0|       17

Filter unique rows, just like `select distinct`

In [52]:
df.select('VendorID').distinct().show()

+--------+
|VendorID|
+--------+
|       1|
|       7|
|       2|
|       6|
+--------+



## Handling columns

Create columns

Dataframe.withColumn() creates a new a column based on an expression

In [53]:
df.select('VendorID', 'total_amount').withColumn('new_id', df.VendorID * 100).show(5)

+--------+------------+------+
|VendorID|total_amount|new_id|
+--------+------------+------+
|       1|        15.5|   100|
|       1|        13.8|   100|
|       2|       25.81|   200|
|       2|       15.54|   200|
|       1|        17.2|   100|
+--------+------------+------+
only showing top 5 rows



We can create multiple columns in a single operation with Dataframe.withColumns(), it receives a dictionary of names and columns, the lit() function creates a column based on a literal value

In [54]:
df.select('VendorID', 'total_amount') \
    .withColumns({
        'new_id': df.VendorID * 100, 
        'new_total_amount': concat(lit('$   '), df.total_amount)
        }).show(5)

+--------+------------+------+----------------+
|VendorID|total_amount|new_id|new_total_amount|
+--------+------------+------+----------------+
|       1|        15.5|   100|        $   15.5|
|       1|        13.8|   100|        $   13.8|
|       2|       25.81|   200|       $   25.81|
|       2|       15.54|   200|       $   15.54|
|       1|        17.2|   100|        $   17.2|
+--------+------------+------+----------------+
only showing top 5 rows



Column.when(), Functions.when(), Column.otherwise(), Functions.otherwise() are the equivalent to SQL CASE, WHEN and ELSE on pyspark

In [55]:
df.select('VendorID', 
    when(df.VendorID == 1, 'Creative Mobile').when(df.VendorID == 2, 'Curb Mobility, LLC') \
    .when(df.VendorID == 6, 'Myle Technologies').otherwise('Helix') \
    .alias('Case Statement')
    ).distinct().show()

+--------+------------------+
|VendorID|    Case Statement|
+--------+------------------+
|       7|             Helix|
|       1|   Creative Mobile|
|       2|Curb Mobility, LLC|
|       6| Myle Technologies|
+--------+------------------+



Return or change a column based on a condition

In [56]:
df.select(df.VendorID, df.total_amount).withColumn('VendorId',
    when(df.total_amount >= 25, 100).when(df.total_amount < 25, 200).otherwise(df.VendorID)).show()

+--------+------------+
|VendorId|total_amount|
+--------+------------+
|     200|        15.5|
|     200|        13.8|
|     100|       25.81|
|     200|       15.54|
|     200|        17.2|
|     200|        20.8|
|     100|        27.3|
|     200|       16.35|
|     200|        23.1|
|     200|       21.42|
|     200|       14.35|
|     200|       14.25|
|     100|       83.44|
|     100|      112.89|
|     100|       28.14|
|     200|        15.7|
|     200|       18.15|
|     200|       16.35|
|     200|       24.75|
|     200|       12.95|
+--------+------------+
only showing top 20 rows



Dataframes and column .alias() method return a Dataframe with an **alias set** applied

In [57]:
df.select('VendorID', col('total_amount').alias('TOTAL_ALIAS')).show(5)

+--------+-----------+
|VendorID|TOTAL_ALIAS|
+--------+-----------+
|       1|       15.5|
|       1|       13.8|
|       2|      25.81|
|       2|      15.54|
|       1|       17.2|
+--------+-----------+
only showing top 5 rows



In [58]:
df.select('VendorID', col('total_amount').alias('TOTAL_ALIAS'))

DataFrame[VendorID: int, TOTAL_ALIAS: double]

Renaming Columns is different from giving them aliases because if the column to be renamed doesn't exists, it doesn't change anything

Rename a column

In [59]:
df.select('VendorID', 'total_amount').withColumnRenamed('VendorID', 'Id').show(5)

+---+------------+
| Id|total_amount|
+---+------------+
|  1|        15.5|
|  1|        13.8|
|  2|       25.81|
|  2|       15.54|
|  1|        17.2|
+---+------------+
only showing top 5 rows



Rename multiple columns

In [60]:
df.select('VendorID', 'total_amount').withColumnsRenamed({'VendorID': 'Id', 'total_amount': 'amount'}).show(5)

+---+------+
| Id|amount|
+---+------+
|  1|  15.5|
|  1|  13.8|
|  2| 25.81|
|  2| 15.54|
|  1|  17.2|
+---+------+
only showing top 5 rows



Drop columns: The Dataframe.Drop() method returns a new dataframe without the columns specified in it

In [61]:
df.select('VendorID', 'total_amount').drop('total_amount').show(5)

+--------+
|VendorID|
+--------+
|       1|
|       1|
|       2|
|       2|
|       1|
+--------+
only showing top 5 rows



## Aggregations

Group By functions: COUNT, SUM, MIN, MAX, AVG, etc

Dataframe.GroupBy returns a DataFrameGroupBy class that is used to create aggregations

In [62]:
print(df.groupBy('VendorID'))

GroupedData[grouping expressions: [VendorID], value: [VendorID: int, tpep_pickup_datetime: timestamp_ntz ... 18 more fields], type: GroupBy]


Count Aggregation: counts the rows by the groupby class

In [63]:
df.groupBy('VendorID').count().show()

+--------+-------+
|VendorID|  count|
+--------+-------+
|       1| 834394|
|       7|  21481|
|       2|3289048|
|       6|    334|
+--------+-------+



Mean, sum, min, max and avg compute the function for all columns by default

In [64]:
df.groupBy('VendorID').mean().show()



+--------+-------------+--------------------+------------------+------------------+------------------+------------------+------------------+------------------+------------------+-------------------+------------------+-------------------+--------------------------+------------------+-------------------------+-------------------+-----------------------+
|VendorID|avg(VendorID)|avg(passenger_count)|avg(trip_distance)|   avg(RatecodeID)| avg(PULocationID)| avg(DOLocationID)| avg(payment_type)|  avg(fare_amount)|        avg(extra)|       avg(mta_tax)|   avg(tip_amount)|  avg(tolls_amount)|avg(improvement_surcharge)| avg(total_amount)|avg(congestion_surcharge)|   avg(Airport_fee)|avg(cbd_congestion_fee)|
+--------+-------------+--------------------+------------------+------------------+------------------+------------------+------------------+------------------+------------------+-------------------+------------------+-------------------+--------------------------+------------------+---------

                                                                                

The last output showed every column, remember to select first

In [65]:
df.select('VendorID', 'total_amount').groupBy('VendorID').mean('total_amount').show()

+--------+------------------+
|VendorID| avg(total_amount)|
+--------+------------------+
|       1| 26.81411121125252|
|       7|25.361795540244803|
|       2| 26.13239850861697|
|       6|29.504341317365295|
+--------+------------------+



Grouping without .select()

In [66]:
df.groupBy('VendorID').mean('total_amount').show()

+--------+------------------+
|VendorID| avg(total_amount)|
+--------+------------------+
|       1| 26.81411121125252|
|       7|25.361795540244803|
|       2| 26.13239850861697|
|       6|29.504341317365295|
+--------+------------------+



We can compute multiple aggregations across the entire dataframe with the .agg() method, this is also an alias for DataFrame.groupBy()

In [67]:
df.agg({ 'total_amount': 'avg', 'trip_distance': 'min', 'extra': 'max', \
        'VendorID': 'count', 'tolls_amount': 'sum'}).show()

+------------------+-----------------+---------------+----------+------------------+
|min(trip_distance)|sum(tolls_amount)|count(VendorID)|max(extra)| avg(total_amount)|
+------------------+-----------------+---------------+----------+------------------+
|               0.0|1968899.629999516|        4145257|      13.5|26.265898046844185|
+------------------+-----------------+---------------+----------+------------------+



## Join and Union operations

Spark has **join types**, just like SQL, these are the ways that we can **join** (mix and match) the data between DataFrames(tables). Some examples: Inner, left, right, full outer, etc

Spark also has **Join Srategies** these are the ways or algorithms used to shuffle, mix and match the data across nodes and executors in the spark cluster. Examples: Broadcast, sort merge, shuffle, etc; we can read about these in the [performance tuning documentation](!https://spark.apache.org/docs/latest/sql-performance-tuning.html)

We can join dataframes using `DataFrame.join(other, on=None, how=None)`, `inner` is the default value of the `how` parameter, and `other` is the  right side dataframe

In [68]:
# Prepare dataframes
df_a = df.select(df.VendorID, df.trip_distance).limit(1)
df_b = df.select(df.VendorID, df.total_amount).limit(1)
df_join = df_a.join(df_b, 'VendorID', 'inner')
df_join.show()

+--------+-------------+------------+
|VendorID|trip_distance|total_amount|
+--------+-------------+------------+
|       1|          0.9|        15.5|
+--------+-------------+------------+



Union concatenates DataFrames with the same schema, using the column's position, like SQL UNION, but with the caveat that there isn't automatic deduplication of rows, this makes spark's union closer to SQL UNION ALL, the SQL UNION behavior can be archieved with .union().distinct()

In this example both dataframes have integer and double columns, making union possible, the output columns are VendorID and trip_distance because df_a was the main DataFrame

In [69]:
df_union = df_a.union(df_b)
df_union.show()

+--------+-------------+
|VendorID|trip_distance|
+--------+-------------+
|       1|          0.9|
|       1|         15.5|
+--------+-------------+



.unionAll() ALL is an alias to .union() in pyspark

In [70]:
print(f'Normal spark Union (SQL Union All): {df_a.unionAll(df_union).count()}')
print(f'Deduplicated spark Union (Union): {df_a.unionAll(df_union).distinct().count()}')

Normal spark Union (SQL Union All): 3
Deduplicated spark Union (Union): 2


# Spark Catalog

The spark catalog is the metadata store used for every SQL objects in a spark session. By default spark will use the Session Catalog, saved in the spark driver's memory. 

There are several Catalogs that Spark can utilize such as:
- Spark Session Catalog
- Hive Catalog
- Delta Catalog
- Hudi Catalog
- Iceberg Catalog and it's variants:
    - Hadoop
    - Hive
    - Rest
    - Nessie
- Unity Catalog
- JDBC Catalog
- Custom Catalog

let's check some basic catalog operations

List Catalogs

In [71]:
spark.catalog.listCatalogs()

[CatalogMetadata(name='spark_catalog', description=None)]

Set default catalog

In [72]:
spark.catalog.setCurrentCatalog('spark_catalog')

Print the default catalog

In [73]:
spark.catalog.currentCatalog()

'spark_catalog'

List all Databases

In [74]:
spark.catalog.listDatabases()

[Database(name='default', catalog='spark_catalog', description='default database', locationUri='file:/mnt/home/repos/learn_pyspark/spark-warehouse')]

Set default database, just like USE in SQL

In [75]:
spark.catalog.setCurrentDatabase('default')

Check default database in session

In [76]:
spark.catalog.currentDatabase()

'default'

Get a database object

In [77]:
spark.catalog.getDatabase('default')

Database(name='default', catalog='spark_catalog', description='default database', locationUri='file:/mnt/home/repos/learn_pyspark/spark-warehouse')

Check if a table exists

In [78]:
spark.catalog.tableExists('union_example')

False

Create a table, this table is managed by the spark catalog

In [79]:
if not spark.catalog.tableExists('union_example'):
    spark.catalog.createTable('union_example', schema=df_union.schema)
    print('union_example table created')
else:
    print('Table already exists')

union_example table created


Create a table from a dataframe

In [80]:
df_union.write.format('parquet').mode('overwrite').saveAsTable('df_union_table')

Create an external table(unmanaged tables), these are not managed by the spark catalog and we must specify where are stored

In [81]:
# root directory of external_tables
os.makedirs('spark-warehouse/external_tables/union_external_table', exist_ok=True)
spark.catalog.createTable(tableName='union_external_table', 
    path='external_tables/union_external_table', schema=df_union.schema)

DataFrame[VendorID: int, trip_distance: double]

In [82]:
df_union.write.format('parquet').mode('overwrite').saveAsTable('external2', 
    path='external_tables/external2')

Get a table object

In [83]:
spark.catalog.getTable('union_example')

Table(name='union_example', catalog='spark_catalog', namespace=['default'], description=None, tableType='MANAGED', isTemporary=False)

Create a temporal view

In [84]:
df_union.createOrReplaceTempView('union_view')

Drop temporal view

In [85]:
spark.catalog.dropTempView('union_view')

True

List tables


In [86]:
spark.catalog.listTables()

[Table(name='df_union_table', catalog='spark_catalog', namespace=['default'], description=None, tableType='MANAGED', isTemporary=False),
 Table(name='external2', catalog='spark_catalog', namespace=['default'], description=None, tableType='EXTERNAL', isTemporary=False),
 Table(name='union_example', catalog='spark_catalog', namespace=['default'], description=None, tableType='MANAGED', isTemporary=False),
 Table(name='union_external_table', catalog='spark_catalog', namespace=['default'], description=None, tableType='EXTERNAL', isTemporary=False)]

List columns of a table or view

In [87]:
spark.catalog.listColumns('union_example')

[Column(name='VendorID', description=None, dataType='int', nullable=True, isPartition=False, isBucket=False),
 Column(name='trip_distance', description=None, dataType='double', nullable=True, isPartition=False, isBucket=False)]

# Spark SQL

We can execute SQL code in pyspark using spark.sql(), this is the Spark SQL API.

Almost everything in this guide can be done using spark sql instead of the dataframe api

In [88]:
spark.sql('CREATE DATABASE example_db')

DataFrame[]

In [89]:
spark.sql("SHOW DATABASES").show()

+----------+
| namespace|
+----------+
|   default|
|example_db|
+----------+



In [90]:
spark.sql("USE example_db")

DataFrame[]

In [91]:
spark.sql('DROP DATABASE IF EXISTS example_db')
spark.sql('SHOW DATABASES').show()
spark.sql("USE default")

+---------+
|namespace|
+---------+
|  default|
+---------+



DataFrame[]

In [92]:
# Prepare example
df_union.createOrReplaceTempView('union_view')

In [93]:
spark.sql('SELECT * FROM union_view').show()

+--------+-------------+
|VendorID|trip_distance|
+--------+-------------+
|       1|          0.9|
|       1|         15.5|
+--------+-------------+



In [94]:
spark.sql('''
    CREATE OR REPLACE TEMP VIEW union_view2 AS
        SELECT *, CONCAT(trip_distance, ' miles') AS new_distance 
        FROM union_view
''')

DataFrame[]

In [95]:
spark.sql('SELECT * FROM union_view2').show()

+--------+-------------+------------+
|VendorID|trip_distance|new_distance|
+--------+-------------+------------+
|       1|          0.9|   0.9 miles|
|       1|         15.5|  15.5 miles|
+--------+-------------+------------+



In [96]:
spark.sql('DESCRIBE TABLE union_example').show()

+-------------+---------+-------+
|     col_name|data_type|comment|
+-------------+---------+-------+
|     VendorID|      int|   NULL|
|trip_distance|   double|   NULL|
+-------------+---------+-------+



In [97]:
spark.sql('DROP TABLE IF EXISTS union_example')

DataFrame[]

## Explain query Plan

In [98]:
df_union.explain()

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Union
   :- GlobalLimit 1, 0
   :  +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=2452]
   :     +- LocalLimit 1
   :        +- FileScan parquet [VendorID#4,trip_distance#8] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/mnt/home/repos/learn_pyspark/data/yellow_taxi_data_202503.parquet], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<VendorID:int,trip_distance:double>
   +- GlobalLimit 1, 0
      +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=2454]
         +- LocalLimit 1
            +- FileScan parquet [VendorID#4304,total_amount#4320] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/mnt/home/repos/learn_pyspark/data/yellow_taxi_data_202503.parquet], PartitionFilters: [], PushedFilters: [], ReadSchema: struct<VendorID:int,total_amount:double>




In [99]:
spark.sql('EXPLAIN SELECT * FROM union_view').show(truncate=False)

+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|plan                                        

Shut down the current SparkSession with SparkSession.stop()

In [100]:
# spark.stop()