In [12]:
# Collab config. (Pyspark needs JAVA)
!apt-get install openjdk-8-jdk -y
!pip install pyspark

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
openjdk-8-jdk is already the newest version (8u442-b06~us1-0ubuntu1~22.04).
0 upgraded, 0 newly installed, 0 to remove and 29 not upgraded.


In [25]:
import pandas as pd
from sklearn import datasets
from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf

# Init spark session.
spark = SparkSession.builder \
    .appName("Wine dataset with Spark") \
    .getOrCreate()

In [14]:
ds = datasets.load_wine(as_frame=True)
# Parse dataset into spark.
df = spark.createDataFrame(ds.frame)

# Check total number of rows and columns in the dataset.
print(f"Total number of samples is '{df.count()}' and there are '{len(df.columns)}' columns.", end="\n\n")

# Obtain a brief look of the dataset.
df.show()
df.printSchema()

# (Optional) See dataset description.
# print(f"Dataset contextual information:\n {ds.DESCR}")

Total number of samples is '178' and there are '14' columns.

+-------+----------+----+-----------------+---------+-------------+----------+--------------------+---------------+---------------+----+----------------------------+-------+------+
|alcohol|malic_acid| ash|alcalinity_of_ash|magnesium|total_phenols|flavanoids|nonflavanoid_phenols|proanthocyanins|color_intensity| hue|od280/od315_of_diluted_wines|proline|target|
+-------+----------+----+-----------------+---------+-------------+----------+--------------------+---------------+---------------+----+----------------------------+-------+------+
|  14.23|      1.71|2.43|             15.6|    127.0|          2.8|      3.06|                0.28|           2.29|           5.64|1.04|                        3.92| 1065.0|     0|
|   13.2|      1.78|2.14|             11.2|    100.0|         2.65|      2.76|                0.26|           1.28|           4.38|1.05|                         3.4| 1050.0|     0|
|  13.16|      2.36|2.67|        

In [15]:
# Clean loaded data (if needed).
print(f"There are {df.drop_duplicates().count()-df.count()} columns duplicated and therefore removed from the dataframe.")
print(f"There were {df.dropna().count()-df.count()} columns empty and therefore removed from the dataframe.")

There are 0 columns duplicated and therefore removed from the dataframe.
There were 0 columns empty and therefore removed from the dataframe.


In [16]:
# General details of dataframe and its columns.
df.select("*").describe().show(vertical=True)

-RECORD 0-------------------------------------------
 summary                      | count               
 alcohol                      | 178                 
 malic_acid                   | 178                 
 ash                          | 178                 
 alcalinity_of_ash            | 178                 
 magnesium                    | 178                 
 total_phenols                | 178                 
 flavanoids                   | 178                 
 nonflavanoid_phenols         | 178                 
 proanthocyanins              | 178                 
 color_intensity              | 178                 
 hue                          | 178                 
 od280/od315_of_diluted_wines | 178                 
 proline                      | 178                 
 target                       | 178                 
-RECORD 1-------------------------------------------
 summary                      | mean                
 alcohol                      | 13.00061797752

In [29]:
# Obtain first column.
print(df.take(1))

# Obtain all ocurrences.
# df.collect()

[Row(alcohol=14.23, malic_acid=1.71, ash=2.43, alcalinity_of_ash=15.6, magnesium=127.0, total_phenols=2.8, flavanoids=3.06, nonflavanoid_phenols=0.28, proanthocyanins=2.29, color_intensity=5.64, hue=1.04, od280/od315_of_diluted_wines=3.92, proline=1065.0, target=0)]


In [18]:
# Print a determined column.
df.select(df.alcohol).show()

+-------+
|alcohol|
+-------+
|  14.23|
|   13.2|
|  13.16|
|  14.37|
|  13.24|
|   14.2|
|  14.39|
|  14.06|
|  14.83|
|  13.86|
|   14.1|
|  14.12|
|  13.75|
|  14.75|
|  14.38|
|  13.63|
|   14.3|
|  13.83|
|  14.19|
|  13.64|
+-------+
only showing top 20 rows



In [19]:
# Print table with column condition.
df.filter(df.alcohol > 13).show()

+-------+----------+----+-----------------+---------+-------------+----------+--------------------+---------------+---------------+----+----------------------------+-------+------+
|alcohol|malic_acid| ash|alcalinity_of_ash|magnesium|total_phenols|flavanoids|nonflavanoid_phenols|proanthocyanins|color_intensity| hue|od280/od315_of_diluted_wines|proline|target|
+-------+----------+----+-----------------+---------+-------------+----------+--------------------+---------------+---------------+----+----------------------------+-------+------+
|  14.23|      1.71|2.43|             15.6|    127.0|          2.8|      3.06|                0.28|           2.29|           5.64|1.04|                        3.92| 1065.0|     0|
|   13.2|      1.78|2.14|             11.2|    100.0|         2.65|      2.76|                0.26|           1.28|           4.38|1.05|                         3.4| 1050.0|     0|
|  13.16|      2.36|2.67|             18.6|    101.0|          2.8|      3.24|                 

In [20]:
# Collapse columns from its unique values.
df.groupby("target").avg().show(vertical=True)

-RECORD 0------------------------------------------------
 target                            | 0                   
 avg(alcohol)                      | 13.744745762711865  
 avg(malic_acid)                   | 2.0106779661016954  
 avg(ash)                          | 2.455593220338984   
 avg(alcalinity_of_ash)            | 17.037288135593222  
 avg(magnesium)                    | 106.33898305084746  
 avg(total_phenols)                | 2.8401694915254234  
 avg(flavanoids)                   | 2.982372881355932   
 avg(nonflavanoid_phenols)         | 0.29                
 avg(proanthocyanins)              | 1.8993220338983055  
 avg(color_intensity)              | 5.528305084745763   
 avg(hue)                          | 1.0620338983050848  
 avg(od280/od315_of_diluted_wines) | 3.1577966101694916  
 avg(proline)                      | 1115.7118644067796  
 avg(target)                       | 0.0                 
-RECORD 1------------------------------------------------
 target       

In [21]:
# Define a temporal view from current Spark Session.
df.createOrReplaceTempView("tableA")
spark.sql("SELECT count(*) from tableA").show()

+--------+
|count(1)|
+--------+
|     178|
+--------+



In [31]:
# Define custom pandas functions.
@pandas_udf("integer")
def add_one(s: pd.Series) -> pd.Series:
    return s + 1

# Make them accesible from Spark.
spark.udf.register("add_one", add_one)
spark.sql("SELECT DISTINCT add_one(target) FROM tableA").show()

+---------------+
|add_one(target)|
+---------------+
|              1|
|              2|
|              3|
+---------------+

