# Learn and practice pyspark

## Prepare the  environment

Import necessary modules and packages

In [1]:
import os
import requests
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, DoubleType, TimestampType
from pyspark.sql.functions import col

Prepare the project's directories

In [2]:
# Prepare directories
os.makedirs("data/", exist_ok=True)

Download a subset of the NYC taxi trips dataset. This subset contains the trips of yellow taxis from march 2025

In [3]:
# Download NYC taxi trips dataset
URL = "https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2025-03.parquet"
response = requests.get(url=URL, stream=True, timeout=(10, 30))

if os.path.exists("data/yellow_taxi_data_202503.parquet"):
    print("Dataset ready")
else:
    with open("data/yellow_taxi_data_202503.parquet", "wb") as file:
        for chunk in response.iter_content(chunk_size=4092):
            file.write(chunk)
    print("Dataset downloaded")

Dataset ready


# Pyspark basics

Initialize Spark Session, this is the basic entrypoint to interact with Apache Spark

In [4]:
spark = SparkSession.builder.appName("LearnSpark") \
    .config('spark.executor.memory', '4g') \
    .config('spark.executor.cores', '4') \
    .config('spark.sql.shuffle.partitions', '200') \
    .getOrCreate()

25/06/28 20:43:55 WARN Utils: Your hostname, xblade resolves to a loopback address: 127.0.1.1; using 192.168.10.156 instead (on interface enp2s0)
25/06/28 20:43:55 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/06/28 20:43:55 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In the SparkSession builder we can configure the spark deployment and its process like the amount of memory and cores assigned to the executor process. The [configuration docs are Here](!https://spark.apache.org/docs/latest/configuration.html)

## The Dataframe basics

Create a Dataframe from scratch

In [5]:
df_plain = spark.createDataFrame(data=[('Alan', 10), ('Cindy', 12)], schema=('Name', 'Age'))
print(df_plain)

DataFrame[Name: string, Age: bigint]


Create a Dataframe by reading a file, it could be parquet, json, csv, etc

In [6]:
df: DataFrame = spark.read.parquet("data/yellow_taxi_data_202503.parquet")

Show the first rows of a dataframe

In [7]:
df.show()

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|cbd_congestion_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+
|       1| 2025-03-01 00:17:16|  2025-03-01 00:25:52|              1|          0.9|         1|                 N|         140|    

In [8]:
# only the first 5 rows
df.show(5)

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|cbd_congestion_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+
|       1| 2025-03-01 00:17:16|  2025-03-01 00:25:52|              1|          0.9|         1|                 N|         140|    

The Dataframe.head() Returns the first n rows like Pandas (1 by default on spark)

In [9]:
df.head(5)

[Row(VendorID=1, tpep_pickup_datetime=datetime.datetime(2025, 3, 1, 0, 17, 16), tpep_dropoff_datetime=datetime.datetime(2025, 3, 1, 0, 25, 52), passenger_count=1, trip_distance=0.9, RatecodeID=1, store_and_fwd_flag='N', PULocationID=140, DOLocationID=236, payment_type=1, fare_amount=7.9, extra=3.5, mta_tax=0.5, tip_amount=2.6, tolls_amount=0.0, improvement_surcharge=1.0, total_amount=15.5, congestion_surcharge=2.5, Airport_fee=0.0, cbd_congestion_fee=0.0),
 Row(VendorID=1, tpep_pickup_datetime=datetime.datetime(2025, 3, 1, 0, 37, 38), tpep_dropoff_datetime=datetime.datetime(2025, 3, 1, 0, 43, 51), passenger_count=1, trip_distance=0.6, RatecodeID=1, store_and_fwd_flag='N', PULocationID=140, DOLocationID=262, payment_type=1, fare_amount=6.5, extra=3.5, mta_tax=0.5, tip_amount=2.3, tolls_amount=0.0, improvement_surcharge=1.0, total_amount=13.8, congestion_surcharge=2.5, Airport_fee=0.0, cbd_congestion_fee=0.0),
 Row(VendorID=2, tpep_pickup_datetime=datetime.datetime(2025, 3, 1, 0, 24, 35)

The DataFrame.take() method works similarly to head()

In [10]:
df.take(5)

[Row(VendorID=1, tpep_pickup_datetime=datetime.datetime(2025, 3, 1, 0, 17, 16), tpep_dropoff_datetime=datetime.datetime(2025, 3, 1, 0, 25, 52), passenger_count=1, trip_distance=0.9, RatecodeID=1, store_and_fwd_flag='N', PULocationID=140, DOLocationID=236, payment_type=1, fare_amount=7.9, extra=3.5, mta_tax=0.5, tip_amount=2.6, tolls_amount=0.0, improvement_surcharge=1.0, total_amount=15.5, congestion_surcharge=2.5, Airport_fee=0.0, cbd_congestion_fee=0.0),
 Row(VendorID=1, tpep_pickup_datetime=datetime.datetime(2025, 3, 1, 0, 37, 38), tpep_dropoff_datetime=datetime.datetime(2025, 3, 1, 0, 43, 51), passenger_count=1, trip_distance=0.6, RatecodeID=1, store_and_fwd_flag='N', PULocationID=140, DOLocationID=262, payment_type=1, fare_amount=6.5, extra=3.5, mta_tax=0.5, tip_amount=2.3, tolls_amount=0.0, improvement_surcharge=1.0, total_amount=13.8, congestion_surcharge=2.5, Airport_fee=0.0, cbd_congestion_fee=0.0),
 Row(VendorID=2, tpep_pickup_datetime=datetime.datetime(2025, 3, 1, 0, 24, 35)

We can create a new Pandas Dataframe based on the original dataframe with Dataframe.toPandas()

In [11]:
pd_df = df_plain.toPandas()
print(type(pd_df))

<class 'pandas.core.frame.DataFrame'>


                                                                                

Show the DataFrame  with Pandas's Dataframe.head()

In [12]:
pd_df.head()

Unnamed: 0,Name,Age
0,Alan,10
1,Cindy,12


Write a dataframe to csv files.
The write.csv() method will create a directory and write multiple csv files(parts of the dataframe) that together, make up the whole dataframe and _SUCCESS files as confirmation.

The .csv() method should be called after .option() and .mode() methods. In this case the header = true is used to write a header with the column names in the .csv files

In [13]:
df.write.option('header', 'true').mode('overwrite').csv('data/yellow_taxi_data_csv')

                                                                                

The df.write() method returns a **DataFrameWriter** object that has many modes:

**append**: Append contents of this DataFrame to existing data.

**overwrite**: Overwrite existing data.

**error or errorifexists**: Throw an exception if data already exists.

**ignore**: Silently ignore this operation if data already exists.

Read the csv data as a new dataframe

In [14]:
df_csv = spark.read.csv('data/yellow_taxi_data_csv')
df_csv.show()

+--------+--------------------+--------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+--------------------+------------+--------------------+-----------+------------------+
|     _c0|                 _c1|                 _c2|            _c3|          _c4|       _c5|               _c6|         _c7|         _c8|         _c9|       _c10| _c11|   _c12|      _c13|        _c14|                _c15|        _c16|                _c17|       _c18|              _c19|
+--------+--------------------+--------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+--------------------+------------+--------------------+-----------+------------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_date...|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationI

The dataframe columns were not inferred, the header = true option must be called. The option inferSchema is useful too.

First let's check the current schema

In [15]:
df_csv.printSchema()

root
 |-- _c0: string (nullable = true)
 |-- _c1: string (nullable = true)
 |-- _c2: string (nullable = true)
 |-- _c3: string (nullable = true)
 |-- _c4: string (nullable = true)
 |-- _c5: string (nullable = true)
 |-- _c6: string (nullable = true)
 |-- _c7: string (nullable = true)
 |-- _c8: string (nullable = true)
 |-- _c9: string (nullable = true)
 |-- _c10: string (nullable = true)
 |-- _c11: string (nullable = true)
 |-- _c12: string (nullable = true)
 |-- _c13: string (nullable = true)
 |-- _c14: string (nullable = true)
 |-- _c15: string (nullable = true)
 |-- _c16: string (nullable = true)
 |-- _c17: string (nullable = true)
 |-- _c18: string (nullable = true)
 |-- _c19: string (nullable = true)



Now let's read again it using both options, passed as a dictionary with the **.options** method, this is different from **.option**

In [16]:
# we can also use the read.csv() method and pass parameter to it
# df_csv = spark.read.csv('data/yellow_taxi_data_csv', header=true, inferschema=true)
df_csv = spark.read.options(header=True, inferschema=True, delimiter=',').csv('data/yellow_taxi_data_csv')

                                                                                

The columns's names are mapped now

In [17]:
df_csv.show(5)

+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+
|VendorID|tpep_pickup_datetime|tpep_dropoff_datetime|passenger_count|trip_distance|RatecodeID|store_and_fwd_flag|PULocationID|DOLocationID|payment_type|fare_amount|extra|mta_tax|tip_amount|tolls_amount|improvement_surcharge|total_amount|congestion_surcharge|Airport_fee|cbd_congestion_fee|
+--------+--------------------+---------------------+---------------+-------------+----------+------------------+------------+------------+------------+-----------+-----+-------+----------+------------+---------------------+------------+--------------------+-----------+------------------+
|       2| 2025-03-31 08:25:18|  2025-03-31 08:31:33|              1|         0.92|         1|                 N|         164|    

And the dataframe has a different(and much better) schema

In [18]:
df_csv.printSchema()

root
 |-- VendorID: integer (nullable = true)
 |-- tpep_pickup_datetime: timestamp (nullable = true)
 |-- tpep_dropoff_datetime: timestamp (nullable = true)
 |-- passenger_count: integer (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: integer (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- payment_type: integer (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- Airport_fee: double (nullable = true)
 |-- cbd_congestion_fee: double (nullable = true)



The dataframe.schema Attribute holds the schema

In [19]:
print(df_csv.schema)

StructType([StructField('VendorID', IntegerType(), True), StructField('tpep_pickup_datetime', TimestampType(), True), StructField('tpep_dropoff_datetime', TimestampType(), True), StructField('passenger_count', IntegerType(), True), StructField('trip_distance', DoubleType(), True), StructField('RatecodeID', IntegerType(), True), StructField('store_and_fwd_flag', StringType(), True), StructField('PULocationID', IntegerType(), True), StructField('DOLocationID', IntegerType(), True), StructField('payment_type', IntegerType(), True), StructField('fare_amount', DoubleType(), True), StructField('extra', DoubleType(), True), StructField('mta_tax', DoubleType(), True), StructField('tip_amount', DoubleType(), True), StructField('tolls_amount', DoubleType(), True), StructField('improvement_surcharge', DoubleType(), True), StructField('total_amount', DoubleType(), True), StructField('congestion_surcharge', DoubleType(), True), StructField('Airport_fee', DoubleType(), True), StructField('cbd_conges

We can create a schema definition using the StructType and StructField classes, the structfield constructor parameters are: StructField(name, dataType, nullable=True, metadata=None)

In [20]:
df_schema = StructType([
    StructField("VendorID", IntegerType(), True),
    StructField("tpep_pickup_datetime", TimestampType(), True),
    StructField("tpep_dropoff_datetime", TimestampType(), True),
    StructField("passenger_count", IntegerType(), True),
    StructField("trip_distance", DoubleType(), True),
    StructField("RatecodeID", IntegerType(), True),
    StructField("store_and_fwd_flag", StringType(), True),
    StructField("PULocationID", IntegerType(), True),
    StructField("DOLocationID", IntegerType(), True),
    StructField("payment_type", IntegerType(), True),
    StructField("fare_amount", DoubleType(), True),
    StructField("extra", DoubleType(), True),
    StructField("mta_tax", DoubleType(), True),
    StructField("tip_amount", DoubleType(), True),
    StructField("tolls_amount", DoubleType(), True),
    StructField("improvement_surcharge", DoubleType(), True),
    StructField("total_amount", DoubleType(), True),
    StructField("congestion_surcharge", DoubleType(), True),
    StructField("Airport_fee", DoubleType(), True),
    StructField("cbd_congestion_fee", DoubleType(), True)
])
print(f'Dataframe schema:\n{df_schema}')

Dataframe schema:
StructType([StructField('VendorID', IntegerType(), True), StructField('tpep_pickup_datetime', TimestampType(), True), StructField('tpep_dropoff_datetime', TimestampType(), True), StructField('passenger_count', IntegerType(), True), StructField('trip_distance', DoubleType(), True), StructField('RatecodeID', IntegerType(), True), StructField('store_and_fwd_flag', StringType(), True), StructField('PULocationID', IntegerType(), True), StructField('DOLocationID', IntegerType(), True), StructField('payment_type', IntegerType(), True), StructField('fare_amount', DoubleType(), True), StructField('extra', DoubleType(), True), StructField('mta_tax', DoubleType(), True), StructField('tip_amount', DoubleType(), True), StructField('tolls_amount', DoubleType(), True), StructField('improvement_surcharge', DoubleType(), True), StructField('total_amount', DoubleType(), True), StructField('congestion_surcharge', DoubleType(), True), StructField('Airport_fee', DoubleType(), True), Struc

In [21]:
df_csv = spark.read.schema(schema=df_schema).csv('data/yellow_taxi_data_csv')

In [22]:
df_csv.printSchema()

root
 |-- VendorID: integer (nullable = true)
 |-- tpep_pickup_datetime: timestamp (nullable = true)
 |-- tpep_dropoff_datetime: timestamp (nullable = true)
 |-- passenger_count: integer (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: integer (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- payment_type: integer (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- Airport_fee: double (nullable = true)
 |-- cbd_congestion_fee: double (nullable = true)



We can print only the columns of the dataframe with the Dataframe.columns Attribute

In [23]:
print(df.columns)

['VendorID', 'tpep_pickup_datetime', 'tpep_dropoff_datetime', 'passenger_count', 'trip_distance', 'RatecodeID', 'store_and_fwd_flag', 'PULocationID', 'DOLocationID', 'payment_type', 'fare_amount', 'extra', 'mta_tax', 'tip_amount', 'tolls_amount', 'improvement_surcharge', 'total_amount', 'congestion_surcharge', 'Airport_fee', 'cbd_congestion_fee']


The DataFrame.dtypes Attribute can be useful if we only want the columans and their types


In [24]:
print(df.dtypes)

[('VendorID', 'int'), ('tpep_pickup_datetime', 'timestamp_ntz'), ('tpep_dropoff_datetime', 'timestamp_ntz'), ('passenger_count', 'bigint'), ('trip_distance', 'double'), ('RatecodeID', 'bigint'), ('store_and_fwd_flag', 'string'), ('PULocationID', 'int'), ('DOLocationID', 'int'), ('payment_type', 'bigint'), ('fare_amount', 'double'), ('extra', 'double'), ('mta_tax', 'double'), ('tip_amount', 'double'), ('tolls_amount', 'double'), ('improvement_surcharge', 'double'), ('total_amount', 'double'), ('congestion_surcharge', 'double'), ('Airport_fee', 'double'), ('cbd_congestion_fee', 'double')]


With a correct schema, we can use the DataFrame.summary() method to generate a new dataframe with the statistics of said DataFrame. Let's use the original dataframe

In [25]:
df.summary().show()

25/06/28 20:44:09 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
[Stage 18:>                                                         (0 + 1) / 1]

+-------+------------------+------------------+-----------------+------------------+------------------+------------------+-----------------+------------------+------------------+------------------+------------------+------------------+------------------+---------------------+------------------+--------------------+-------------------+------------------+
|summary|          VendorID|   passenger_count|    trip_distance|        RatecodeID|store_and_fwd_flag|      PULocationID|     DOLocationID|      payment_type|       fare_amount|             extra|           mta_tax|        tip_amount|      tolls_amount|improvement_surcharge|      total_amount|congestion_surcharge|        Airport_fee|cbd_congestion_fee|
+-------+------------------+------------------+-----------------+------------------+------------------+------------------+-----------------+------------------+------------------+------------------+------------------+------------------+------------------+---------------------+------------

                                                                                

If we only want the count here's the simple way

In [26]:
print(f'Count: {df.count()}')

Count: 4145257


# Dataframe Operations

## Select operations

Select: Returns a dataframe with only the **selected** columns

In [27]:
df.select('VendorID', 'fare_amount', 'extra', 'tip_amount', 'total_amount').show()

+--------+-----------+-----+----------+------------+
|VendorID|fare_amount|extra|tip_amount|total_amount|
+--------+-----------+-----+----------+------------+
|       1|        7.9|  3.5|       2.6|        15.5|
|       1|        6.5|  3.5|       2.3|        13.8|
|       2|       14.9|  1.0|      5.16|       25.81|
|       2|        7.2|  1.0|      2.59|       15.54|
|       1|        8.6| 4.25|      2.85|        17.2|
|       1|       16.3|  1.0|       2.0|        20.8|
|       2|       17.0|  1.0|      4.55|        27.3|
|       2|        8.6|  1.0|       2.0|       16.35|
|       2|       13.5|  1.0|      3.85|        23.1|
|       2|       12.1|  1.0|      3.57|       21.42|
|       2|        8.6|  1.0|       0.0|       14.35|
|       2|        6.5|  1.0|       2.0|       14.25|
|       1|       70.0|  5.0|       0.0|       83.44|
|       2|      101.7|  1.0|       0.0|      112.89|
|       2|       17.7|  1.0|      4.69|       28.14|
|       2|        7.9|  1.0|      2.05|       

We can select columns with Pandas's style too

In [28]:
df['VendorID', 'total_amount'].show()

+--------+------------+
|VendorID|total_amount|
+--------+------------+
|       1|        15.5|
|       1|        13.8|
|       2|       25.81|
|       2|       15.54|
|       1|        17.2|
|       1|        20.8|
|       2|        27.3|
|       2|       16.35|
|       2|        23.1|
|       2|       21.42|
|       2|       14.35|
|       2|       14.25|
|       1|       83.44|
|       2|      112.89|
|       2|       28.14|
|       2|        15.7|
|       2|       18.15|
|       1|       16.35|
|       1|       24.75|
|       2|       12.95|
+--------+------------+
only showing top 20 rows



Or an Attribute style

In [29]:
df.select(df.VendorID, df.total_amount).show()

+--------+------------+
|VendorID|total_amount|
+--------+------------+
|       1|        15.5|
|       1|        13.8|
|       2|       25.81|
|       2|       15.54|
|       1|        17.2|
|       1|        20.8|
|       2|        27.3|
|       2|       16.35|
|       2|        23.1|
|       2|       21.42|
|       2|       14.35|
|       2|       14.25|
|       1|       83.44|
|       2|      112.89|
|       2|       28.14|
|       2|        15.7|
|       2|       18.15|
|       1|       16.35|
|       1|       24.75|
|       2|       12.95|
+--------+------------+
only showing top 20 rows



Dataframe.limit(): Limits the result count of a Dataframe

In [61]:
df.select(df.VendorID, df.total_amount).limit(10).show()

+--------+------------+
|VendorID|total_amount|
+--------+------------+
|       1|        15.5|
|       1|        13.8|
|       2|       25.81|
|       2|       15.54|
|       1|        17.2|
|       1|        20.8|
|       2|        27.3|
|       2|       16.35|
|       2|        23.1|
|       2|       21.42|
+--------+------------+



DataFrame.collect() returns all the rows of the DataFrame

In [64]:
df.select(df.VendorID, df.total_amount).limit(10).collect()

[Row(VendorID=1, total_amount=15.5),
 Row(VendorID=1, total_amount=13.8),
 Row(VendorID=2, total_amount=25.81),
 Row(VendorID=2, total_amount=15.54),
 Row(VendorID=1, total_amount=17.2),
 Row(VendorID=1, total_amount=20.8),
 Row(VendorID=2, total_amount=27.3),
 Row(VendorID=2, total_amount=16.35),
 Row(VendorID=2, total_amount=23.1),
 Row(VendorID=2, total_amount=21.42)]

We can return a column using the col() function

In [32]:
df.select(col(col='VendorID'), col('total_amount')).show()

+--------+------------+
|VendorID|total_amount|
+--------+------------+
|       1|        15.5|
|       1|        13.8|
|       2|       25.81|
|       2|       15.54|
|       1|        17.2|
|       1|        20.8|
|       2|        27.3|
|       2|       16.35|
|       2|        23.1|
|       2|       21.42|
|       2|       14.35|
|       2|       14.25|
|       1|       83.44|
|       2|      112.89|
|       2|       28.14|
|       2|        15.7|
|       2|       18.15|
|       1|       16.35|
|       1|       24.75|
|       2|       12.95|
+--------+------------+
only showing top 20 rows



Dataframe Aliases (another name like in sql), they are useful to refer a column within a dataframe in joins

In [33]:
df_as = df.alias('nyc')

In [34]:
df_as.select(col('nyc.VendorID'), col('nyc.total_amount')).show()

+--------+------------+
|VendorID|total_amount|
+--------+------------+
|       1|        15.5|
|       1|        13.8|
|       2|       25.81|
|       2|       15.54|
|       1|        17.2|
|       1|        20.8|
|       2|        27.3|
|       2|       16.35|
|       2|        23.1|
|       2|       21.42|
|       2|       14.35|
|       2|       14.25|
|       1|       83.44|
|       2|      112.89|
|       2|       28.14|
|       2|        15.7|
|       2|       18.15|
|       1|       16.35|
|       1|       24.75|
|       2|       12.95|
+--------+------------+
only showing top 20 rows



Columns Aliases

In [35]:
df.select('VendorID', col('total_amount').alias('TOTAL_ALIAS')).show()

+--------+-----------+
|VendorID|TOTAL_ALIAS|
+--------+-----------+
|       1|       15.5|
|       1|       13.8|
|       2|      25.81|
|       2|      15.54|
|       1|       17.2|
|       1|       20.8|
|       2|       27.3|
|       2|      16.35|
|       2|       23.1|
|       2|      21.42|
|       2|      14.35|
|       2|      14.25|
|       1|      83.44|
|       2|     112.89|
|       2|      28.14|
|       2|       15.7|
|       2|      18.15|
|       1|      16.35|
|       1|      24.75|
|       2|      12.95|
+--------+-----------+
only showing top 20 rows



Sorting Data
We can sort the data with the .sort() method, any number of columns can be used

In [36]:
df.select('VendorID', col('total_amount').alias('TOTAL_ALIAS')) \
    .sort(['VendorID', 'TOTAL_ALIAS'], ascending=[True, False]).show()

+--------+-----------+
|VendorID|TOTAL_ALIAS|
+--------+-----------+
|       1|   46269.44|
|       1|     1009.0|
|       1|    1003.62|
|       1|     832.74|
|       1|     602.75|
|       1|     551.75|
|       1|      547.2|
|       1|     521.69|
|       1|     516.69|
|       1|     514.44|
|       1|      512.0|
|       1|     502.99|
|       1|      501.0|
|       1|      475.5|
|       1|     472.45|
|       1|     467.29|
|       1|     451.98|
|       1|      451.0|
|       1|      451.0|
|       1|     440.59|
+--------+-----------+
only showing top 20 rows



## Filter Operations

We can use the DataFrame.filter() or DataFrame.where() methods to filter the data, .where() is an alias for filter(). This methods are used alongside filter operators to create a column expression that represents a `Filter Condition`.

Some Filter Operators:

| Operation        | Example (Expression)                    | Equivalent SQL |
| ---------------- | --------------------------------------- | ------------------- |
| Equal to         | `col("age") == 30`                      | `age = 30`          |
| Not equal        | `col("age") != 30`                      | `age <> 30`         |
| Greater than     | `col("age") > 18`                       | `age > 18`          |
| Less than        | `col("age") < 65`                       | `age < 65`          |
| Greater or equal | `col("age") >= 21`                      | `age >= 21`         |
| Less or equal    | `col("age") <= 65`                      | `age <= 65`         |
| And              | `(col("age") > 18) & (col("age") < 65)` | `AND`               |
| Or               | `(col("age") > 18) \| (col("age") < 65)`| `OR`                |
| Not              | `~(col("active") == True)`              | `NOT active = true` |


In [37]:
df.select('VendorID', 'total_amount').where('total_amount > 100').show()

+--------+------------+
|VendorID|total_amount|
+--------+------------+
|       2|      112.89|
|       2|      153.85|
|       2|      103.86|
|       2|      100.13|
|       1|      121.75|
|       2|      106.97|
|       1|      100.39|
|       1|      101.89|
|       2|      100.13|
|       2|      148.26|
|       2|      238.89|
|       2|      104.21|
|       1|      100.09|
|       2|      103.86|
|       2|      112.75|
|       2|      145.83|
|       2|      101.15|
|       2|       124.2|
|       2|      160.62|
|       2|      103.86|
+--------+------------+
only showing top 20 rows



In [38]:
df.select('VendorID', 'total_amount').filter(df.VendorID == 1).show()

+--------+------------+
|VendorID|total_amount|
+--------+------------+
|       1|        15.5|
|       1|        13.8|
|       1|        17.2|
|       1|        20.8|
|       1|       83.44|
|       1|       16.35|
|       1|       24.75|
|       1|        55.0|
|       1|       18.45|
|       1|       19.15|
|       1|       18.85|
|       1|       14.35|
|       1|        19.7|
|       1|        15.5|
|       1|        15.5|
|       1|        25.6|
|       1|        17.9|
|       1|       17.95|
|       1|       22.25|
|       1|       20.55|
+--------+------------+
only showing top 20 rows



Multiple Operators

In [39]:
df.select('VendorID', 'payment_type', 'total_amount').filter((df.VendorID == 2) & (df.total_amount > 1000)).show()

+--------+------------+------------+
|VendorID|payment_type|total_amount|
+--------+------------+------------+
|       2|           2|     1694.31|
|       2|           1|     1015.55|
|       2|           2|     1059.21|
|       2|           1|      1171.3|
|       2|           1|      1081.2|
|       2|           1|      1189.2|
|       2|           1|     1175.52|
|       2|           1|     1240.52|
|       2|           2|     1183.39|
|       2|           1|     1004.25|
|       2|           1|     1004.25|
|       2|           1|     1226.25|
|       2|           1|      1249.3|
|       2|           1|     2246.55|
+--------+------------+------------+



In [40]:
df.select('VendorID', 'payment_type', 'total_amount') \
    .filter((df.VendorID == 2) & (df.total_amount > 1000) & (df.payment_type == 1)).show()

+--------+------------+------------+
|VendorID|payment_type|total_amount|
+--------+------------+------------+
|       2|           1|     1015.55|
|       2|           1|      1171.3|
|       2|           1|      1081.2|
|       2|           1|      1189.2|
|       2|           1|     1175.52|
|       2|           1|     1240.52|
|       2|           1|     1004.25|
|       2|           1|     1004.25|
|       2|           1|     1226.25|
|       2|           1|      1249.3|
|       2|           1|     2246.55|
+--------+------------+------------+



In [41]:
df.select('VendorID', 'total_amount') \
    .filter(((df.VendorID == 2) & (df.total_amount > 1000)) | (df.total_amount > 2000)).show()

+--------+------------+
|VendorID|total_amount|
+--------+------------+
|       2|     1694.31|
|       2|     1015.55|
|       2|     1059.21|
|       1|    46269.44|
|       2|      1171.3|
|       2|      1081.2|
|       2|      1189.2|
|       2|     1175.52|
|       2|     1240.52|
|       2|     1183.39|
|       2|     1004.25|
|       2|     1004.25|
|       2|     1226.25|
|       2|      1249.3|
|       2|     2246.55|
+--------+------------+



The Between operator

In [42]:
df.select(df.VendorID, df.total_amount).where(df.total_amount.between(15, 16)).show()

+--------+------------+
|VendorID|total_amount|
+--------+------------+
|       1|        15.5|
|       2|       15.54|
|       2|        15.7|
|       2|       15.29|
|       2|       15.93|
|       1|        15.5|
|       1|        15.5|
|       2|       15.54|
|       1|        15.0|
|       2|       15.48|
|       2|       15.75|
|       2|       15.54|
|       2|       15.75|
|       1|        15.5|
|       2|        15.8|
|       2|       15.65|
|       1|       15.75|
|       2|       15.01|
|       1|        15.5|
|       2|       15.37|
+--------+------------+
only showing top 20 rows



Filtering null values can be done with .isnull(), .isNotNull() and .isNaN() methods

In [43]:
df['VendorID', 'Airport_fee', 'total_amount'].where(col('Airport_fee').isNull()).show()

+--------+-----------+------------+
|VendorID|Airport_fee|total_amount|
+--------+-----------+------------+
|       2|       NULL|        4.76|
|       2|       NULL|       18.95|
|       2|       NULL|         2.0|
|       2|       NULL|        43.4|
|       2|       NULL|       39.23|
|       1|       NULL|       18.79|
|       2|       NULL|        5.29|
|       2|       NULL|       20.77|
|       2|       NULL|       50.34|
|       1|       NULL|       48.98|
|       2|       NULL|       19.38|
|       2|       NULL|        2.97|
|       2|       NULL|       15.63|
|       2|       NULL|         1.0|
|       2|       NULL|       37.81|
|       2|       NULL|       38.14|
|       1|       NULL|        30.3|
|       2|       NULL|       15.71|
|       2|       NULL|        6.31|
|       2|       NULL|       35.54|
+--------+-----------+------------+
only showing top 20 rows



Filter text in a column with .contains()

In [44]:
df['VendorID', 'Airport_fee', 'total_amount', 'tpep_pickup_datetime'] \
    .where(col('tpep_pickup_datetime').contains('2025-03-05')).show()

+--------+-----------+------------+--------------------+
|VendorID|Airport_fee|total_amount|tpep_pickup_datetime|
+--------+-----------+------------+--------------------+
|       1|       1.75|       57.25| 2025-03-05 00:17:44|
|       1|        0.0|       16.35| 2025-03-05 00:02:33|
|       2|        0.0|       14.44| 2025-03-05 00:17:56|
|       2|        0.0|       19.25| 2025-03-05 00:39:48|
|       2|        0.0|       17.16| 2025-03-05 00:20:42|
|       2|        0.0|       16.32| 2025-03-05 00:00:03|
|       2|        0.0|        30.6| 2025-03-05 00:13:04|
|       2|        0.0|        -8.0| 2025-03-05 00:45:09|
|       2|        0.0|         8.0| 2025-03-05 00:45:09|
|       2|        0.0|       23.94| 2025-03-05 00:22:02|
|       2|        0.0|       17.22| 2025-03-05 00:29:21|
|       2|       1.75|       71.74| 2025-03-05 00:09:59|
|       2|       1.75|      147.72| 2025-03-05 00:26:52|
|       1|        0.0|       26.45| 2025-03-05 00:32:22|
|       2|        0.0|        1

SQL LIKE pattern filter

In [45]:
df['VendorID', 'Airport_fee', 'total_amount', 'tpep_pickup_datetime'] \
    .where(col('tpep_pickup_datetime').like('%51%')).show()

+--------+-----------+------------+--------------------+
|VendorID|Airport_fee|total_amount|tpep_pickup_datetime|
+--------+-----------+------------+--------------------+
|       2|        0.0|       18.15| 2025-03-01 00:41:51|
|       2|        0.0|       38.99| 2025-03-01 00:51:48|
|       2|        0.0|       18.06| 2025-03-01 00:51:23|
|       2|        0.0|       48.56| 2025-03-01 00:44:51|
|       1|        0.0|       26.45| 2025-03-01 00:16:51|
|       1|        0.0|       18.55| 2025-03-01 00:56:51|
|       1|        0.0|        14.0| 2025-03-01 00:51:04|
|       2|        0.0|       21.42| 2025-03-01 00:51:30|
|       2|        0.0|       16.38| 2025-03-01 00:51:14|
|       2|        0.0|       51.09| 2025-03-01 00:02:51|
|       2|        0.0|       14.95| 2025-03-01 00:22:51|
|       2|        0.0|       25.62| 2025-03-01 00:03:51|
|       2|        0.0|       15.65| 2025-03-01 00:51:35|
|       2|        0.0|       26.46| 2025-03-01 00:51:47|
|       2|        0.0|       21

Case insensitive SQL LIKE pattern filter

In [46]:
df.select(df.VendorID, df.store_and_fwd_flag).where(col('store_and_fwd_flag').like('n')).show()
df.select(df.VendorID, df.store_and_fwd_flag).where(col('store_and_fwd_flag').ilike('n')).show()

+--------+------------------+
|VendorID|store_and_fwd_flag|
+--------+------------------+
+--------+------------------+

+--------+------------------+
|VendorID|store_and_fwd_flag|
+--------+------------------+
|       1|                 N|
|       1|                 N|
|       2|                 N|
|       2|                 N|
|       1|                 N|
|       1|                 N|
|       2|                 N|
|       2|                 N|
|       2|                 N|
|       2|                 N|
|       2|                 N|
|       2|                 N|
|       1|                 N|
|       2|                 N|
|       2|                 N|
|       2|                 N|
|       2|                 N|
|       1|                 N|
|       1|                 N|
|       2|                 N|
+--------+------------------+
only showing top 20 rows



SQL LIKE pattern with REGEX filter

In [47]:
df['VendorID', 'Airport_fee', 'total_amount', 'tpep_pickup_datetime'] \
    .where(col('tpep_pickup_datetime').rlike('.51$')).show()

+--------+-----------+------------+--------------------+
|VendorID|Airport_fee|total_amount|tpep_pickup_datetime|
+--------+-----------+------------+--------------------+
|       2|        0.0|       18.15| 2025-03-01 00:41:51|
|       2|        0.0|       48.56| 2025-03-01 00:44:51|
|       1|        0.0|       26.45| 2025-03-01 00:16:51|
|       1|        0.0|       18.55| 2025-03-01 00:56:51|
|       2|        0.0|       51.09| 2025-03-01 00:02:51|
|       2|        0.0|       14.95| 2025-03-01 00:22:51|
|       2|        0.0|       25.62| 2025-03-01 00:03:51|
|       2|        0.0|       21.75| 2025-03-01 00:50:51|
|       2|        0.0|        35.7| 2025-03-01 00:27:51|
|       2|        0.0|       16.23| 2025-03-01 00:20:51|
|       2|        0.0|       34.02| 2025-03-01 00:11:51|
|       1|        0.0|        15.3| 2025-03-01 00:51:51|
|       2|        0.0|       15.05| 2025-03-01 00:17:51|
|       2|        0.0|        18.9| 2025-03-01 00:35:51|
|       1|        0.0|       17

Startswith, filter strings that starts with the specified string filter without using LIKE or REGEX

In [48]:
df['VendorID', 'Airport_fee', 'total_amount', 'tpep_pickup_datetime'] \
    .where(col('tpep_pickup_datetime').startswith('2025-03-30 23')).show()

+--------+-----------+------------+--------------------+
|VendorID|Airport_fee|total_amount|tpep_pickup_datetime|
+--------+-----------+------------+--------------------+
|       2|        0.0|        27.3| 2025-03-30 23:59:06|
|       2|        0.0|       20.58| 2025-03-30 23:00:05|
|       1|       1.75|        71.0| 2025-03-30 23:18:41|
|       2|       1.75|       53.05| 2025-03-30 23:55:58|
|       1|       1.75|        91.8| 2025-03-30 23:05:01|
|       1|        0.0|       12.25| 2025-03-30 23:48:17|
|       2|        0.0|       26.46| 2025-03-30 23:00:27|
|       2|        0.0|        14.7| 2025-03-30 23:48:30|
|       1|        0.0|        19.7| 2025-03-30 23:33:25|
|       1|        0.0|       14.35| 2025-03-30 23:07:44|
|       1|        0.0|       13.65| 2025-03-30 23:17:18|
|       2|        0.0|       24.15| 2025-03-30 23:22:33|
|       1|        0.0|        11.5| 2025-03-30 23:04:25|
|       2|        0.0|       29.58| 2025-03-30 23:01:48|
|       1|        0.0|       17

Filter unique rows, just like `select distinct`

In [49]:
df.select('VendorID').distinct().show()

+--------+
|VendorID|
+--------+
|       1|
|       7|
|       2|
|       6|
+--------+



## Handling columns

Create columns

Columns Aliases refresh

In [50]:
df.select('VendorID', col('total_amount').alias('TOTAL_ALIAS')).show()

+--------+-----------+
|VendorID|TOTAL_ALIAS|
+--------+-----------+
|       1|       15.5|
|       1|       13.8|
|       2|      25.81|
|       2|      15.54|
|       1|       17.2|
|       1|       20.8|
|       2|       27.3|
|       2|      16.35|
|       2|       23.1|
|       2|      21.42|
|       2|      14.35|
|       2|      14.25|
|       1|      83.44|
|       2|     112.89|
|       2|      28.14|
|       2|       15.7|
|       2|      18.15|
|       1|      16.35|
|       1|      24.75|
|       2|      12.95|
+--------+-----------+
only showing top 20 rows



Drop columns: The Dataframe.Drop() method returns a new dataframe without the columns specified in it

In [51]:
df.select('VendorID', 'total_amount').drop('total_amount').show()

+--------+
|VendorID|
+--------+
|       1|
|       1|
|       2|
|       2|
|       1|
|       1|
|       2|
|       2|
|       2|
|       2|
|       2|
|       2|
|       1|
|       2|
|       2|
|       2|
|       2|
|       1|
|       1|
|       2|
+--------+
only showing top 20 rows



## Aggregations

Group By functions: COUNT, SUM, MIN, MAX, AVG, etc

Dataframe.GroupBy returns a DataFrameGroupBy class that is used to create aggregations

In [52]:
print(df.groupBy('VendorID'))

GroupedData[grouping expressions: [VendorID], value: [VendorID: int, tpep_pickup_datetime: timestamp_ntz ... 18 more fields], type: GroupBy]


Count Aggregation: counts the rows by the groupby class

In [53]:
df.groupBy('VendorID').count().show()

+--------+-------+
|VendorID|  count|
+--------+-------+
|       1| 834394|
|       7|  21481|
|       2|3289048|
|       6|    334|
+--------+-------+



Mean

In [54]:
df.groupBy('VendorID').mean().show()

+--------+-------------+--------------------+------------------+------------------+------------------+------------------+------------------+------------------+------------------+-------------------+------------------+-------------------+--------------------------+------------------+-------------------------+-------------------+-----------------------+
|VendorID|avg(VendorID)|avg(passenger_count)|avg(trip_distance)|   avg(RatecodeID)| avg(PULocationID)| avg(DOLocationID)| avg(payment_type)|  avg(fare_amount)|        avg(extra)|       avg(mta_tax)|   avg(tip_amount)|  avg(tolls_amount)|avg(improvement_surcharge)| avg(total_amount)|avg(congestion_surcharge)|   avg(Airport_fee)|avg(cbd_congestion_fee)|
+--------+-------------+--------------------+------------------+------------------+------------------+------------------+------------------+------------------+------------------+-------------------+------------------+-------------------+--------------------------+------------------+---------

                                                                                

The last output showed every column, remember to select first

In [55]:
df.select('VendorID', 'total_amount').groupBy('VendorID').mean('total_amount').show()

+--------+------------------+
|VendorID| avg(total_amount)|
+--------+------------------+
|       1| 26.81411121125252|
|       7|25.361795540244803|
|       2| 26.13239850861697|
|       6|29.504341317365295|
+--------+------------------+



Grouping without .select()

In [56]:
df.groupBy('VendorID').mean('total_amount').show()

+--------+------------------+
|VendorID| avg(total_amount)|
+--------+------------------+
|       1| 26.81411121125252|
|       7|25.361795540244803|
|       2| 26.13239850861697|
|       6|29.504341317365295|
+--------+------------------+



We can compute multiple aggregations across the entire dataframe with the .agg() method, this is also an alias for DataFrame.groupBy()

In [57]:
df.agg({ 'total_amount': 'avg', 'trip_distance': 'min', 'extra': 'max', \
        'VendorID': 'count', 'tolls_amount': 'sum'}).show()

+------------------+-----------------+---------------+----------+------------------+
|min(trip_distance)|sum(tolls_amount)|count(VendorID)|max(extra)| avg(total_amount)|
+------------------+-----------------+---------------+----------+------------------+
|               0.0|1968899.629999516|        4145257|      13.5|26.265898046844185|
+------------------+-----------------+---------------+----------+------------------+



## Join and Union operations

Spark has **join types**, just like SQL, these are the ways that we can **join** (mix and match) the data between DataFrames(tables). Some examples: Inner, left, right, full outer, etc

Spark also has **Join Srategies** these are the ways or algorithms used to shuffle, mix and match the data across nodes and executors in the spark cluster. Examples: Broadcast, sort merge, shuffle, etc; we can read about these in the [performance tuning documentation](!https://spark.apache.org/docs/latest/sql-performance-tuning.html)

We can join dataframes using `DataFrame.join(other, on=None, how=None)`, `inner` is the default value of the `how` parameter, and `other` is the  right side dataframe

In [58]:
# Prepare dataframes
df_a = df.select(df.VendorID, df.trip_distance).limit(1)
df_b = df.select(df.VendorID, df.total_amount).limit(1)
df_join = df_a.join(df_b, 'VendorID', 'inner')
df_join.show()

+--------+-------------+------------+
|VendorID|trip_distance|total_amount|
+--------+-------------+------------+
|       1|          0.9|        15.5|
+--------+-------------+------------+



Union concatenates DataFrames with the same schema, using the column's position, like SQL UNION, but with the caveat that there isn't automatic deduplication of rows, this makes spark's union closer to SQL UNION ALL, the SQL UNION behavior can be archieved with .union().distinct()

In this example both dataframes have integer and double columns, making union possible, the output columns are VendorID and trip_distance because df_a was the main DataFrame

In [71]:
df_union = df_a.union(df_b)
df_union.show()

+--------+-------------+
|VendorID|trip_distance|
+--------+-------------+
|       1|          0.9|
|       1|         15.5|
+--------+-------------+



.unionAll() ALL is an alias to .union() in pyspark

In [75]:
print(f'Normal spark Union (SQL Union All): {df_a.unionAll(df_union).count()}')
print(f'Deduplicated spark Union (Union): {df_a.unionAll(df_union).distinct().count()}')

Normal spark Union (SQL Union All): 3
Deduplicated spark Union (Union): 2


Shut down the current SparkSession

In [60]:
# spark.stop()