# **PYSPARK INTERVIEW QUESTIONS**

In [0]:
from pyspark.sql.functions import * 
from pyspark.sql.types import *

from pyspark.sql import SparkSession

spark = SparkSession.builder \
    .appName("SparkSessionExample") \
    .master("local[*]") \
    .getOrCreate()

**Q1 While ingesting customer data from an external source, you notice duplicate entries. How would you remove duplicates and retain only the latest entry based on a timestamp column?**

In [0]:
data = [("101", "2023-12-01", 100), ("101", "2023-12-02", 150), 
        ("102", "2023-12-01", 200), ("102", "2023-12-02", 250)]
columns = ["product_id", "date", "sales"]

df = spark.createDataFrame(data, columns)
df.display()

**Solution**

In [0]:
df = df.withColumn('date', col('date').cast(DateType()))

In [0]:
df.orderBy('product_id','date', ascending = [1,0]).dropDuplicates(subset=['product_id']).display()

**2. While processing data from multiple files with inconsistent schemas, you need to merge them into a single DataFrame. How would you handle this inconsistency in PySpark?**

**Solution**

In [0]:
df = (df.spark.read
      .format('parquet')
      .option('mergeSchema', True)
      .load('sample_path'))

**4. You are working with a real-time data pipeline, and you notice missing values in your streaming data Column - Category. How would you handle null or missing values in such a scenario?**

**df_stream = spark.readStream.schema("id INT, value STRING").csv("path/to/stream")**

In [0]:
df = df.fillna({'Category': 'NA'})

**5. You need to calculate the total number of actions performed by users in a system. How would you calculate the top 5 most active users based on this information?**

In [0]:
data = [("user1", 5), ("user2", 8), ("user3", 2), ("user4", 10), ("user2", 3)]
columns = ["user_id", "actions"]

df = spark.createDataFrame(data, columns)
df.display()

In [0]:
df = df.groupBy('user_id').agg(sum('actions').alias('total_actions')).orderBy('total_actions',ascending=False).limit(5)

df.display()

**6. While processing sales transaction data, you need to identify the most recent transaction for each customer. How would you approach this task?**

In [0]:
data = [("cust1", "2023-12-01", 100), ("cust2", "2023-12-02", 150),
        ("cust1", "2023-12-03", 200), ("cust2", "2023-12-04", 250)]
columns = ["customer_id", "transaction_date", "sales"]
df = spark.createDataFrame(data, columns)
df.display()

In [0]:
from pyspark.sql.window import Window

df.withColumn('latest', dense_rank().over(Window.partitionBy('customer_id').orderBy(col('transaction_date').desc()))).filter(col('latest')==1).display()

**7. You need to identify customers who haven’t made any purchases in the last 30 days. How would you filter such customers?**

In [0]:
data = [("cust1", "2025-12-01"), ("cust2", "2024-11-20"), ("cust3", "2024-11-25")]
columns = ["customer_id", "last_purchase_date"]

df = spark.createDataFrame(data, columns)

df.display()

In [0]:
df = df.withColumn('last_purchase_date', to_date(col('last_purchase_date')))

df.withColumn('gap',date_diff(current_date(),col('last_purchase_date'))).filter(col('gap') > 30).display()

**8. While analyzing customer reviews, you need to identify the most frequently used words in the feedback. How would you implement this?**

In [0]:
data = [("customer1", "The product is great"), ("customer2", "Great product, fast delivery"), ("customer3", "Not bad, could be better")]
columns = ["customer_id", "feedback"]

df = spark.createDataFrame(data, columns)

df.display()

In [0]:
df = (df.withColumn('feedback',lower('feedback'))
        .withColumn('feedback', explode(split(col('feedback'), ' ')))
        .withColumn('feedback', regexp_replace(col('feedback'),",", "")))


df_grp = df.groupBy('feedback').agg(count('feedback').alias('wordcount')).orderBy(col('wordcount'),ascending = 0)
df_grp.display()

**9. You need to calculate the cumulative sum of sales over time for each product. How would you approach this?**

In [0]:
data = [("product1", "2023-12-01", 100), ("product2", "2023-12-02", 200),
        ("product1", "2023-12-03", 150), ("product2", "2023-12-04", 250)]
columns = ["product_id", "date", "sales"]
df = spark.createDataFrame(data, columns)
df.display()

In [0]:
df = df.withColumn('cumsum', sum('sales').over(Window.partitionBy('product_id').orderBy('date')))

df.display()

**10. While preparing a data pipeline, you notice some duplicate rows in a dataset. How would you remove the duplicates without affecting the original order?**

In [0]:
data = [("John", 25), ("Jane", 30), ("John", 25), ("Alice", 22)]
columns = ["name", "age"]
df = spark.createDataFrame(data, columns)
df.display()

In [0]:
from pyspark.sql.window import Window
df = df.withColumn('row_num', row_number().over(Window.partitionBy(col('name')).orderBy(col('age')))).filter(col('row_num') ==1).display()

**11. You are working with user activity data and need to calculate the average session duration per user. How would you implement this?**

In [0]:
data = [("user1", "2023-12-01", 50), ("user1", "2023-12-02", 60), 
        ("user2", "2023-12-01", 45), ("user2", "2023-12-03", 75)]
columns = ["user_id", "session_date", "duration"]
df = spark.createDataFrame(data, columns)

df.display()

In [0]:
df.groupBy('user_id').agg(avg('duration').alias('avg_duration')).display()

**12. While analyzing sales data, you need to find the product with the highest sales for each month. How would you accomplish this?**

In [0]:
data = [("product1", "2023-12-01", 100), ("product2", "2023-12-01", 150), 
        ("product1", "2023-12-02", 200), ("product2", "2023-12-02", 250)]
columns = ["product_id", "date", "sales"]
df = spark.createDataFrame(data, columns)
df.display()

In [0]:
from pyspark.sql.window import Window
df = df.withColumn('date', to_date('date'))

df = df.withColumn('date', month('date')).groupBy('date','product_id').agg(sum('sales').alias('sum_sales'))
df = df.withColumn('rank', dense_rank().over(Window.partitionBy('date').orderBy(desc('sum_sales')))).filter(col('rank')==1)

In [0]:
df.display()

**13. You are working with a large Delta table that is frequently updated by multiple users. The data is stored in partitions, and sometimes updates can cause inconsistent reads due to concurrent transactions. How would you ensure ACID compliance and avoid data corruption in PySpark?**

In [0]:
# 1. create a delta log for ACID txn for large delta table
# 2. upsert condition for corruption

df = spark.read.format('parquet').load('path')

from delta.table import DeltaTable

delta_tbl = DeltaTable.forPath('path')

(delta_tbl.alias('trg').merge(df.alias('src'),'src_id == trg_id')
 .whenNotMatchedInsertAll()
 .whenMatchedUpdateAll()
 .execute())


**14. You need to process a large dataset stored in PARQUET format and ensure that all columns have the right schema (Almost). How would you do this?**

In [0]:
df = (spark.read.format('parquet')
      .option("inferSchema", True)
      .load('path'))

**15. You are reading a CSV file and need to handle corrupt records gracefully by skipping them. How would you configure this in PySpark?**

In [0]:
df = (spark.read.format('csv')
      .option("mode", "DROPMALFORMED")
      .load('path'))

**22. You have a dataset containing the names of employees and their departments. You need to find the department with the most employees.**

In [0]:
data = [("Alice", "HR"), ("Bob", "Finance"), ("Charlie", "HR"), ("David", "Engineering"), ("Eve", "Finance")]
columns = ["employee_name", "department"]

df = spark.createDataFrame(data, columns)
df.display()

In [0]:
result_df = df.groupBy('department').agg(count('employee_name').alias('count')).orderBy(desc('count')).limit(1).show()

**23. While processing sales data, you need to classify each transaction as either 'High' or 'Low' based on its amount. How would you achieve this using a when condition**

In [0]:
data = [("product1", 100), ("product2", 300), ("product3", 50)]
columns = ["product_id", "sales"]

df = spark.createDataFrame(data, columns)
df.display()

In [0]:
df.withColumn('price_cat', when(col('sales') > 50, 'High').otherwise('Low')).display()

**24. While analyzing a large dataset, you need to create a new column that holds a timestamp of when the record was processed. How would you implement this and what can be the best USE CASE?**

In [0]:
# use case in scd
data = [("product1", 100), ("product2", 200), ("product3", 300)]
columns = ["product_id", "sales"]

df = spark.createDataFrame(data, columns)
df.display()

In [0]:
df = df.withColumn('processed_time', current_timestamp())
df.display()

**25. You need to register this PySpark DataFrame as a temporary SQL object and run a query on it. How would you achieve this?**

In [0]:
data = [("product1", 100), ("product2", 200), ("product3", 300)]
columns = ["product_id", "sales"]

df = spark.createDataFrame(data, columns)
df.display()

In [0]:
df.createOrReplaceTempView('tempsql')

spark.sql("select * from tempsql").display()

**26. You need to register this PySpark DataFrame as a temporary SQL object and run a query on it (FROM DIFFERENT NOTEBOOKS AS WELL)?**

In [0]:
df.createOrReplaceGlobalTempView('globaldf')

In [0]:
%sql
select * from global_temp.globaldf

**27. You need to query data from a PySpark DataFrame using SQL, but the data includes a nested structure. How would you flatten the data for easier querying?**

In [0]:
data = [("product1", {"price": 100, "quantity": 2}), 
        ("product2", {"price": 200, "quantity": 3})]
columns = ["product_id", "product_info"]

df = spark.createDataFrame(data, columns)
df.display()

In [0]:
df.select('product_id', 'product_info.price','product_info.quantity').display()

**28. You are ingesting data from an external API in JSON format where the schema is inconsistent. How would you handle this situation to ensure a robust pipeline?**

In [0]:
df = spark.read.format("json").option("mergeSchema", True).load('path')

**29. While reading data from Parquet, you need to optimize performance by partitioning the data based on a column. How would you implement this?**

In [0]:
df.write.format('parquet').mode('append').partitionBy('category').save('location')

**30. You are working with a large dataset in Parquet format and need to ensure that the data is written in an optimized manner with proper compression. How would you accomplish this?**

In [0]:
df.write.format('parquet').option('compression', 'snappy').save('location')

**31. Your company uses a large-scale data pipeline that reads from Delta tables and processes data using complex aggregations. However, performance is becoming an issue due to the growing dataset size. How would you optimize the performance of the pipeline?**

In [0]:
%sql

OPTIMIZE tabledelta ZORDER BY ("order_date")

-- delta lake time travel

describe history tbl;
restore tbl to version as of 2;

**43. You are processing sales data. Group by product categories and create a list of all product names in each category.**

In [0]:
data = [("Electronics", "Laptop"), ("Electronics", "Smartphone"), ("Furniture", "Chair"), ("Furniture", "Table")]
columns = ["category", "product"]
df = spark.createDataFrame(data, columns)
df.display()

In [0]:
df = df.groupBy('category').agg(collect_list('product').alias('products'))
df.display()

**44. You are analyzing orders. Group by customer IDs and list all unique product IDs each customer purchased.**

In [0]:
data = [(101, "P001"), (101, "P002"), (102, "P001"), (101, "P001")]
columns = ["customer_id", "product_id"]
df = spark.createDataFrame(data, columns)
df.display()

In [0]:
df.groupBy('customer_id').agg(collect_set('product_id').alias('unique_products')).display()

**45. For customer records, combine first and last names only if the email address exists.**

In [0]:
data = [("John", "Doe", "john.doe@example.com"), ("Jane", "Smith", None)]
columns = ["first_name", "last_name", "email"]
df = spark.createDataFrame(data, columns)
df.display()

In [0]:
df.withColumn('name', when(col('email').isNotNull(),concat_ws(' ',col('first_name'),col('last_name'))).otherwise(None)).display()

In [0]:
df.filter(col('email').isNotNull()).withColumn('name', concat_ws(' ',col('first_name'),col('last_name'))).display()

**46. You have a DataFrame containing customer IDs and a list of their purchased product IDs. Calculate the number of products each customer has purchased.**

In [0]:
data = [
    (1, ["prod1", "prod2", "prod3"]),
    (2, ["prod4"]),
    (3, ["prod5", "prod6"]),
]
myschema = "customer_id INT ,product_ids array<STRING>"

df = spark.createDataFrame(data, myschema)
df.display()

In [0]:
df.withColumn('no_of_prod', size(col('product_ids'))).display()

**47. You have employee IDs of varying lengths. Ensure all IDs are 6 characters long by padding with leading zeroes.**

In [0]:
data = [
    ("1",),
    ("123",),
    ("4567",),
]
schema = ["employee_id"]

df = spark.createDataFrame(data, schema)
df.display()

In [0]:
df.withColumn('employee_id', lpad(col('employee_id'),6,"0")).display()

# use rpad for appending 0 at the end

**48. You need to validate phone numbers by checking if they start with "91"**

In [0]:
data = [
    ("911234567890",),
    ("811234567890",),
    ("912345678901",),
]
schema = ["phone_number"]

df = spark.createDataFrame(data, schema)
df.display()

In [0]:
df.filter(substring(col('phone_number'),1,2) == '91').display()

**49. You have a dataset with courses taken by students. Calculate the average number of courses per student.**

In [0]:
data = [
    (1, ["Math", "Science"]),
    (2, ["History"]),
    (3, ["Art", "PE", "Biology"]),
]
schema = ["student_id", "courses"]

df = spark.createDataFrame(data, schema)
df.display()

In [0]:
df.withColumn('course_size', size('courses')).groupBy().agg(avg('course_size')).display()

# groupby is left blank similar to reduce, to get the total avg

**50. You have a dataset with primary and secondary contact numbers. Use the primary number if available; otherwise, use the secondary number.**

In [0]:
data = [
    (None, "1234567890"),
    ("9876543210", None),
    ("7894561230", "4567891230"),
]
schema = ["primary_contact", "secondary_contact"]

df = spark.createDataFrame(data, schema)
df.display()

In [0]:
df.withColumn('phone_no', when(col('primary_contact').isNotNull(), col('primary_contact')).otherwise(col('secondary_contact'))).display()

# or 

df.withColumn('phone_no', coalesce(col('primary_contact'),col('secondary_contact'))).display()

**51. You are categorizing product codes based on their lengths. If the length is 5, label it as "Standard"; otherwise, label it as "Custom".**

In [0]:
data = [
    ("prod1",),
    ("prd234",),
    ("pr9876",),
]
schema = ["product_code"]

df = spark.createDataFrame(data, schema)
df.display()

In [0]:
df.withColumn('category',when(length(col('product_code')) == 5,'standard').otherwise('custom')).display()

**52. Flatten a json data**

In [None]:
df = df.select('items','customer.customer_id','customer.name','customer.email','customer.address.city','customer.address.country','customer.address.postal_code','payment.method','payment.transaction_id','metadata')

df = df.withColumn('items', explode_outer('items'))


In [None]:
df = df.select('items.item_id','items.price','items.product_name','items.quantity','customer_id','name','email','city','country','postal_code','method','transaction_id','metadata')

df = df.withColumn('metadata', explode_outer('metadata'))
df = df.select('item_id','price','product_name','quantity','customer_id','name','email','city','country','postal_code','method','transaction_id','metadata.key','metadata.value')


**53. Kafka Streaming code**

In [None]:
# sliding window example
from pyspark.sql.functions import window, col, countDistinct

result = user_df.groupBy(
    window("event_time", "15 minutes", "5 minutes"),  # window size, slide interval
    "country"
).agg(
    countDistinct("user_id").alias("distinct_user_count")
)
result.orderBy("window").show(truncate=False)


In [None]:
# watermarking

from pyspark.sql.functions import window, col, countDistinct

result = user_df \
    .withWatermark("event_time", "10 minutes") \
    .groupBy(
        window("event_time", "15 minutes"),
        "country"
    ) \
    .agg(countDistinct("user_id").alias("distinct_user_count"))


In [None]:
# Window start alignment control

# To align window so it starts at 5 minutes past the hour (e.g., 10:05, 10:20, ...)
result = user_df.groupBy(
    window("event_time", "15 minutes", "15 minutes", "5 minutes"),  # window size, slide, start(offset)
    "country"
).agg(
    countDistinct("user_id").alias("distinct_user_count")
)


In [None]:
# combining all these in one

from pyspark.sql.functions import window, col, countDistinct

result = user_df \
    .withWatermark("event_time", "10 minutes") \
    .groupBy(
        window("event_time", "15 minutes", "5 minutes", "5 minutes"),
        "country"
    ).agg(countDistinct("user_id").alias("distinct_user_count"))


**54. Pyspark optimisation code**

In [None]:
from pyspark.sql.functions import broadcast

# Read efficiently
df = spark.read.format("parquet").load("/mnt/data/large_table").select("user_id", "event_time")

# Filter early
df = df.filter(df["event_time"] >= "2025-07-01")

# Broadcast small reference table for join
ref_df = spark.read.format("parquet").load("/mnt/data/reference").limit(1000)
df_joined = df.join(broadcast(ref_df), "user_id")

# Cache if re-used
df_joined.cache()

# Repartition before writing as Delta
df_joined.repartition(10).write.format("delta").mode("overwrite").save("/mnt/delta/output")

# Remember to unpersist when done
df_joined.unpersist()


**55. SCD2 implementation**

In [None]:
from pyspark.sql.functions import lit, current_timestamp
from delta.tables import DeltaTable

# Sample initial SCD table
initial_data = [
    (1, 'CUST001', 'Alice', True, "2024-01-01", None),
    (2, 'CUST002', 'Bob', True, "2024-01-01", None)
]

columns = ['surrogate_key', 'customer_id', 'name', 'is_current', 'effective_date', 'end_date']

df = spark.createDataFrame(initial_data, columns) \
          .withColumn("effective_date", lit("2024-01-01").cast("date"))

# Write to Delta table
df.write.format("delta").mode("overwrite").save("/tmp/delta/customer_dim")


In [None]:
from pyspark.sql.functions import current_date

source_data = [
    ("CUST002", "Robert"),  # Existing but changed
    ("CUST003", "Dave")     # New customer
]

source_df = spark.createDataFrame(source_data, ["customer_id", "name"]) \
                 .withColumn("effective_date", current_date())


In [None]:
# Load the target table as a DeltaTable
dim_table = DeltaTable.forPath(spark, "/tmp/delta/customer_dim")

# Create a temp view for use in the MERGE
source_df.createOrReplaceTempView("staged_updates")

# Get the max surrogate key for new rows
max_sk = spark.read.format("delta").load("/tmp/delta/customer_dim").agg({"surrogate_key": "max"}).collect()[0][0] or 0

# Step 1: Expire unchanged records
dim_table.alias("tgt").merge(
    source_df.alias("src"),
    "tgt.customer_id = src.customer_id AND tgt.is_current = true AND tgt.name != src.name"
).whenMatchedUpdate(set={
    "is_current": lit(False),
    "end_date": current_date()
}).execute()

# Step 2: Insert new records (new customers or changed names)
updates_to_insert = source_df.alias("src") \
    .join(
        dim_table.toDF().filter("is_current = true").alias("tgt"),
        on="customer_id",
        how="left_anti"
    ) \
    .withColumn("surrogate_key", lit(max_sk + 1)) \
    .withColumn("is_current", lit(True)) \
    .withColumn("end_date", lit(None).cast("date"))

# Append new records to Delta table
updates_to_insert.select("surrogate_key", "customer_id", "name", "is_current", "effective_date", "end_date") \
    .write.format("delta").mode("append").save("/tmp/delta/customer_dim")


In [None]:
display(spark.read.format("delta")
        .load("/tmp/delta/customer_dim")
        .orderBy("customer_id", "effective_date"))
