In [5]:
"""
Let’s write a Spark program that reads a file with over 100,000 entries (where each
row or line has a <state, mnm_color, count>) and computes and aggregates the
counts for each color and state. These aggregated counts tell us the colors of M&Ms
favored by students in each state.
"""

'\nLet’s write a Spark program that reads a file with over 100,000 entries (where each\nrow or line has a <state, mnm_color, count>) and computes and aggregates the\ncounts for each color and state. These aggregated counts tell us the colors of M&Ms\nfavored by students in each state.\n'

In [1]:
from pyspark.sql import SparkSession

In [2]:
spark = SparkSession.builder.appName('mnm_count').master('local[*]').getOrCreate()

In [6]:
#read the csv data into dataframe
raw_df = spark.read.csv("./data/mnm_dataset_for_count/mnm_dataset.csv", inferSchema= True, header= True)

#display the first 20 rows of the df
raw_df.show(20, truncate=False)

+-----+------+-----+
|State|Color |Count|
+-----+------+-----+
|TX   |Red   |20   |
|NV   |Blue  |66   |
|CO   |Blue  |79   |
|OR   |Blue  |71   |
|WA   |Yellow|93   |
|WY   |Blue  |16   |
|CA   |Yellow|53   |
|WA   |Green |60   |
|OR   |Green |71   |
|TX   |Green |68   |
|NV   |Green |59   |
|AZ   |Brown |95   |
|WA   |Yellow|20   |
|AZ   |Blue  |75   |
|OR   |Brown |72   |
|NV   |Red   |98   |
|WY   |Orange|45   |
|CO   |Blue  |52   |
|TX   |Brown |94   |
|CO   |Red   |82   |
+-----+------+-----+
only showing top 20 rows



In [18]:
# raw_df grouped by Color
grouped_df = raw_df.select("State", "Color", "Count").groupBy("State", "Color")

#workaround for seeing the grouped data since show is not supported by groupedData directly
grouped_df.pivot("State").count().show(20)

+-----+------+----+----+----+----+----+----+----+----+----+----+
|State| Color|  AZ|  CA|  CO|  NM|  NV|  OR|  TX|  UT|  WA|  WY|
+-----+------+----+----+----+----+----+----+----+----+----+----+
|   WY| Green|null|null|null|null|null|null|null|null|null|1695|
|   NV|   Red|null|null|null|null|1610|null|null|null|null|null|
|   UT|  Blue|null|null|null|null|null|null|null|1655|null|null|
|   WA|Orange|null|null|null|null|null|null|null|null|1658|null|
|   NM| Green|null|null|null|1682|null|null|null|null|null|null|
|   CA|  Blue|null|1603|null|null|null|null|null|null|null|null|
|   WA|   Red|null|null|null|null|null|null|null|null|1671|null|
|   NV| Brown|null|null|null|null|1657|null|null|null|null|null|
|   AZ| Green|1676|null|null|null|null|null|null|null|null|null|
|   CA|   Red|null|1656|null|null|null|null|null|null|null|null|
|   AZ|Orange|1689|null|null|null|null|null|null|null|null|null|
|   CO|  Blue|null|null|1695|null|null|null|null|null|null|null|
|   NM|Orange|null|null|n

In [25]:
from pyspark.sql.functions import count

grouped_df_with_count = grouped_df.agg(count("Count").alias("Total")).orderBy("Total", ascending= False)
grouped_df_with_count.show(100)

print("Total Rows : "+str(raw_df.count()))

+-----+------+-----+
|State| Color|Total|
+-----+------+-----+
|   CA|Yellow| 1807|
|   WA| Green| 1779|
|   OR|Orange| 1743|
|   TX| Green| 1737|
|   TX|   Red| 1725|
|   CA| Green| 1723|
|   CO|Yellow| 1721|
|   CA| Brown| 1718|
|   CO| Green| 1713|
|   NV|Orange| 1712|
|   TX|Yellow| 1703|
|   NV| Green| 1698|
|   AZ| Brown| 1698|
|   CO|  Blue| 1695|
|   WY| Green| 1695|
|   NM|   Red| 1690|
|   AZ|Orange| 1689|
|   NM|Yellow| 1688|
|   NM| Brown| 1687|
|   UT|Orange| 1684|
|   NM| Green| 1682|
|   UT|   Red| 1680|
|   AZ| Green| 1676|
|   NV|Yellow| 1675|
|   NV|  Blue| 1673|
|   WA|   Red| 1671|
|   WY|   Red| 1670|
|   WA| Brown| 1669|
|   NM|Orange| 1665|
|   WY|  Blue| 1664|
|   WA|Yellow| 1663|
|   WA|Orange| 1658|
|   CA|Orange| 1657|
|   NV| Brown| 1657|
|   CO| Brown| 1656|
|   CA|   Red| 1656|
|   UT|  Blue| 1655|
|   AZ|Yellow| 1654|
|   TX|Orange| 1652|
|   AZ|   Red| 1648|
|   OR|  Blue| 1646|
|   OR|   Red| 1645|
|   UT|Yellow| 1645|
|   CO|Orange| 1642|
|   TX| Brown

In [26]:
# While the above code aggregated and counted for all
# the states, what if we just want to see the data for
# a single state, e.g., CA?
# 1. Select from all rows in the DataFrame
# 2. Filter only CA state
# 3. groupBy() State and Color as we did above
# 4. Aggregate the counts for each color
# 5. orderBy() in descending order
# Find the aggregate count for California by filtering

In [32]:
ca_data = raw_df.select("State", "Color", "Count").where(raw_df.State== "CA").groupBy("State", "Color").agg(count("Count").alias("Total")).orderBy("Total", ascending = False).show(10, truncate = False)

+-----+------+-----+
|State|Color |Total|
+-----+------+-----+
|CA   |Yellow|1807 |
|CA   |Green |1723 |
|CA   |Brown |1718 |
|CA   |Orange|1657 |
|CA   |Red   |1656 |
|CA   |Blue  |1603 |
+-----+------+-----+



In [33]:
spark.stop()