In [3]:
from pyspark.sql.types import StructType, StructField, IntegerType, StringType
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType, StringType
spark = SparkSession.builder.appName("PySparkTables").getOrCreate()

# Sample data
employee_data = [
    (1, 'A', 100),
    (2, 'A', 200),
    (3, 'A', 300),
    (4, 'B', 400),
    (5, 'B', 500),
]

# Define schema
employee_schema = StructType([
    StructField("id", IntegerType(), False),
    StructField("company", StringType(), False),
    StructField("salary", IntegerType(), False),
])

# Create DataFrame
employee_df = spark.createDataFrame(employee_data, employee_schema)

# Register as temporary view
employee_df.createOrReplaceTempView("Employee")


In [15]:
spark.sql(
"""
SELECT company, 
       ROUND(AVG(salary), 0) AS median_salary
FROM (
    SELECT company, salary,
           ROW_NUMBER() OVER (PARTITION BY company ORDER BY salary) AS rn,
           COUNT(*) OVER (PARTITION BY company) AS cnt
    FROM Employee
) AS ranked
WHERE rn = (cnt + 1) / 2      -- For odd count: middle element
   OR rn = (cnt + 2) / 2      -- For even count: average of middle two
GROUP BY company;
""").show()

+-------+-------------+
|company|median_salary|
+-------+-------------+
|      A|        200.0|
|      B|        500.0|
+-------+-------------+

