# PySpark Playground



**Import libraries, create spark session**

In [5]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, avg, count, when

# create spark session
spark = SparkSession.builder.appName("pyspark_playground").getOrCreate()

**Load data, check schema and first few rows**

In [6]:
# load csv file with headers and schema inferred
df = spark.read.csv("data/data.csv", header=True, inferSchema=True)

# check schema like pandas df.info()
df.printSchema()

# show first 5 rows like pandas df.head()
df.show(5)

root
 |-- InvoiceNo: string (nullable = true)
 |-- StockCode: string (nullable = true)
 |-- Description: string (nullable = true)
 |-- Quantity: integer (nullable = true)
 |-- InvoiceDate: string (nullable = true)
 |-- UnitPrice: double (nullable = true)
 |-- CustomerID: integer (nullable = true)
 |-- Country: string (nullable = true)

+---------+---------+--------------------+--------+--------------+---------+----------+--------------+
|InvoiceNo|StockCode|         Description|Quantity|   InvoiceDate|UnitPrice|CustomerID|       Country|
+---------+---------+--------------------+--------+--------------+---------+----------+--------------+
|   536365|   85123A|WHITE HANGING HEA...|       6|12/1/2010 8:26|     2.55|     17850|United Kingdom|
|   536365|    71053| WHITE METAL LANTERN|       6|12/1/2010 8:26|     3.39|     17850|United Kingdom|
|   536365|   84406B|CREAM CUPID HEART...|       8|12/1/2010 8:26|     2.75|     17850|United Kingdom|
|   536365|   84029G|KNITTED UNION FLA...|  

**Performing operations**

In [7]:
# select specific columns
df.select("InvoiceNo", "Description").show(5)

+---------+--------------------+
|InvoiceNo|         Description|
+---------+--------------------+
|   536365|WHITE HANGING HEA...|
|   536365| WHITE METAL LANTERN|
|   536365|CREAM CUPID HEART...|
|   536365|KNITTED UNION FLA...|
|   536365|RED WOOLLY HOTTIE...|
+---------+--------------------+
only showing top 5 rows



In [8]:
# filter rows where Quantity > 100
filtered_df = df.filter(col("Quantity") > 100)
filtered_df.show(5)

+---------+---------+--------------------+--------+--------------+---------+----------+--------------+
|InvoiceNo|StockCode|         Description|Quantity|   InvoiceDate|UnitPrice|CustomerID|       Country|
+---------+---------+--------------------+--------+--------------+---------+----------+--------------+
|   536378|    21212|PACK OF 72 RETROS...|     120|12/1/2010 9:37|     0.42|     14688|United Kingdom|
|   536387|    79321|       CHILLI LIGHTS|     192|12/1/2010 9:58|     3.82|     16029|United Kingdom|
|   536387|    22780|LIGHT GARLAND BUT...|     192|12/1/2010 9:58|     3.37|     16029|United Kingdom|
|   536387|    22779|WOODEN OWLS LIGHT...|     192|12/1/2010 9:58|     3.37|     16029|United Kingdom|
|   536387|    22466|FAIRY TALE COTTAG...|     432|12/1/2010 9:58|     1.45|     16029|United Kingdom|
+---------+---------+--------------------+--------+--------------+---------+----------+--------------+
only showing top 5 rows



In [9]:
# group by Country and calculate avg for UnitPrice
grouped_df = df.groupBy("Country").agg(avg("UnitPrice").alias("average_unit_price"))
grouped_df.show()

+--------------------+------------------+
|             Country|average_unit_price|
+--------------------+------------------+
|              Sweden|3.9108874458874445|
|           Singapore|109.64580786026201|
|             Germany|3.9669299631384733|
|              France| 5.028864087881271|
|              Greece| 4.885547945205479|
|             Belgium|3.6443354277428774|
|             Finland| 5.448705035971225|
|               Italy| 4.831120797011212|
|                EIRE| 5.911077354807258|
|           Lithuania|2.8411428571428576|
|              Norway|  6.01202578268877|
|               Spain| 4.987544413738649|
|             Denmark|  3.25694087403599|
|           Hong Kong| 42.50520833333331|
|             Iceland| 2.644010989010989|
|              Israel| 3.633131313131316|
|     Channel Islands| 4.932124010554093|
|              Cyprus| 6.302363344051452|
|         Switzerland| 3.403441558441565|
|United Arab Emirates|3.3807352941176467|
+--------------------+------------

**Handle missing data**

In [None]:
# fill missing values in CustomerID with 0
df_filled = df.fillna({"CustomerID": 0})
df_filled.show(5)

**Adding new columns**

In [None]:
# create new column based on condition (is UnitPrice > 2.5)
df_with_new_col = df.withColumn("is_expensive", when(col("UnitPrice") > 3, True).otherwise(False))
df_with_new_col.show(5)