In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *

# Create a SparkSession
spark = SparkSession.builder.appName("EmployeeData").getOrCreate() 

In [2]:
# Sample data for employees
data = [
    (1, "Alice", "HR", 60000),
    (2, "Bob", "HR", 50000),
    (3, "Charlie", "Finance", 70000),
    (4, "David", "Finance", 75000),
    (5, "Eve", "Engineering", 90000),
    (6, "Frank", "Engineering", 93000),
    (7, "Grace", "HR", 45000),
    (8, "Hank", "Engineering", 98000),
    (9, "Ivy", "Finance", 66000)
]

# Define column names for the DataFrame
columns = ["employee_id", "employee_name", "department", "salary"]

# Create the DataFrame
df = spark.createDataFrame(data, columns)

df.show()

+-----------+-------------+-----------+------+
|employee_id|employee_name| department|salary|
+-----------+-------------+-----------+------+
|          1|        Alice|         HR| 60000|
|          2|          Bob|         HR| 50000|
|          3|      Charlie|    Finance| 70000|
|          4|        David|    Finance| 75000|
|          5|          Eve|Engineering| 90000|
|          6|        Frank|Engineering| 93000|
|          7|        Grace|         HR| 45000|
|          8|         Hank|Engineering| 98000|
|          9|          Ivy|    Finance| 66000|
+-----------+-------------+-----------+------+



In [3]:
from pyspark.sql.window import Window

window_spec = Window.partitionBy('department')

win_df = df.withColumn('avg_salary',avg(col('salary')).over(window_spec))

win_df.show()

+-----------+-------------+-----------+------+------------------+
|employee_id|employee_name| department|salary|        avg_salary|
+-----------+-------------+-----------+------+------------------+
|          5|          Eve|Engineering| 90000| 93666.66666666667|
|          6|        Frank|Engineering| 93000| 93666.66666666667|
|          8|         Hank|Engineering| 98000| 93666.66666666667|
|          3|      Charlie|    Finance| 70000| 70333.33333333333|
|          4|        David|    Finance| 75000| 70333.33333333333|
|          9|          Ivy|    Finance| 66000| 70333.33333333333|
|          1|        Alice|         HR| 60000|51666.666666666664|
|          2|          Bob|         HR| 50000|51666.666666666664|
|          7|        Grace|         HR| 45000|51666.666666666664|
+-----------+-------------+-----------+------+------------------+



In [4]:
result_df = win_df.filter(col('salary')>col('avg_salary'))

result_df.show()

+-----------+-------------+-----------+------+------------------+
|employee_id|employee_name| department|salary|        avg_salary|
+-----------+-------------+-----------+------+------------------+
|          8|         Hank|Engineering| 98000| 93666.66666666667|
|          4|        David|    Finance| 75000| 70333.33333333333|
|          1|        Alice|         HR| 60000|51666.666666666664|
+-----------+-------------+-----------+------+------------------+



SQL approach

In [5]:
df.createOrReplaceTempView("employee")

In [6]:
query = """
        with cte as (
        select e.*,avg(salary) over(partition by department) as avg_sal from employee e )
        select * from cte where salary > avg_sal;
 """

In [7]:
result_sql_df = spark.sql(query)
result_sql_df.show()

+-----------+-------------+-----------+------+------------------+
|employee_id|employee_name| department|salary|           avg_sal|
+-----------+-------------+-----------+------+------------------+
|          8|         Hank|Engineering| 98000| 93666.66666666667|
|          4|        David|    Finance| 75000| 70333.33333333333|
|          1|        Alice|         HR| 60000|51666.666666666664|
+-----------+-------------+-----------+------+------------------+

