 Referring to Columns in PySpark
In PySpark, referencing columns is essential for filtering, selecting, transforming, and performing other DataFrame operations. Unlike SQL, PySpark provides several options for referring to columns, each suited to different tasks. Letâ€™s explore these approaches with examples across common operations, such as filtering, selecting, and applying transformations.

In [2]:
from pyspark.sql import SparkSession

# Initiate the SparkSession - you're basically summoning Spark's power!
spark = SparkSession.builder \
    .appName("PySpark 101") \
    .getOrCreate()

print("Spark's in the house! ðŸ”¥")


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/11/07 18:30:10 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/11/07 18:30:11 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
25/11/07 18:30:11 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.
25/11/07 18:30:11 WARN Utils: Service 'SparkUI' could not bind on port 4042. Attempting port 4043.


Spark's in the house! ðŸ”¥


In [3]:
# Read employee data
df = spark.read.csv("../data/boston.csv", header=True, inferSchema=True)

# Show the first 5 rows
df.show(5)



+-------+----+-----+----+-----+-----+----+------+---+---+-------+------+-----+-----+
|   CRIM|  ZN|INDUS|CHAS|  NOX|   RM| AGE|   DIS|RAD|TAX|PTRATIO|     B|LSTAT|Price|
+-------+----+-----+----+-----+-----+----+------+---+---+-------+------+-----+-----+
|0.00632|18.0| 2.31|   0|0.538|6.575|65.2|  4.09|  1|296|   15.3| 396.9| 4.98| 24.0|
|0.02731| 0.0| 7.07|   0|0.469|6.421|78.9|4.9671|  2|242|   17.8| 396.9| 9.14| 21.6|
|0.02729| 0.0| 7.07|   0|0.469|7.185|61.1|4.9671|  2|242|   17.8|392.83| 4.03| 34.7|
|0.03237| 0.0| 2.18|   0|0.458|6.998|45.8|6.0622|  3|222|   18.7|394.63| 2.94| 33.4|
|0.06905| 0.0| 2.18|   0|0.458|7.147|54.2|6.0622|  3|222|   18.7| 396.9| 5.33| 36.2|
+-------+----+-----+----+-----+-----+----+------+---+---+-------+------+-----+-----+
only showing top 5 rows



In [10]:
# Selecting columns by name
df.select("CRIM", "RM").show(2)

# Filtering based on a column condition
df.filter("AGE > 30").show(2)

+-------+-----+
|   CRIM|   RM|
+-------+-----+
|0.00632|6.575|
|0.02731|6.421|
+-------+-----+
only showing top 2 rows

+-------+----+-----+----+-----+-----+----+------+---+---+-------+-----+-----+-----+
|   CRIM|  ZN|INDUS|CHAS|  NOX|   RM| AGE|   DIS|RAD|TAX|PTRATIO|    B|LSTAT|Price|
+-------+----+-----+----+-----+-----+----+------+---+---+-------+-----+-----+-----+
|0.00632|18.0| 2.31|   0|0.538|6.575|65.2|  4.09|  1|296|   15.3|396.9| 4.98| 24.0|
|0.02731| 0.0| 7.07|   0|0.469|6.421|78.9|4.9671|  2|242|   17.8|396.9| 9.14| 21.6|
+-------+----+-----+----+-----+-----+----+------+---+---+-------+-----+-----+-----+
only showing top 2 rows



In [11]:
# Selecting columns using dot notation
df.select(df.CRIM, df.RM).show(2)

# Filtering rows based on column conditions
df.filter(df.AGE > 30).show(2)

+-------+-----+
|   CRIM|   RM|
+-------+-----+
|0.00632|6.575|
|0.02731|6.421|
+-------+-----+
only showing top 2 rows

+-------+----+-----+----+-----+-----+----+------+---+---+-------+-----+-----+-----+
|   CRIM|  ZN|INDUS|CHAS|  NOX|   RM| AGE|   DIS|RAD|TAX|PTRATIO|    B|LSTAT|Price|
+-------+----+-----+----+-----+-----+----+------+---+---+-------+-----+-----+-----+
|0.00632|18.0| 2.31|   0|0.538|6.575|65.2|  4.09|  1|296|   15.3|396.9| 4.98| 24.0|
|0.02731| 0.0| 7.07|   0|0.469|6.421|78.9|4.9671|  2|242|   17.8|396.9| 9.14| 21.6|
+-------+----+-----+----+-----+-----+----+------+---+---+-------+-----+-----+-----+
only showing top 2 rows



In [12]:
from pyspark.sql.functions import col

# Selecting columns using col()
df.select(col("CRIM"), col("AGE")).show(2)

# Filtering rows using col() for flexibility
age_column = "AGE"
df.filter(col(age_column) > 30).show(2)

+-------+----+
|   CRIM| AGE|
+-------+----+
|0.00632|65.2|
|0.02731|78.9|
+-------+----+
only showing top 2 rows

+-------+----+-----+----+-----+-----+----+------+---+---+-------+-----+-----+-----+
|   CRIM|  ZN|INDUS|CHAS|  NOX|   RM| AGE|   DIS|RAD|TAX|PTRATIO|    B|LSTAT|Price|
+-------+----+-----+----+-----+-----+----+------+---+---+-------+-----+-----+-----+
|0.00632|18.0| 2.31|   0|0.538|6.575|65.2|  4.09|  1|296|   15.3|396.9| 4.98| 24.0|
|0.02731| 0.0| 7.07|   0|0.469|6.421|78.9|4.9671|  2|242|   17.8|396.9| 9.14| 21.6|
+-------+----+-----+----+-----+-----+----+------+---+---+-------+-----+-----+-----+
only showing top 2 rows



In [15]:
# Selecting columns
df.select(df["CRIM"], df["AGE"]).show(2)

# Filtering with expressions
df.filter(df["AGE"] > 30).show(2)

# Applying transformations
df.select((df["age"] + 10).alias("age_plus_10")).show(2)


+-------+----+
|   CRIM| AGE|
+-------+----+
|0.00632|65.2|
|0.02731|78.9|
+-------+----+
only showing top 2 rows

+-------+----+-----+----+-----+-----+----+------+---+---+-------+-----+-----+-----+
|   CRIM|  ZN|INDUS|CHAS|  NOX|   RM| AGE|   DIS|RAD|TAX|PTRATIO|    B|LSTAT|Price|
+-------+----+-----+----+-----+-----+----+------+---+---+-------+-----+-----+-----+
|0.00632|18.0| 2.31|   0|0.538|6.575|65.2|  4.09|  1|296|   15.3|396.9| 4.98| 24.0|
|0.02731| 0.0| 7.07|   0|0.469|6.421|78.9|4.9671|  2|242|   17.8|396.9| 9.14| 21.6|
+-------+----+-----+----+-----+-----+----+------+---+---+-------+-----+-----+-----+
only showing top 2 rows

+-----------+
|age_plus_10|
+-----------+
|       75.2|
|       88.9|
+-----------+
only showing top 2 rows



In [17]:
from pyspark.sql.functions import lit

# Adding a constant column
df.select(df.CRIM, lit(25).alias("constant_age")).show(2)

# Filtering based on a constant
df.filter(df.AGE > lit(30)).show(2)

+-------+------------+
|   CRIM|constant_age|
+-------+------------+
|0.00632|          25|
|0.02731|          25|
+-------+------------+
only showing top 2 rows

+-------+----+-----+----+-----+-----+----+------+---+---+-------+-----+-----+-----+
|   CRIM|  ZN|INDUS|CHAS|  NOX|   RM| AGE|   DIS|RAD|TAX|PTRATIO|    B|LSTAT|Price|
+-------+----+-----+----+-----+-----+----+------+---+---+-------+-----+-----+-----+
|0.00632|18.0| 2.31|   0|0.538|6.575|65.2|  4.09|  1|296|   15.3|396.9| 4.98| 24.0|
|0.02731| 0.0| 7.07|   0|0.469|6.421|78.9|4.9671|  2|242|   17.8|396.9| 9.14| 21.6|
+-------+----+-----+----+-----+-----+----+------+---+---+-------+-----+-----+-----+
only showing top 2 rows



In [19]:
# Selecting a single column
selected_df = df.select("CRIM")

# Selecting multiple columns
selected_df = df.select("CRIM", "AGE", "NOX")
selected_df.show(2)


+-------+----+-----+
|   CRIM| AGE|  NOX|
+-------+----+-----+
|0.00632|65.2|0.538|
|0.02731|78.9|0.469|
+-------+----+-----+
only showing top 2 rows



In [21]:
from pyspark.sql.functions import col

# Selecting columns using col()
selected_df = df.select(col("CRIM"), col("AGE"), col("NOX"))
selected_df.show(2)

+-------+----+-----+
|   CRIM| AGE|  NOX|
+-------+----+-----+
|0.00632|65.2|0.538|
|0.02731|78.9|0.469|
+-------+----+-----+
only showing top 2 rows



In [22]:
from pyspark.sql.functions import col

# Adding a new column with a static value
df_with_constant = df.withColumn("CRIM", lit("active"))

# Adding a new column based on an existing one
df_with_computed = df.withColumn("age_plus_5", col("AGE") + 5)
df_with_constant.show(2)
df_with_computed.show(2)

+------+----+-----+----+-----+-----+----+------+---+---+-------+-----+-----+-----+
|  CRIM|  ZN|INDUS|CHAS|  NOX|   RM| AGE|   DIS|RAD|TAX|PTRATIO|    B|LSTAT|Price|
+------+----+-----+----+-----+-----+----+------+---+---+-------+-----+-----+-----+
|active|18.0| 2.31|   0|0.538|6.575|65.2|  4.09|  1|296|   15.3|396.9| 4.98| 24.0|
|active| 0.0| 7.07|   0|0.469|6.421|78.9|4.9671|  2|242|   17.8|396.9| 9.14| 21.6|
+------+----+-----+----+-----+-----+----+------+---+---+-------+-----+-----+-----+
only showing top 2 rows

+-------+----+-----+----+-----+-----+----+------+---+---+-------+-----+-----+-----+----------+
|   CRIM|  ZN|INDUS|CHAS|  NOX|   RM| AGE|   DIS|RAD|TAX|PTRATIO|    B|LSTAT|Price|age_plus_5|
+-------+----+-----+----+-----+-----+----+------+---+---+-------+-----+-----+-----+----------+
|0.00632|18.0| 2.31|   0|0.538|6.575|65.2|  4.09|  1|296|   15.3|396.9| 4.98| 24.0|      70.2|
|0.02731| 0.0| 7.07|   0|0.469|6.421|78.9|4.9671|  2|242|   17.8|396.9| 9.14| 21.6|      83.9|
+-

In [24]:
# Renaming columns
selected_df = df.select(col("CRIM").alias("location_crim"), col("AGE").alias("age_of_property"), col("NOX").alias("nitric_oxide_level"))
selected_df.show(2)

+-------------+---------------+------------------+
|location_crim|age_of_property|nitric_oxide_level|
+-------------+---------------+------------------+
|      0.00632|           65.2|             0.538|
|      0.02731|           78.9|             0.469|
+-------------+---------------+------------------+
only showing top 2 rows



In [26]:
from pyspark.sql.functions import expr

# Selecting with expressions
selected_df = df.select(expr("AGE + 1 AS next_year_age"), "Price")
selected_df.show(2)

+-------------+-----+
|next_year_age|Price|
+-------------+-----+
|         66.2| 24.0|
|         79.9| 21.6|
+-------------+-----+
only showing top 2 rows



In [27]:
# Selecting all columns except 'salary'
selected_df = df.select("*").drop("CRIM")
selected_df.show(2)

+----+-----+----+-----+-----+----+------+---+---+-------+-----+-----+-----+
|  ZN|INDUS|CHAS|  NOX|   RM| AGE|   DIS|RAD|TAX|PTRATIO|    B|LSTAT|Price|
+----+-----+----+-----+-----+----+------+---+---+-------+-----+-----+-----+
|18.0| 2.31|   0|0.538|6.575|65.2|  4.09|  1|296|   15.3|396.9| 4.98| 24.0|
| 0.0| 7.07|   0|0.469|6.421|78.9|4.9671|  2|242|   17.8|396.9| 9.14| 21.6|
+----+-----+----+-----+-----+----+------+---+---+-------+-----+-----+-----+
only showing top 2 rows



In [29]:
# Registering the DataFrame as a temporary view
df.createOrReplaceTempView("Boston")

# Using SQL to select columns
selected_df = spark.sql("SELECT CRIM, AGE FROM Boston")
selected_df.show(2)

+-------+----+
|   CRIM| AGE|
+-------+----+
|0.00632|65.2|
|0.02731|78.9|
+-------+----+
only showing top 2 rows



In [31]:
# List of columns to select
columns_to_select = ["CRIM", "AGE"]

# Dynamically selecting columns
selected_df = df.select(*columns_to_select)
selected_df.show(2)

+-------+----+
|   CRIM| AGE|
+-------+----+
|0.00632|65.2|
|0.02731|78.9|
+-------+----+
only showing top 2 rows



In [32]:
# Keeping all columns except 'salary' and 'address'
selected_df = df.drop("AGE", "CRIM")
selected_df.show(2)

+----+-----+----+-----+-----+------+---+---+-------+-----+-----+-----+
|  ZN|INDUS|CHAS|  NOX|   RM|   DIS|RAD|TAX|PTRATIO|    B|LSTAT|Price|
+----+-----+----+-----+-----+------+---+---+-------+-----+-----+-----+
|18.0| 2.31|   0|0.538|6.575|  4.09|  1|296|   15.3|396.9| 4.98| 24.0|
| 0.0| 7.07|   0|0.469|6.421|4.9671|  2|242|   17.8|396.9| 9.14| 21.6|
+----+-----+----+-----+-----+------+---+---+-------+-----+-----+-----+
only showing top 2 rows



In [33]:
# Example of basic filtering
filtered_df = df.filter(df["AGE"] > 30)
# or
filtered_df = df.where("AGE > 30")

filtered_df.show(2)


+-------+----+-----+----+-----+-----+----+------+---+---+-------+-----+-----+-----+
|   CRIM|  ZN|INDUS|CHAS|  NOX|   RM| AGE|   DIS|RAD|TAX|PTRATIO|    B|LSTAT|Price|
+-------+----+-----+----+-----+-----+----+------+---+---+-------+-----+-----+-----+
|0.00632|18.0| 2.31|   0|0.538|6.575|65.2|  4.09|  1|296|   15.3|396.9| 4.98| 24.0|
|0.02731| 0.0| 7.07|   0|0.469|6.421|78.9|4.9671|  2|242|   17.8|396.9| 9.14| 21.6|
+-------+----+-----+----+-----+-----+----+------+---+---+-------+-----+-----+-----+
only showing top 2 rows



In [34]:
# AND Condition (&):

filtered_df = df.filter((df["AGE"] > 30) & (df["CRIM"] > 50000))

# OR Condition (|):

filtered_df = df.filter((df["AGE"] > 30) | (df["CRIM"] == "HR"))

# NOT Condition (~):

filtered_df = df.filter(~(df["RAD"] == "1"))

In [36]:
filtered_df = df.filter(df["RAD"].isin("1", "2"))
filtered_df.show(2)


+-------+----+-----+----+-----+-----+----+------+---+---+-------+-----+-----+-----+
|   CRIM|  ZN|INDUS|CHAS|  NOX|   RM| AGE|   DIS|RAD|TAX|PTRATIO|    B|LSTAT|Price|
+-------+----+-----+----+-----+-----+----+------+---+---+-------+-----+-----+-----+
|0.00632|18.0| 2.31|   0|0.538|6.575|65.2|  4.09|  1|296|   15.3|396.9| 4.98| 24.0|
|0.02731| 0.0| 7.07|   0|0.469|6.421|78.9|4.9671|  2|242|   17.8|396.9| 9.14| 21.6|
+-------+----+-----+----+-----+-----+----+------+---+---+-------+-----+-----+-----+
only showing top 2 rows



In [37]:
filtered_df = df.filter(df["AGE"].startswith("3"))
filtered_df.show(2)

+-------+----+-----+----+-----+-----+----+------+---+---+-------+------+-----+-----+
|   CRIM|  ZN|INDUS|CHAS|  NOX|   RM| AGE|   DIS|RAD|TAX|PTRATIO|     B|LSTAT|Price|
+-------+----+-----+----+-----+-----+----+------+---+---+-------+------+-----+-----+
|0.09378|12.5| 7.87|   0|0.524|5.889|39.0|5.4509|  5|311|   15.2| 390.5|15.71| 21.7|
|0.80271| 0.0| 8.14|   0|0.538|5.456|36.6|3.7965|  4|307|   21.0|288.99|11.69| 20.2|
+-------+----+-----+----+-----+-----+----+------+---+---+-------+------+-----+-----+
only showing top 2 rows



This code demonstrates filtering in PySpark.

In [38]:
# Import required modules
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, array_contains

# Initialize SparkSession
spark = SparkSession.builder.appName("PySparkFilteringExamples").getOrCreate()

# Sample data: list of tuples (student_ID, student_NAME, college)
data = [
    ("1", "Amit", "DU"),
    ("2", "Mohit", "DU"),
    ("3", "Rohith", "BHU"),
    ("4", "Sridevi", "LPU"),
    ("5", "Gnanesh", "IIT"),
    ("6", "Anita", "DU"),
    ("7", "Amrit", "IIT"),
]

# Define schema column names
columns = ["student_ID", "student_NAME", "college"]

# Create DataFrame
df = spark.createDataFrame(data, columns)

# Show original DataFrame
print("Original DataFrame:")
df.show()

# 1. Filter rows with single condition: college = 'DU'
print("Filter: college == 'DU'")
df.filter(df.college == "DU").show()

# 2. Filter with multiple conditions using & (AND): college='DU' AND student_ID='1'
print("Filter: college == 'DU' AND student_ID == '1'")
df.filter((df.college == "DU") & (df.student_ID == "1")).show()

# 3. Filter using SQL expression string syntax: student_NAME starts with 'A'
print("Filter: student_NAME starts with 'A' using SQL string filter")
df.filter("student_NAME LIKE 'A%'").show()

# 4. Filter using SQL col() function with multiple conditions: college='DU' AND student_NAME='Amit'
print("Filter: college == 'DU' AND student_NAME == 'Amit' using col() function")
df.filter((col("college") == "DU") & (col("student_NAME") == "Amit")).show()

# 5. Filter using isin(): student_ID in [1, 2], or college in ['DU', 'IIT']
print("Filter: student_ID in [1,2] OR college in ['DU','IIT']")
df.filter((df.student_ID.isin(["1", "2"])) | (df.college.isin(["DU", "IIT"]))).show()

# 6. Filter with startswith and endswith: student_NAME starts with 'A' AND ends with 't'
print("Filter: student_NAME startswith 'A' AND endswith 't'")
df.filter((df.student_NAME.startswith("A")) & (df.student_NAME.endswith("t"))).show()

# 7. Filter rows containing substring: student_NAME contains 'ith'
print("Filter: student_NAME contains 'ith'")
df.filter(df.student_NAME.contains("ith")).show()

# 8. Using NOT condition: exclude students from college 'DU'
print("Filter: NOT college == 'DU'")
df.filter(~(df.college == "DU")).show()

df.show()

# Stop SparkSession
spark.stop()


25/11/07 19:53:14 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


Original DataFrame:


                                                                                

+----------+------------+-------+
|student_ID|student_NAME|college|
+----------+------------+-------+
|         1|        Amit|     DU|
|         2|       Mohit|     DU|
|         3|      Rohith|    BHU|
|         4|     Sridevi|    LPU|
|         5|     Gnanesh|    IIT|
|         6|       Anita|     DU|
|         7|       Amrit|    IIT|
+----------+------------+-------+

Filter: college == 'DU'
+----------+------------+-------+
|student_ID|student_NAME|college|
+----------+------------+-------+
|         1|        Amit|     DU|
|         2|       Mohit|     DU|
|         6|       Anita|     DU|
+----------+------------+-------+

Filter: college == 'DU' AND student_ID == '1'
+----------+------------+-------+
|student_ID|student_NAME|college|
+----------+------------+-------+
|         1|        Amit|     DU|
+----------+------------+-------+

Filter: student_NAME starts with 'A' using SQL string filter
+----------+------------+-------+
|student_ID|student_NAME|college|
+----------+----

Sample Code
This code demonstrates grouping in PySpark.

In [39]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import avg, sum, count

# Initialize Spark Session
spark = SparkSession.builder \
    .appName("PySpark Grouping Example") \
    .getOrCreate()

# Sample DataFrame
data = [
    ("Alice", "HR", 5000),
    ("Bob", "IT", 6000),
    ("Charlie", "Finance", 7000),
    ("David", "IT", 6000),
    ("Eve", "HR", 5500),
    ("Frank", "Finance", 8000),
]
columns = ["name", "department", "salary"]

df = spark.createDataFrame(data, columns)

# Show original data
print("Original Data:")
df.show()

# Group by department and calculate aggregates
print("Group by Department - Count:")
df.groupBy("department").count().show()

print("Group by Department - Sum of Salaries:")
df.groupBy("department").sum("salary").show()

print("Group by Department - Average Salary:")
df.groupBy("department").agg(avg("salary")).show()

print("Group by Department - Multiple Aggregates:")
df.groupBy("department").agg(
    count("name").alias("employee_count"),
    sum("salary").alias("total_salary"),
    avg("salary").alias("average_salary")
).show()

# Stop Spark Session
spark.stop()


25/11/07 20:12:37 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
25/11/07 20:12:37 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.
25/11/07 20:12:37 WARN Utils: Service 'SparkUI' could not bind on port 4042. Attempting port 4043.


Original Data:


                                                                                

+-------+----------+------+
|   name|department|salary|
+-------+----------+------+
|  Alice|        HR|  5000|
|    Bob|        IT|  6000|
|Charlie|   Finance|  7000|
|  David|        IT|  6000|
|    Eve|        HR|  5500|
|  Frank|   Finance|  8000|
+-------+----------+------+

Group by Department - Count:
+----------+-----+
|department|count|
+----------+-----+
|        HR|    2|
|        IT|    2|
|   Finance|    2|
+----------+-----+

Group by Department - Sum of Salaries:
+----------+-----------+
|department|sum(salary)|
+----------+-----------+
|        HR|      10500|
|        IT|      12000|
|   Finance|      15000|
+----------+-----------+

Group by Department - Average Salary:
+----------+-----------+
|department|avg(salary)|
+----------+-----------+
|        HR|     5250.0|
|        IT|     6000.0|
|   Finance|     7500.0|
+----------+-----------+

Group by Department - Multiple Aggregates:
+----------+--------------+------------+--------------+
|department|employee_count|t