In [None]:
# imports
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, avg, count, sum, max, min, when, lit
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, FloatType

## Creating a SparkSession

The SparkSession is the entry point for all Spark functionality.

In [None]:
# create spark session
spark = SparkSession.builder \
    .appName("PySpark Introduction") \
    .master("local[*]") \
    .getOrCreate()

print(f"Spark version: {spark.version}")

## Loading Data

We'll work with a dataset of data science salaries.

In [None]:
# load the salaries dataset
df = spark.read.csv("Data/salaries.csv", header=True, inferSchema=True)

# show first few rows
df.show(5)

In [None]:
# check schema
df.printSchema()

In [None]:
# basic stats
print(f"Total rows: {df.count()}")
print(f"Columns: {len(df.columns)}")
df.describe().show()

## Basic Transformations

### Select and Filter

In [None]:
# select specific columns
df.select("job_title", "salary_in_usd", "experience_level").show(5)

In [None]:
# filter rows - high salary jobs
high_salary = df.filter(col("salary_in_usd") > 150000)
print(f"Jobs with salary > $150k: {high_salary.count()}")
high_salary.select("job_title", "salary_in_usd", "company_location").show(5)

In [None]:
# multiple conditions
senior_remote = df.filter(
    (col("experience_level") == "SE") & 
    (col("remote_ratio") == 100)
)
print(f"Senior fully remote jobs: {senior_remote.count()}")

### Adding and Modifying Columns

In [None]:
# add a new column
df_with_monthly = df.withColumn("monthly_salary_usd", col("salary_in_usd") / 12)
df_with_monthly.select("job_title", "salary_in_usd", "monthly_salary_usd").show(5)

In [None]:
# conditional column with when
df_salary_level = df.withColumn(
    "salary_category",
    when(col("salary_in_usd") < 50000, "Low")
    .when(col("salary_in_usd") < 100000, "Medium")
    .when(col("salary_in_usd") < 150000, "High")
    .otherwise("Very High")
)
df_salary_level.select("job_title", "salary_in_usd", "salary_category").show(10)

## Aggregations

Grouping and summarizing data.

In [None]:
# average salary by experience level
df.groupBy("experience_level") \
    .agg(avg("salary_in_usd").alias("avg_salary")) \
    .orderBy("avg_salary", ascending=False) \
    .show()

In [None]:
# multiple aggregations
df.groupBy("experience_level").agg(
    count("*").alias("count"),
    avg("salary_in_usd").alias("avg_salary"),
    min("salary_in_usd").alias("min_salary"),
    max("salary_in_usd").alias("max_salary")
).orderBy("avg_salary", ascending=False).show()

In [None]:
# top job titles by count
df.groupBy("job_title") \
    .count() \
    .orderBy("count", ascending=False) \
    .show(10)

## SQL Queries

Spark allows you to run SQL queries directly on DataFrames.

In [None]:
# register dataframe as a temp view
df.createOrReplaceTempView("salaries")

In [None]:
# run sql query
result = spark.sql("""
    SELECT 
        job_title,
        COUNT(*) as job_count,
        ROUND(AVG(salary_in_usd), 2) as avg_salary
    FROM salaries
    WHERE experience_level = 'SE'
    GROUP BY job_title
    HAVING COUNT(*) >= 5
    ORDER BY avg_salary DESC
    LIMIT 10
""")
result.show()

In [None]:
# salary trends by year
yearly_trend = spark.sql("""
    SELECT 
        work_year,
        COUNT(*) as total_jobs,
        ROUND(AVG(salary_in_usd), 2) as avg_salary
    FROM salaries
    GROUP BY work_year
    ORDER BY work_year
""")
yearly_trend.show()

## Remote Work Analysis

In [None]:
# remote work distribution
remote_analysis = spark.sql("""
    SELECT 
        CASE 
            WHEN remote_ratio = 0 THEN 'On-site'
            WHEN remote_ratio = 50 THEN 'Hybrid'
            WHEN remote_ratio = 100 THEN 'Fully Remote'
        END as work_type,
        COUNT(*) as count,
        ROUND(AVG(salary_in_usd), 2) as avg_salary
    FROM salaries
    GROUP BY remote_ratio
    ORDER BY count DESC
""")
remote_analysis.show()

## Saving Data

In [None]:
# example: save aggregated results (commented to avoid creating files)
# df.groupBy("job_title").count().write.csv("output/job_counts", header=True, mode="overwrite")
print("To save: df.write.csv('path', header=True)")
print("Or parquet: df.write.parquet('path')")

## Cleanup

In [None]:
# stop spark session
spark.stop()
print("Spark session stopped")

## Summary

### Key Concepts Covered

- **SparkSession**: Entry point for Spark applications
- **DataFrames**: Distributed collection of data organized into named columns
- **Transformations**: select, filter, withColumn, when
- **Aggregations**: groupBy, agg, count, avg, min, max
- **SQL**: createOrReplaceTempView, spark.sql()

### When to Use Spark

- Data doesn't fit in memory on a single machine
- Need distributed processing across a cluster
- Working with big data pipelines (ETL)
- Integration with data lakes and warehouses