In [None]:
# Imports
#"Col" function allows operations on the columns of the pyspark dataframe
from pyspark.sql import SparkSession
from pyspark.sql.functions import col  
import pandas as pd
from pyspark.sql.functions import udf, max, col, expr
from pyspark.sql.window import Window

In [2]:
import findspark 
findspark.init()  

In [3]:
# Create a Spark session
spark = (
    SparkSession
    .builder
    .master("local")
    .appName("PySpark project")
    .getOrCreate()
)

In [4]:
# Print Spark version
print(spark.version)

3.4.2


In [5]:

# Read consumer data from csv file
consumers_csv = "../data/consumers.csv"
consumers = spark\
        .read.format("csv")\
        .option("inferSchema", "True")\
        .option("header", "True")\
        .csv(consumers_csv)

In [6]:
# Check dataframe shape
print((consumers.count(), len(consumers.columns)))

(138, 14)


In [7]:
# Check number of partitions
consumers.rdd.getNumPartitions()

1

In [8]:
# Show first 10 lines of dataframe
consumers.show(10)

+-----------+---------------+---------------+-------+---------+-----------+------+--------------+---------------------+--------------+-----------+---+----------+------+
|Consumer_ID|           City|          State|Country| Latitude|  Longitude|Smoker|   Drink_Level|Transportation_Method|Marital_Status|   Children|Age|Occupation|Budget|
+-----------+---------------+---------------+-------+---------+-----------+------+--------------+---------------------+--------------+-----------+---+----------+------+
|      U1001|San Luis Potosi|San Luis Potosi| Mexico|22.139997|-100.978803|    No|    Abstemious|              On Foot|        Single|Independent| 23|   Student|Medium|
|      U1002|San Luis Potosi|San Luis Potosi| Mexico|22.150087|-100.983325|    No|    Abstemious|               Public|        Single|Independent| 22|   Student|   Low|
|      U1003|San Luis Potosi|San Luis Potosi| Mexico|22.119847|-100.946527|    No|Social Drinker|               Public|        Single|Independent| 23|   St

In [9]:

# Check columns types
consumers.printSchema()

root
 |-- Consumer_ID: string (nullable = true)
 |-- City: string (nullable = true)
 |-- State: string (nullable = true)
 |-- Country: string (nullable = true)
 |-- Latitude: double (nullable = true)
 |-- Longitude: double (nullable = true)
 |-- Smoker: string (nullable = true)
 |-- Drink_Level: string (nullable = true)
 |-- Transportation_Method: string (nullable = true)
 |-- Marital_Status: string (nullable = true)
 |-- Children: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Occupation: string (nullable = true)
 |-- Budget: string (nullable = true)



In [10]:
# Show distinct values of a column
consumers.select('City').distinct().show()

+---------------+
|           City|
+---------------+
|     Cuernavaca|
|Ciudad Victoria|
|       Jiutepec|
|San Luis Potosi|
+---------------+



In [11]:
# Show distinct values of a column and sort
consumers.select('Age').distinct().orderBy('Age',ascending=False).show()

+---+
|Age|
+---+
| 82|
| 72|
| 69|
| 60|
| 45|
| 43|
| 33|
| 31|
| 30|
| 29|
| 28|
| 27|
| 26|
| 25|
| 24|
| 23|
| 22|
| 21|
| 20|
| 19|
+---+
only showing top 20 rows



In [12]:
# Rename columns
consumers = consumers.withColumnRenamed("Smoker", "Smoking")

In [13]:
# Check again the schema, to see the new column name
consumers.printSchema()

root
 |-- Consumer_ID: string (nullable = true)
 |-- City: string (nullable = true)
 |-- State: string (nullable = true)
 |-- Country: string (nullable = true)
 |-- Latitude: double (nullable = true)
 |-- Longitude: double (nullable = true)
 |-- Smoking: string (nullable = true)
 |-- Drink_Level: string (nullable = true)
 |-- Transportation_Method: string (nullable = true)
 |-- Marital_Status: string (nullable = true)
 |-- Children: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- Occupation: string (nullable = true)
 |-- Budget: string (nullable = true)



In [14]:
# Create Temporary View on top of the pyspark dataframe. Then you can use SQL to query this view.
consumers.createOrReplaceTempView("consumers_temp")

In [15]:
# Run a query: show how many consumers we have per city
consumers_per_city = spark.sql("Select City as City, Count(*) AS Count from consumers_temp group by City")

#show output
consumers_per_city.show()

+---------------+-----+
|           CITY|COUNT|
+---------------+-----+
|     Cuernavaca|   22|
|Ciudad Victoria|   25|
|       Jiutepec|    5|
|San Luis Potosi|   86|
+---------------+-----+



Sometimes, you might want to use Pandas methods on your dataset. You have the option to convert your Pyspark dataframe to a Pandas dataframe.

In [16]:
# Use toPandas() method
consumers_pandas  = consumers.toPandas()
consumers_pandas.head(10)

Unnamed: 0,Consumer_ID,City,State,Country,Latitude,Longitude,Smoking,Drink_Level,Transportation_Method,Marital_Status,Children,Age,Occupation,Budget
0,U1001,San Luis Potosi,San Luis Potosi,Mexico,22.139997,-100.978803,No,Abstemious,On Foot,Single,Independent,23,Student,Medium
1,U1002,San Luis Potosi,San Luis Potosi,Mexico,22.150087,-100.983325,No,Abstemious,Public,Single,Independent,22,Student,Low
2,U1003,San Luis Potosi,San Luis Potosi,Mexico,22.119847,-100.946527,No,Social Drinker,Public,Single,Independent,23,Student,Low
3,U1004,Cuernavaca,Morelos,Mexico,18.867,-99.183,No,Abstemious,Public,Single,Independent,72,Employed,Medium
4,U1005,San Luis Potosi,San Luis Potosi,Mexico,22.183477,-100.959891,No,Abstemious,Public,Single,Independent,20,Student,Medium
5,U1006,San Luis Potosi,San Luis Potosi,Mexico,22.15,-100.983,Yes,Social Drinker,Car,Single,Independent,23,Student,Medium
6,U1007,San Luis Potosi,San Luis Potosi,Mexico,22.118464,-100.938256,No,Casual Drinker,Public,Single,Independent,23,Student,Low
7,U1008,San Luis Potosi,San Luis Potosi,Mexico,22.122989,-100.923811,No,Social Drinker,Public,Single,Independent,23,Student,Low
8,U1009,San Luis Potosi,San Luis Potosi,Mexico,22.159427,-100.990448,No,Abstemious,On Foot,Single,Kids,21,Student,Medium
9,U1010,San Luis Potosi,San Luis Potosi,Mexico,22.190889,-100.998669,No,Social Drinker,Car,Married,Kids,25,Student,Medium


The opposite is also possible:

In [17]:
# Use createDataFrame()
consumers_pyspark = spark.createDataFrame(consumers_pandas)
consumers_pyspark.show(10)

+-----------+---------------+---------------+-------+---------+-----------+-------+--------------+---------------------+--------------+-----------+---+----------+------+
|Consumer_ID|           City|          State|Country| Latitude|  Longitude|Smoking|   Drink_Level|Transportation_Method|Marital_Status|   Children|Age|Occupation|Budget|
+-----------+---------------+---------------+-------+---------+-----------+-------+--------------+---------------------+--------------+-----------+---+----------+------+
|      U1001|San Luis Potosi|San Luis Potosi| Mexico|22.139997|-100.978803|     No|    Abstemious|              On Foot|        Single|Independent| 23|   Student|Medium|
|      U1002|San Luis Potosi|San Luis Potosi| Mexico|22.150087|-100.983325|     No|    Abstemious|               Public|        Single|Independent| 22|   Student|   Low|
|      U1003|San Luis Potosi|San Luis Potosi| Mexico|22.119847|-100.946527|     No|Social Drinker|               Public|        Single|Independent| 23

Similar to creating a TempView and using "SELECT" statements, you can use the .select() method of the pyspark dataframes, to just show a subset of all the Columns

In [18]:
# Select() example
consumers_subset = consumers.select("Smoking","Drink_Level").show(10)

+-------+--------------+
|Smoking|   Drink_Level|
+-------+--------------+
|     No|    Abstemious|
|     No|    Abstemious|
|     No|Social Drinker|
|     No|    Abstemious|
|     No|    Abstemious|
|    Yes|Social Drinker|
|     No|Casual Drinker|
|     No|Social Drinker|
|     No|    Abstemious|
|     No|Social Drinker|
+-------+--------------+
only showing top 10 rows



In [19]:
# Filter
consumers.filter(consumers.Budget == 'Low').show(3)

+-----------+---------------+---------------+-------+---------+-----------+-------+--------------+---------------------+--------------+-----------+---+----------+------+
|Consumer_ID|           City|          State|Country| Latitude|  Longitude|Smoking|   Drink_Level|Transportation_Method|Marital_Status|   Children|Age|Occupation|Budget|
+-----------+---------------+---------------+-------+---------+-----------+-------+--------------+---------------------+--------------+-----------+---+----------+------+
|      U1002|San Luis Potosi|San Luis Potosi| Mexico|22.150087|-100.983325|     No|    Abstemious|               Public|        Single|Independent| 22|   Student|   Low|
|      U1003|San Luis Potosi|San Luis Potosi| Mexico|22.119847|-100.946527|     No|Social Drinker|               Public|        Single|Independent| 23|   Student|   Low|
|      U1007|San Luis Potosi|San Luis Potosi| Mexico|22.118464|-100.938256|     No|Casual Drinker|               Public|        Single|Independent| 23

In [20]:
# Aggregations, e.g. sum, avg, count, min, max

# For example, find the younger age among the people with high budget
consumers.filter(consumers.Budget == 'High').groupBy().min("Age").show()

+--------+
|min(Age)|
+--------+
|      21|
+--------+



In [21]:
#Joins

#bring in restaurants data from csv file
restaurants_csv = "../data/restaurants.csv"
restaurants = spark\
        .read.format("csv")\
        .option("inferSchema", "True")\
        .option("header", "True")\
        .csv(restaurants_csv)

restaurants.show(5)

+-------------+------------------+---------------+---------------+-------+--------+----------+------------+---------------+---------------+-----+---------+------+-------+
|Restaurant_ID|              Name|           City|          State|Country|Zip_Code|  Latitude|   Longitude|Alcohol_Service|Smoking_Allowed|Price|Franchise|  Area|Parking|
+-------------+------------------+---------------+---------------+-------+--------+----------+------------+---------------+---------------+-----+---------+------+-------+
|       132560|Puesto de Gorditas|Ciudad Victoria|     Tamaulipas| Mexico|    null|23.7523041| -99.1669133|           None|            Yes|  Low|       No|  Open| Public|
|       132561|        Cafe Ambar|Ciudad Victoria|     Tamaulipas| Mexico|    null| 23.726819| -99.1265059|           None|             No|  Low|       No|Closed|   None|
|       132564|          Church's|Ciudad Victoria|     Tamaulipas| Mexico|    null|23.7309245| -99.1451848|           None|             No|  Low|

In [22]:
# Read ratings data from csv file
ratings_csv = "../data/ratings.csv"
ratings = spark\
        .read.format("csv")\
        .option("inferSchema", "True")\
        .option("header", "True")\
        .csv(ratings_csv)

ratings.show(5)

+-----------+-------------+--------------+-----------+--------------+
|Consumer_ID|Restaurant_ID|Overall_Rating|Food_Rating|Service_Rating|
+-----------+-------------+--------------+-----------+--------------+
|      U1077|       135085|             2|          2|             2|
|      U1077|       135038|             2|          2|             1|
|      U1077|       132825|             2|          2|             2|
|      U1077|       135060|             1|          2|             2|
|      U1068|       135104|             1|          1|             2|
+-----------+-------------+--------------+-----------+--------------+
only showing top 5 rows



In [23]:
# Perform Left join between consumers with ratings and restaurants. 
# Same logic applies for inner, right, etc.
joined_df = consumers.join(ratings, consumers['Consumer_ID']==ratings['Consumer_ID'], how="left") \
                     .join (restaurants, ratings['Restaurant_ID']==restaurants['Restaurant_ID'], how="left")

# Select only the relevant columns that you want to present, with unambiguous reference, and renaming on the spot
#with the alias method.
joined_df_selection = joined_df.select(restaurants["Name"].alias("Restaurant_Name"), \
                                ratings['Overall_Rating'], consumers['Consumer_ID'])
joined_df_selection.show(10)

+--------------------+--------------+-----------+
|     Restaurant_Name|Overall_Rating|Consumer_ID|
+--------------------+--------------+-----------+
|Restaurante Versa...|             1|      U1001|
|El Rincon De San ...|             2|      U1001|
|Restaurant El Mul...|             1|      U1001|
|Restaurante La Gr...|             1|      U1001|
|Restaurant De Mar...|             1|      U1001|
|Restaurant Los Co...|             1|      U1001|
|Tortas Locas Hipo...|             0|      U1001|
|     Puesto De Tacos|             2|      U1001|
|     Rincon Huasteco|             1|      U1001|
|Tortas Locas Hipo...|             1|      U1002|
+--------------------+--------------+-----------+
only showing top 10 rows



In [24]:
# To show the "full column", you can use the truncate=False option.
joined_df_selection.orderBy("Restaurant_Name").show(20,truncate=False)

+-------------------------+--------------+-----------+
|Restaurant_Name          |Overall_Rating|Consumer_ID|
+-------------------------+--------------+-----------+
|Abondance Restaurante Bar|1             |U1008      |
|Abondance Restaurante Bar|1             |U1014      |
|Abondance Restaurante Bar|0             |U1018      |
|Abondance Restaurante Bar|0             |U1064      |
|Abondance Restaurante Bar|0             |U1069      |
|Abondance Restaurante Bar|1             |U1081      |
|Abondance Restaurante Bar|1             |U1088      |
|Abondance Restaurante Bar|0             |U1094      |
|Abondance Restaurante Bar|0             |U1105      |
|Abondance Restaurante Bar|1             |U1115      |
|Abondance Restaurante Bar|1             |U1124      |
|Abondance Restaurante Bar|0             |U1126      |
|Arrachela Grill          |0             |U1030      |
|Arrachela Grill          |1             |U1072      |
|Arrachela Grill          |2             |U1117      |
|Cabana Hu

In [25]:
# Write to csv output. You can change the writing mode (e.g. "overwrite").
# Notice that the output is a Folder, and not a single File.

joined_df_selection.write.format("csv").mode('overwrite').save("joined_df_selection_multiple.csv")

In [26]:
# Write to csv output, in one partition/part - coalesce(1)
joined_df_selection.coalesce(1).write.format("csv").mode('overwrite').save("joined_df_selection.csv")

In [27]:
# Pivoting:  Data transformation technique that converts rows into columns.
# E.g. what if we want to pivot the above dataframe, and keep Restaurants in the rows, but 
# all different Rating Values in the columns, and count how many occurences we had for each? 

# Put in the pivot() parameter the column you wish to pivot.

In [28]:
pivot_df = joined_df_selection.groupBy("Restaurant_Name").pivot("Overall_Rating").count()

In [29]:
pivot_df.orderBy("Restaurant_Name").show(truncate=False)

+-------------------------------------+----+----+----+
|Restaurant_Name                      |0   |1   |2   |
+-------------------------------------+----+----+----+
|Abondance Restaurante Bar            |6   |6   |null|
|Arrachela Grill                      |1   |1   |1   |
|Cabana Huasteca                      |2   |3   |8   |
|Cafe Ambar                           |1   |3   |null|
|Cafe Chaires                         |3   |9   |3   |
|Cafe Punta Del Cielo                 |null|1   |5   |
|Cafeteria Cenidet                    |1   |4   |1   |
|Cafeteria Y Restaurant El Pacifico   |7   |9   |12  |
|Carl's Jr                            |null|4   |3   |
|Carnitas Mata                        |1   |3   |2   |
|Carnitas Mata  Calle 16 de Septiembre|3   |1   |null|
|Carnitas Mata Calle Emilio Portes Gil|1   |1   |3   |
|Carreton De Flautas Y Migadas        |3   |4   |1   |
|Cenaduria El Rincón De Tlaquepaque   |2   |2   |1   |
|Chaires                              |1   |1   |3   |
|Chilis Cu

In [30]:
# User Defined Functions (UDFs)

# Create a function to capitalize a string
def capitalize(x):
    return x.upper()

In [31]:
# Convert the Python function to UDF

upper_udf = udf(lambda z: capitalize(z))

In [32]:
# Apply the UDF to the Dataframe, to capitalize a certain column
consumers.select('Consumer_ID',upper_udf(col("Occupation"))).show(5)

+-----------+--------------------+
|Consumer_ID|<lambda>(Occupation)|
+-----------+--------------------+
|      U1001|             STUDENT|
|      U1002|             STUDENT|
|      U1003|             STUDENT|
|      U1004|            EMPLOYED|
|      U1005|             STUDENT|
+-----------+--------------------+
only showing top 5 rows



In [33]:
# Window functions - example with Max
# Let's find the maximum rating given by any customer. In SQL, you would write it like:
# Select Consumer_ID, max(Overall_Rating) over(partition by Consumer_ID) as max_salary from emp_unique

window_spec = Window.partitionBy(col("Consumer_ID"))
window_df = ratings.withColumn("MaxRatingPerConsumer", max(col("Overall_Rating")).over(window_spec))
window_df.select("Consumer_ID", "Overall_Rating", "MaxRatingPerConsumer").show()

+-----------+--------------+--------------------+
|Consumer_ID|Overall_Rating|MaxRatingPerConsumer|
+-----------+--------------+--------------------+
|      U1001|             1|                   2|
|      U1001|             2|                   2|
|      U1001|             0|                   2|
|      U1001|             1|                   2|
|      U1001|             1|                   2|
|      U1001|             1|                   2|
|      U1001|             1|                   2|
|      U1001|             2|                   2|
|      U1001|             1|                   2|
|      U1002|             2|                   2|
|      U1002|             1|                   2|
|      U1002|             1|                   2|
|      U1002|             2|                   2|
|      U1002|             1|                   2|
|      U1002|             2|                   2|
|      U1002|             1|                   2|
|      U1002|             2|                   2|


In [None]:
# Expr() is a function that allows you to use SQL-like functions. E.g. "CASE WHEN" statement. 
# Useful when you are most familiar with SQL. 

df = consumers.withColumn("Budget_shortened", expr("CASE WHEN Budget = 'High' THEN 'H' " \
            + "WHEN Budget = 'Medium' THEN 'M' ELSE 'L' END"))

df.select('Consumer_ID','Budget', 'Budget_shortened').show(10)