# README
This is a collection of pyspark exercises or little problems that helped me in my journey as a ML engineer and data scientist. We used SQL and pyspark intensely so knowing this little bits are important. I also include the accompanying datasets for you to play with.

## Datasets

In [3]:
import os 
os.makedirs('data', exist_ok=True)

# create transactions.csv
csv_content = """id,ts,user_id,item_id,qty,price
1,2025-09-13T09:15:00,101,1,2,19.99
2,2025-09-13T09:50:00,102,2,1,5.50
3,2025-09-13T10:05:00,101,1,3,19.99
4,2025-09-13T11:20:00,103,3,1,12.00
5,2025-09-13T12:05:00,102,4,5,3.00
6,2025-09-13T13:10:00,101,5,2,45.00
7,2025-09-13T16:45:00,103,2,0,5.50
8,2025-09-13T18:00:00,103,4,2,
9,2025-09-14T09:40:00,101,3,1,11.50
10,2025-09-14T10:15:00,102,1,1,19.99
11,2025-09-14T11:25:00,102,5,1,44.00
12,2025-09-14T12:30:00,103,1,2,20.49
13,2025-09-14T13:55:00,101,2,4,5.99
14,2025-09-14T15:05:00,102,4,3,3.50
15,2025-09-14T16:35:00,103,5,1,46.00
16,2025-09-13T07:05:00,104,2,2,5.60
17,2025-09-14T08:25:00,104,3,2,12.50
18,2025-09-14T19:10:00,104,1,1,21.00
19,2025-09-13T20:20:00,101,4,1,3.25
20,2025-09-14T21:45:00,102,2,2,5.40
"""

with open('data/transactions.csv', 'w') as f:
    f.write(csv_content)

# create users.csv 
csv_content = """user_id,city,signup_date,updated_at
101,Lima,2025-09-01,2025-09-10
102,Arequipa,2025-08-15,2025-09-09
103,Trujillo,2025-08-01,2025-09-11
104,Lima,2025-09-12,2025-09-12
102,Cusco,2025-08-15,2025-09-12
"""
with open('data/users.csv', 'w') as f:
    f.write(csv_content)

csv_content = """item_id,name,category,cost
1,Widget A,widgets,10.00
2,Gadget B,gadgets,3.00
3,Thing C,things,6.00
4,Accessory D,accessories,1.50
5,Premium E,premium,22.00
"""

with open('data/items.csv', 'w') as f:
    f.write(csv_content)

csv_content = """user_id,ts,search_query
101,2025-09-13T08:30:00,"Widget A discount"
102,2025-09-13T09:00:00," gadget  b  "
101,2025-09-13T10:00:00,"THing c blue"
103,2025-09-13T10:10:00,"premium e"
102,2025-09-14T09:00:00,"widget a bundle"
104,2025-09-14T10:00:00,"new accessories!"
103,2025-09-14T11:00:00,"gadget-b pro"
102,2025-09-14T12:00:00,"widget a"
"""

with open('data/searches.csv', 'w') as f:
    f.write(csv_content)

csv_content = """id,ts,user_id,item_id,qty,price
7,2025-09-13T16:45:00,103,2,1,5.50
8,2025-09-13T18:00:00,103,4,2,3.10
11,2025-09-14T11:25:00,102,5,1,45.00
"""

with open('data/corrections.csv', 'w') as f:
    f.write(csv_content)

## What if I want to practice pure SQL? 
After reading the csv you can run
```sql
# register as a SQL view
df.createOrReplaceTempView("transactions")
```
where `transactions` is an example name of the temporary view. Temporary means it only exists in this Spark sessions, in memory. You can now query it like any SQL table, e.g: 
```sql
spark.sql("""
    SELECT user_id, SUM(qty * price) AS total_revenue
    FROM transactions
    GROUP BY user_id
    ORDER BY total_revenue DESC
    LIMIT 5
""").show()




## Setup before running the PySpark env config cell on macOS

```bash
# 1. Create and activate a Conda environment
conda create -n sparkenv python=3.10
conda activate sparkenv

# 2. Install Java (JDK) inside the environment
conda install -c conda-forge openjdk

# 3. Install PySpark
conda install -c conda-forge pyspark
# or
pip install pyspark

# 4. Install Jupyter and register the kernel
conda install jupyter ipykernel
python -m ipykernel install --user --name=sparkenv

# 5. Launch Jupyter Notebook/Lab and select the `sparkenv` kernel
jupyter notebook


In [None]:
# import os, sys, subprocess

# # Point Spark to the JDK that conda installed in this very env
# os.environ["JAVA_HOME"] = sys.prefix
# os.environ["PATH"] = f"{sys.prefix}/bin:" + os.environ.get("PATH", "")

# # Make Spark use this exact Python (your bertenv kernel)
# os.environ["PYSPARK_PYTHON"] = sys.executable
# os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable

# # (Optional) Helps on some macOS setups
# os.environ.setdefault("SPARK_LOCAL_IP", "127.0.0.1")

# # Sanity check
# print("JAVA_HOME =", os.environ["JAVA_HOME"])
# print(subprocess.check_output([f"{os.environ['JAVA_HOME']}/bin/java","-version"], stderr=subprocess.STDOUT).decode())


JAVA_HOME = /home/kenyi/.pyenv/versions/3.11.8/envs/general


FileNotFoundError: [Errno 2] No such file or directory: '/home/kenyi/.pyenv/versions/3.11.8/envs/general/bin/java'

## Setup for Windows 

Install JDK in WSL
```
sudo apt update
sudo apt install openjdk-17-jdk -y
```

In [4]:
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-17-openjdk-amd64"
os.environ["PATH"] = f"{os.environ['JAVA_HOME']}/bin:" + os.environ.get("PATH", "")
print(subprocess.check_output([f"{os.environ['JAVA_HOME']}/bin/java", "-version"], stderr=subprocess.STDOUT).decode())

openjdk version "17.0.16" 2025-07-15
OpenJDK Runtime Environment (build 17.0.16+8-Ubuntu-0ubuntu124.04.1)
OpenJDK 64-Bit Server VM (build 17.0.16+8-Ubuntu-0ubuntu124.04.1, mixed mode, sharing)



This pip installation takes 5 mins

In [None]:
# %pip install pyspark


Collecting pyspark
  Downloading pyspark-4.0.1.tar.gz (434.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m434.2/434.2 MB[0m [31m465.2 kB/s[0m eta [36m0:00:00[0m00:01[0m00:04[0m
[?25h  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[?25hCollecting py4j==0.10.9.9 (from pyspark)
  Downloading py4j-0.10.9.9-py2.py3-none-any.whl.metadata (1.3 kB)
Downloading py4j-0.10.9.9-py2.py3-none-any.whl (203 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m203.0/203.0 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hBuilding wheels for collected packages: pyspark
  Building wheel for pyspark (pyproject.toml) ... [?25ldone
[?25h  Created wheel for pyspark: filename=pyspark-4.0.1-py2.py3-none-any.whl size=434813860 sha256=4a269a0f2753afc282cedf98b59d43c9130fba1847c3d94f2305ef02824904e6
  Stored in directory: /home/kenyi/.

## Create PySpark Client 


In [7]:
from pyspark.sql import SparkSession

spark = (
    SparkSession.builder
    .appName("MySession")
    .master("local[*]")   # local mode
    .getOrCreate()
)

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/09/17 19:49:14 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [8]:
from IPython.display import display as ipy_display

def display(sdf, n=20):
    """Databricks-like display for Spark DataFrames in Jupyter."""
    return ipy_display(sdf.limit(n).toPandas())


# Exc

### 1) 
Read & Inspect DataFrames

Task: Load transactions.csv with header + inferred schema; print schema; show first 5; report row/column counts.

Dataset: transactions.csv


In [9]:
trx_path = 'data/transactions.csv'
items_path = 'data/items.csv'
users_path = 'data/users.csv'
corr_path = 'data/corrections.csv'
searches_path = 'data/searches.csv'

# 1 read the csv
df = spark.read.csv(trx_path, header=True)
display(df)

                                                                                

Unnamed: 0,id,ts,user_id,item_id,qty,price
0,1,2025-09-13T09:15:00,101,1,2,19.99
1,2,2025-09-13T09:50:00,102,2,1,5.5
2,3,2025-09-13T10:05:00,101,1,3,19.99
3,4,2025-09-13T11:20:00,103,3,1,12.0
4,5,2025-09-13T12:05:00,102,4,5,3.0
5,6,2025-09-13T13:10:00,101,5,2,45.0
6,7,2025-09-13T16:45:00,103,2,0,5.5
7,8,2025-09-13T18:00:00,103,4,2,
8,9,2025-09-14T09:40:00,101,3,1,11.5
9,10,2025-09-14T10:15:00,102,1,1,19.99


In [10]:
df.printSchema

<bound method DataFrame.printSchema of DataFrame[id: string, ts: string, user_id: string, item_id: string, qty: string, price: string]>

In [11]:
df.show(5)

+---+-------------------+-------+-------+---+-----+
| id|                 ts|user_id|item_id|qty|price|
+---+-------------------+-------+-------+---+-----+
|  1|2025-09-13T09:15:00|    101|      1|  2|19.99|
|  2|2025-09-13T09:50:00|    102|      2|  1| 5.50|
|  3|2025-09-13T10:05:00|    101|      1|  3|19.99|
|  4|2025-09-13T11:20:00|    103|      3|  1|12.00|
|  5|2025-09-13T12:05:00|    102|      4|  5| 3.00|
+---+-------------------+-------+-------+---+-----+
only showing top 5 rows


In [12]:
print('number of rows is = ',df.count())
print('number of columns is = ',len(df.columns))

number of rows is =  20
number of columns is =  6


### 2)
Basic Column Ops

Task: Create amount = qty * price, rounded to 2 decimals; select (id, user_id, item_id, amount).

Dataset: transactions.csv

APIs: withColumn, expr/col, round, select

In [13]:
display(df,5)
print(df.printSchema)

Unnamed: 0,id,ts,user_id,item_id,qty,price
0,1,2025-09-13T09:15:00,101,1,2,19.99
1,2,2025-09-13T09:50:00,102,2,1,5.5
2,3,2025-09-13T10:05:00,101,1,3,19.99
3,4,2025-09-13T11:20:00,103,3,1,12.0
4,5,2025-09-13T12:05:00,102,4,5,3.0


<bound method DataFrame.printSchema of DataFrame[id: string, ts: string, user_id: string, item_id: string, qty: string, price: string]>


In [14]:
from pyspark.sql import functions as F
# let's be careful with the types of the columns, use cast
# with F.col
df_2 = df.withColumn('amount',F.format_number(F.col('qty').cast('float')*F.col('price').cast('float'),2)).select('id', 'user_id', 'item_id', 'amount')
display(df_2,5)

Unnamed: 0,id,user_id,item_id,amount
0,1,101,1,39.98
1,2,102,2,5.5
2,3,101,1,59.97
3,4,103,3,12.0
4,5,102,4,15.0


### 3) 
Filtering & Null Handling

Task: Remove qty <= 0; compute median price; fill null price with median; report how many rows changed.

Dataset: transactions.csv

APIs: filter/where, percentile_approx, na.fill

In [15]:
display(df,10)

Unnamed: 0,id,ts,user_id,item_id,qty,price
0,1,2025-09-13T09:15:00,101,1,2,19.99
1,2,2025-09-13T09:50:00,102,2,1,5.5
2,3,2025-09-13T10:05:00,101,1,3,19.99
3,4,2025-09-13T11:20:00,103,3,1,12.0
4,5,2025-09-13T12:05:00,102,4,5,3.0
5,6,2025-09-13T13:10:00,101,5,2,45.0
6,7,2025-09-13T16:45:00,103,2,0,5.5
7,8,2025-09-13T18:00:00,103,4,2,
8,9,2025-09-14T09:40:00,101,3,1,11.5
9,10,2025-09-14T10:15:00,102,1,1,19.99


In [16]:
# filter out qty < 0
df_x = df.filter('qty>0')
display(df_x)
# compute median
median = df.select(F.median('price'))
display(median,10)
print(type(median))
# fill null prices with median
df_x = df_x.fillna({'price':median.collect()[0][0]})
display(df_x,10)

Unnamed: 0,id,ts,user_id,item_id,qty,price
0,1,2025-09-13T09:15:00,101,1,2,19.99
1,2,2025-09-13T09:50:00,102,2,1,5.5
2,3,2025-09-13T10:05:00,101,1,3,19.99
3,4,2025-09-13T11:20:00,103,3,1,12.0
4,5,2025-09-13T12:05:00,102,4,5,3.0
5,6,2025-09-13T13:10:00,101,5,2,45.0
6,8,2025-09-13T18:00:00,103,4,2,
7,9,2025-09-14T09:40:00,101,3,1,11.5
8,10,2025-09-14T10:15:00,102,1,1,19.99
9,11,2025-09-14T11:25:00,102,5,1,44.0


Unnamed: 0,median(price)
0,12.0


<class 'pyspark.sql.classic.dataframe.DataFrame'>


Unnamed: 0,id,ts,user_id,item_id,qty,price
0,1,2025-09-13T09:15:00,101,1,2,19.99
1,2,2025-09-13T09:50:00,102,2,1,5.5
2,3,2025-09-13T10:05:00,101,1,3,19.99
3,4,2025-09-13T11:20:00,103,3,1,12.0
4,5,2025-09-13T12:05:00,102,4,5,3.0
5,6,2025-09-13T13:10:00,101,5,2,45.0
6,8,2025-09-13T18:00:00,103,4,2,12.0
7,9,2025-09-14T09:40:00,101,3,1,11.5
8,10,2025-09-14T10:15:00,102,1,1,19.99
9,11,2025-09-14T11:25:00,102,5,1,44.0


In [17]:
print('number of changed rows is ',df.filter('price is NULL').count()) 

number of changed rows is  1


### 4) 
GroupBy Aggregations

Task: Compute total revenue = sum(qty*price) and total qty per item_id; sort by revenue desc; top 10.

Dataset: transactions.csv

APIs: groupBy, agg, sum, orderBy, limit

In [18]:
df = spark.read.csv(trx_path,header=True, inferSchema=True)
total_revenue = df.agg(F.sum(F.col('price')*F.col('qty')).alias('total_revenue'))
display(total_revenue)

Unnamed: 0,total_revenue
0,490.63


In [19]:
display(df,20)

Unnamed: 0,id,ts,user_id,item_id,qty,price
0,1,2025-09-13 09:15:00,101,1,2,19.99
1,2,2025-09-13 09:50:00,102,2,1,5.5
2,3,2025-09-13 10:05:00,101,1,3,19.99
3,4,2025-09-13 11:20:00,103,3,1,12.0
4,5,2025-09-13 12:05:00,102,4,5,3.0
5,6,2025-09-13 13:10:00,101,5,2,45.0
6,7,2025-09-13 16:45:00,103,2,0,5.5
7,8,2025-09-13 18:00:00,103,4,2,
8,9,2025-09-14 09:40:00,101,3,1,11.5
9,10,2025-09-14 10:15:00,102,1,1,19.99


In [20]:
# total renevie and total qty per item_id
df_x = df.groupBy('item_id').agg(F.sum(F.col('qty')*F.col('price')).alias('sum_item').alias('total_revenue')).orderBy('total_revenue',ascending=False).limit(10)
display(df_x)

                                                                                

Unnamed: 0,item_id,total_revenue
0,1,181.92
1,5,180.0
2,2,51.46
3,3,48.5
4,4,28.75


### 5) 

Window Functions (Ranking)

Task: Within each user_id, rank purchases by amount desc; keep top 3 per user.

Dataset: transactions.csv

APIs: Window.partitionBy, row_number/dense_rank, orderBy, filter

In [21]:
display(df,5)

Unnamed: 0,id,ts,user_id,item_id,qty,price
0,1,2025-09-13 09:15:00,101,1,2,19.99
1,2,2025-09-13 09:50:00,102,2,1,5.5
2,3,2025-09-13 10:05:00,101,1,3,19.99
3,4,2025-09-13 11:20:00,103,3,1,12.0
4,5,2025-09-13 12:05:00,102,4,5,3.0


In [22]:
from pyspark.sql import Window
from pyspark.sql import functions as F

win = Window.partitionBy('user_id').orderBy(F.col('qty').desc())

df_x = df.withColumn('rank_purchase_by_amount',F.row_number().over(win)).filter(F.col('rank_purchase_by_amount')<=3)

display(df_x)

                                                                                

Unnamed: 0,id,ts,user_id,item_id,qty,price,rank_purchase_by_amount
0,13,2025-09-14 13:55:00,101,2,4,5.99,1
1,3,2025-09-13 10:05:00,101,1,3,19.99,2
2,1,2025-09-13 09:15:00,101,1,2,19.99,3
3,5,2025-09-13 12:05:00,102,4,5,3.0,1
4,14,2025-09-14 15:05:00,102,4,3,3.5,2
5,20,2025-09-14 21:45:00,102,2,2,5.4,3
6,8,2025-09-13 18:00:00,103,4,2,,1
7,12,2025-09-14 12:30:00,103,1,2,20.49,2
8,4,2025-09-13 11:20:00,103,3,1,12.0,3
9,16,2025-09-13 07:05:00,104,2,2,5.6,1


### 6)
Time Bucketing
Task: Cast ts to timestamp; create date and hour; compute hourly revenue per day.
Dataset: transactions.csv
APIs: to_timestamp, date_trunc, hour, groupBy, agg

### 7)

Joins & Deduplication
Task: On users.csv, keep the latest row per user_id by updated_at; join to transactions and report rows lost/gained vs inner/left join.
Datasets: transactions.csv, users.csv
APIs: Window + row_number, dropDuplicates (alt), join

### 8) 

Complex Aggregations (Rollups)
Task: Revenue by (city, date) plus city totals and grand total; label rollup levels using grouping_id.
Datasets: transactions.csv + deduped users.csv
APIs: rollup/cube, grouping_id, agg, orderBy

### 9) 

String Ops & Regex
Task: Normalize search_query (lower, trim, collapse spaces); split tokens; keep rows matching regex for “gadget” variants (gadget[- ]?b).
Dataset: searches.csv
APIs: lower, trim, regexp_replace, split, rlike

### 10) 

UDF vs Built-ins
Task: Bucketize amount into ["low","mid","high"] via a Python UDF; then reimplement with when/otherwise; compare execution time.
Dataset: transactions.csv
APIs: udf, when/otherwise, cache, simple timing (time or spark.time)

### 11)

Pandas UDF (Vectorized)
Task: Per user_id, compute count, mean amount, and coefficient of variation using a Pandas grouped apply; return a Spark DF.
Dataset: transactions.csv
APIs: pandas_udf or applyInPandas, schema

### 12)

Reading/Writing Efficiently
Task: Add date column; write partitioned Parquet by date; read just one date and prove partition pruning with explain.
Dataset: cleaned transactions
APIs: write.partitionBy, parquet, spark.read.parquet, explain

### 13) 

Skew & Broadcast
Task: Join large transactions to small items; show plan and timing before/after broadcast(items).
Datasets: transactions.csv, items.csv
APIs: broadcast, join, spark.conf.set("spark.sql.autoBroadcastJoinThreshold", ...)

### 14)

Caching & Checkpointing
Task: Build an expensive aggregation; time it uncached vs cached; set checkpoint dir, checkpoint the DF, and show lineage reduction in explain.
Dataset: transactions.csv
APIs: cache/persist, count (materialize), setCheckpointDir, checkpoint, explain

### 15)

SQL + Temp Views
Task: Create temp views; in SQL compute per-date: DAU (distinct user_id), conversion proxy (% users with amount>0), and ARPU (revenue/DAU).
Datasets: transactions.csv (+ optionally users.csv)
APIs: createOrReplaceTempView, spark.sql