# PySpark Koans - Solutions Notebook

This notebook contains complete solutions for all koans. Use this to reference correct answers after attempting the practice notebook.

**Note**: These koans are designed to work with the browser-based pandas shim. To run with real PySpark, you'll need a Spark environment.

## Categories:
- **Koans 1-30**: PySpark Basics and Operations
- **Koans 101-110**: Delta Lake
- **Koans 201-210**: Unity Catalog
- **Koans 301-310**: Pandas API on Spark

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.window import Window
import pyspark.pandas as ps

# For browser-based version, spark is already initialized
print("✓ Environment ready")

## Basics

In [None]:
# Koan 1: Creating a DataFrame
# Solution

data = [("Alice", 34), ("Bob", 45), ("Charlie", 29)]
columns = ["name", "age"]

df = spark.createDataFrame(data, columns)

assert df.count() == 3
assert len(df.columns) == 2
print("✓ Koan 1 complete: DataFrame creation")

In [None]:
# Koan 2: Selecting Columns
# Solution

data = [("Alice", 34, "NYC"), ("Bob", 45, "LA"), ("Charlie", 29, "Chicago")]
df = spark.createDataFrame(data, ["name", "age", "city"])

result = df.select("name", "city")

assert len(result.columns) == 2
assert "name" in result.columns
assert "city" in result.columns
print("✓ Koan 2 complete: Column selection")

In [None]:
# Koan 3: Filtering Rows
# Solution

data = [("Alice", 34), ("Bob", 45), ("Charlie", 29), ("Diana", 52)]
df = spark.createDataFrame(data, ["name", "age"])

result = df.filter(col("age") > 35)

assert result.count() == 2
rows = result.collect()
ages = [row["age"] for row in rows]
assert all(age > 35 for age in ages)
print("✓ Koan 3 complete: Row filtering")

In [None]:
# Koan 4: Adding Columns
# Solution

data = [("Alice", 34), ("Bob", 45), ("Charlie", 29)]
df = spark.createDataFrame(data, ["name", "age"])

result = df.withColumn("age_in_months", col("age") * 12)

assert result.count() == 3
assert len(result.columns) == 3
first_row = result.filter(col("name") == "Alice").collect()[0]
assert first_row["age_in_months"] == 408
print("✓ Koan 4 complete: Adding columns")

In [None]:
# Koan 5: Grouping and Aggregating
# Solution

data = [
    ("Sales", "Alice", 5000),
    ("Sales", "Bob", 4500),
    ("Engineering", "Charlie", 6000),
    ("Engineering", "Diana", 6500),
    ("Engineering", "Eve", 5500)
]
df = spark.createDataFrame(data, ["department", "name", "salary"])

result = df.groupBy("department").agg(
    round(avg("salary"), 2).alias("avg_salary")
)

assert result.count() == 2
eng_row = result.filter(col("department") == "Engineering").collect()[0]
assert eng_row["avg_salary"] == 6000.0
print("✓ Koan 5 complete: Grouping and aggregating")

In [None]:
# Koan 6: Dropping Columns
# Solution

data = [("Alice", 34, "NYC", "F"), ("Bob", 45, "LA", "M")]
df = spark.createDataFrame(data, ["name", "age", "city", "gender"])

result = df.drop("gender")
result2 = df.drop("city", "gender")

assert "gender" not in result.columns
assert len(result.columns) == 3
assert len(result2.columns) == 2
print("✓ Koan 6 complete: Dropping columns")

In [None]:
# Koan 7: Distinct Values
# Solution

data = [("Alice", "NYC"), ("Bob", "LA"), ("Alice", "NYC"), ("Charlie", "NYC")]
df = spark.createDataFrame(data, ["name", "city"])

result = df.distinct()
cities = df.select("city").distinct()

assert result.count() == 3
assert cities.count() == 2
print("✓ Koan 7 complete: Distinct values")

## Column Operations

In [None]:
# Koan 9: Renaming Columns
# Solution

data = [("Alice", 34), ("Bob", 45)]
df = spark.createDataFrame(data, ["name", "age"])

result = df.withColumnRenamed("name", "employee_name")
result2 = df.select(col("name").alias("full_name"), col("age"))

assert "employee_name" in result.columns
assert "name" not in result.columns
assert "full_name" in result2.columns
print("✓ Koan 9 complete: Renaming columns")

In [None]:
# Koan 10: Literal Values
# Solution

data = [("Alice", 34), ("Bob", 45)]
df = spark.createDataFrame(data, ["name", "age"])

result = df.withColumn("country", lit("USA"))
result2 = df.withColumn("bonus", lit(1000))

rows = result.collect()
assert all(row["country"] == "USA" for row in rows)
assert result2.collect()[0]["bonus"] == 1000
print("✓ Koan 10 complete: Literal values")

In [None]:
# Koan 11: Conditional Logic with when/otherwise
# Solution

data = [("Alice", 34), ("Bob", 45), ("Charlie", 17), ("Diana", 65)]
df = spark.createDataFrame(data, ["name", "age"])

result = df.withColumn(
    "age_group",
    when(col("age") < 18, "minor")
    .when(col("age") < 65, "adult")
    .otherwise("senior")
)

rows = result.collect()
groups = {row["name"]: row["age_group"] for row in rows}

assert groups["Charlie"] == "minor"
assert groups["Alice"] == "adult"
assert groups["Diana"] == "senior"
print("✓ Koan 11 complete: Conditional logic")

In [None]:
# Koan 12: Type Casting
# Solution

data = [("Alice", "34"), ("Bob", "45")]
df = spark.createDataFrame(data, ["name", "age_str"])

result = df.withColumn("age", col("age_str").cast("integer"))
result = result.withColumn("age_plus_10", col("age") + 10)

result2 = df.withColumn("age_float", col("age_str").cast("double"))

rows = result.collect()
assert rows[0]["age_plus_10"] == 44
print("✓ Koan 12 complete: Type casting")

## String Functions

In [None]:
# Koan 13: String Functions - Case
# Solution

data = [("alice smith",), ("BOB JONES",), ("Charlie Brown",)]
df = spark.createDataFrame(data, ["name"])

result = df.withColumn("upper_name", upper(col("name")))
assert result.collect()[0]["upper_name"] == "ALICE SMITH"

result = df.withColumn("lower_name", lower(col("name")))
assert result.collect()[1]["lower_name"] == "bob jones"

result = df.withColumn("title_name", initcap(col("name")))
assert result.collect()[0]["title_name"] == "Alice Smith"

print("✓ Koan 13 complete: String case functions")

In [None]:
# Koan 14: String Functions - Concatenation
# Solution

data = [("Alice", "Smith"), ("Bob", "Jones")]
df = spark.createDataFrame(data, ["first", "last"])

result = df.withColumn("full_name", concat(col("first"), lit(" "), col("last")))
assert result.collect()[0]["full_name"] == "Alice Smith"

result2 = df.withColumn("full_name", concat_ws(" ", col("first"), col("last")))
assert result2.collect()[0]["full_name"] == "Alice Smith"

print("✓ Koan 14 complete: String concatenation")

In [None]:
# Koan 15: String Functions - Substring and Length
# Solution

data = [("Alice",), ("Bob",), ("Charlotte",)]
df = spark.createDataFrame(data, ["name"])

result = df.withColumn("name_length", length(col("name")))
lengths = [row["name_length"] for row in result.collect()]
assert lengths == [5, 3, 9]

result2 = df.withColumn("first_three", substring(col("name"), 1, 3))
firsts = [row["first_three"] for row in result2.collect()]
assert firsts == ["Ali", "Bob", "Cha"]

print("✓ Koan 15 complete: Substring and length")

In [None]:
# Koan 16: String Functions - Trim and Pad
# Solution

data = [("  Alice  ",), ("Bob",), (" Charlie ",)]
df = spark.createDataFrame(data, ["name"])

result = df.withColumn("trimmed", trim(col("name")))
trimmed = [row["trimmed"] for row in result.collect()]
assert trimmed == ["Alice", "Bob", "Charlie"]

result2 = df.withColumn("trimmed", trim(col("name")))
result2 = result2.withColumn("padded", lpad(col("trimmed"), 10, "*"))
assert result2.collect()[1]["padded"] == "*******Bob"

print("✓ Koan 16 complete: Trim and pad functions")

## Aggregations

In [None]:
# Koan 17: Grouping and Aggregating
# Solution

data = [
    ("Sales", "Alice", 5000),
    ("Sales", "Bob", 4500),
    ("Engineering", "Charlie", 6000),
    ("Engineering", "Diana", 6500),
    ("Engineering", "Eve", 5500)
]
df = spark.createDataFrame(data, ["department", "name", "salary"])

result = df.groupBy("department").agg(
    round(avg("salary"), 2).alias("avg_salary")
)

assert result.count() == 2
eng_row = result.filter(col("department") == "Engineering").collect()[0]
assert eng_row["avg_salary"] == 6000.0
print("✓ Koan 17 complete: Grouping and aggregating")

In [None]:
# Koan 18: Multiple Aggregations
# Solution

data = [
    ("Sales", 5000), ("Sales", 4500), ("Sales", 6000),
    ("Engineering", 6000), ("Engineering", 6500)
]
df = spark.createDataFrame(data, ["department", "salary"])

result = df.groupBy("department").agg(
    min("salary").alias("min_salary"),
    max("salary").alias("max_salary"),
    avg("salary").alias("avg_salary"),
    count("salary").alias("emp_count")
)

sales = result.filter(col("department") == "Sales").collect()[0]
assert sales["min_salary"] == 4500
assert sales["max_salary"] == 6000
assert sales["emp_count"] == 3
print("✓ Koan 18 complete: Multiple aggregations")

In [None]:
# Koan 19: Aggregate Without Grouping
# Solution

data = [(100,), (200,), (300,), (400,), (500,)]
df = spark.createDataFrame(data, ["value"])

result = df.agg(sum(col("value")).alias("total"))
total = result.collect()[0]["total"]
assert total == 1500

result2 = df.agg(
    sum(col("value")).alias("total"),
    avg("value").alias("average"),
    count("value").alias("num_rows")
)

row = result2.collect()[0]
assert row["average"] == 300.0
assert row["num_rows"] == 5
print("✓ Koan 19 complete: Global aggregations")

## Joins

In [None]:
# Koan 20: Inner Join
# Solution

employees = spark.createDataFrame([
    (1, "Alice", 101),
    (2, "Bob", 102),
    (3, "Charlie", 101)
], ["emp_id", "name", "dept_id"])

departments = spark.createDataFrame([
    (101, "Engineering"),
    (102, "Sales"),
    (103, "Marketing")
], ["dept_id", "dept_name"])

result = employees.join(departments, "dept_id", "inner")

assert result.count() == 3
assert "name" in result.columns
assert "dept_name" in result.columns
alice = result.filter(col("name") == "Alice").collect()[0]
assert alice["dept_name"] == "Engineering"
print("✓ Koan 20 complete: Inner joins")

In [None]:
# Koan 21: Left Outer Join
# Solution

employees = spark.createDataFrame([
    (1, "Alice", 101),
    (2, "Bob", 102),
    (3, "Charlie", 999)
], ["emp_id", "name", "dept_id"])

departments = spark.createDataFrame([
    (101, "Engineering"),
    (102, "Sales")
], ["dept_id", "dept_name"])

result = employees.join(departments, "dept_id", "left")

assert result.count() == 3
charlie = result.filter(col("name") == "Charlie").collect()[0]
assert charlie["dept_name"] is None
print("✓ Koan 21 complete: Left outer joins")

In [None]:
# Koan 22: Join on Multiple Columns
# Solution

orders = spark.createDataFrame([
    ("2024", "Q1", "Alice", 100),
    ("2024", "Q2", "Alice", 150),
    ("2024", "Q1", "Bob", 200)
], ["year", "quarter", "rep", "amount"])

targets = spark.createDataFrame([
    ("2024", "Q1", 120),
    ("2024", "Q2", 140)
], ["year", "quarter", "target"])

result = orders.join(targets, ["year", "quarter"], "inner")

assert result.count() == 3
alice_q1 = result.filter((col("rep") == "Alice") & (col("quarter") == "Q1")).collect()[0]
assert alice_q1["target"] == 120
print("✓ Koan 22 complete: Multi-column joins")

## Window Functions

In [None]:
# Koan 23: Window Functions - Running Total
# Solution

data = [
    ("2024-01-01", 100),
    ("2024-01-02", 150),
    ("2024-01-03", 200),
    ("2024-01-04", 175)
]
df = spark.createDataFrame(data, ["date", "sales"])

window_spec = Window.orderBy("date").rowsBetween(Window.unboundedPreceding, Window.currentRow)
result = df.withColumn("running_total", sum(col("sales")).over(window_spec))

rows = result.orderBy("date").collect()
assert rows[0]["running_total"] == 100
assert rows[1]["running_total"] == 250
assert rows[3]["running_total"] == 625
print("✓ Koan 23 complete: Window running totals")

In [None]:
# Koan 24: Window Functions - Row Number
# Solution

data = [
    ("Sales", "Alice", 5000),
    ("Sales", "Bob", 5500),
    ("Engineering", "Charlie", 6000),
    ("Engineering", "Diana", 6500),
    ("Engineering", "Eve", 5500)
]
df = spark.createDataFrame(data, ["dept", "name", "salary"])

window_spec = Window.partitionBy("dept").orderBy(col("salary").desc())
result = df.withColumn("rank", row_number().over(window_spec))

eng = result.filter(col("dept") == "Engineering").orderBy("rank").collect()
assert eng[0]["name"] == "Diana"
assert eng[0]["rank"] == 1
assert eng[1]["name"] == "Charlie"
print("✓ Koan 24 complete: Row number window functions")

In [None]:
# Koan 25: Window Functions - Lag and Lead
# Solution

data = [
    ("2024-01-01", 100),
    ("2024-01-02", 150),
    ("2024-01-03", 120),
    ("2024-01-04", 200)
]
df = spark.createDataFrame(data, ["date", "price"])

window_spec = Window.orderBy("date")

result = df.withColumn("prev_price", lag("price", 1).over(window_spec))
result = result.withColumn("change", col("price") - col("prev_price"))

rows = result.orderBy("date").collect()
assert rows[0]["prev_price"] is None
assert rows[1]["prev_price"] == 100
assert rows[1]["change"] == 50

result2 = df.withColumn("next_price", lead("price", 1).over(window_spec))
rows2 = result2.orderBy("date").collect()
assert rows2[0]["next_price"] == 150
print("✓ Koan 25 complete: Lag and lead window functions")

## Null Handling

In [None]:
# Koan 26: Handling Nulls - Detection
# Solution

data = [("Alice", 34), ("Bob", None), ("Charlie", 29), (None, 45)]
df = spark.createDataFrame(data, ["name", "age"])

result = df.filter(col("age").isNotNull())
assert result.count() == 3

nulls = df.filter(col("age").isNull())
assert nulls.count() == 1

null_names = df.filter(col("name").isNull())
assert null_names.count() == 1
print("✓ Koan 26 complete: Null detection")

In [None]:
# Koan 27: Handling Nulls - Fill and Drop
# Solution

data = [("Alice", 34), ("Bob", None), (None, 29), ("Diana", None)]
df = spark.createDataFrame(data, ["name", "age"])

result = df.fillna(0, subset=["age"])
ages = [row["age"] for row in result.collect()]
assert None not in ages
assert ages.count(0) == 2

result2 = df.fillna("Unknown", subset=["name"])
names = [row["name"] for row in result2.collect()]
assert "Unknown" in names

result3 = df.dropna()
assert result3.count() == 1
print("✓ Koan 27 complete: Null handling")

## Advanced

In [None]:
# Koan 28: Union DataFrames
# Solution

df1 = spark.createDataFrame([("Alice", 34), ("Bob", 45)], ["name", "age"])
df2 = spark.createDataFrame([("Charlie", 29), ("Diana", 52)], ["name", "age"])

result = df1.union(df2)

assert result.count() == 4
names = [row["name"] for row in result.collect()]
assert "Alice" in names and "Charlie" in names
print("✓ Koan 28 complete: Union DataFrames")

In [None]:
# Koan 29: Explode Arrays
# Solution

data = [("Alice", "python,sql,spark"), ("Bob", "java,scala")]
df = spark.createDataFrame(data, ["name", "skills_str"])

df = df.withColumn("skills", split(col("skills_str"), ","))
result = df.select("name", explode(col("skills")).alias("skill"))

assert result.count() == 5
alice_skills = [row["skill"] for row in result.filter(col("name") == "Alice").collect()]
assert len(alice_skills) == 3
assert "spark" in alice_skills
print("✓ Koan 29 complete: Explode arrays")

In [None]:
# Koan 30: Pivot Tables
# Solution

data = [
    ("Alice", "Q1", 100), ("Alice", "Q2", 150),
    ("Bob", "Q1", 200), ("Bob", "Q2", 180)
]
df = spark.createDataFrame(data, ["name", "quarter", "sales"])

result = df.groupBy("name").pivot("quarter").agg(sum(col("sales")))

assert "Q1" in result.columns
assert "Q2" in result.columns
alice = result.filter(col("name") == "Alice").collect()[0]
assert alice["Q1"] == 100
assert alice["Q2"] == 150
print("✓ Koan 30 complete: Pivot tables")

## Summary

Congratulations! You've completed 30 core PySpark koans covering:
- DataFrame basics (create, select, filter, add/drop columns)
- Aggregations and grouping
- String functions and type casting
- Joins and window functions
- Null handling
- Advanced operations (union, explode, pivot)

Additional koans for Delta Lake (101-110), Unity Catalog (201-210), and Pandas API on Spark (301-310) require specialized environments and are documented in the full solutions notebook.