In [9]:
# Standard libraries
from pathlib import Path
import urllib.request as ureq
from os import unlink as os_unlink
from tempfile import NamedTemporaryFile

# Third-party libraries
import numpy as np


# --- PySparklibraries --- #
from pyspark.sql import SparkSession
from pyspark.sql.functions import ltrim, rtrim
from pyspark.ml.linalg import DenseMatrix, Vectors

from pyspark.ml.stat import ChiSquareTest, Correlation
from pyspark.mllib.feature import StandardScaler, VectorAssembler
from pyspark.mllib.regression import LinearRegression, LogisticRegressionWithSGD

from pyspark.sql.dataframe import DataFrame # For type assertion.

In [6]:
APP_NAME: str = "census_income"
DATA_SUBDIR: Path = Path().joinpath(r"data")
DATA_URL: str = "https://github.com/PacktPublishing/PySpark-Cookbook/raw/master/Data/census_income.csv"


In [7]:
spark = SparkSession.builder.appName(APP_NAME).getOrCreate()

In [10]:
def create_csv_dataframe(url: str) -> DataFrame:
    """Create and return pyspark DataFrame object from raw csv file.
    
    Parameters
    ----------
    url : str
        Path to remote, internet-hosted data file.
    verbose : bool
        Whether or not to send additional info to stdout.
        
    Returns
    -------
    filepath : str
        Path to local temporary file.
    DataFrame
        pyspark.sql.dataframe.DataFrame object.
    """
    if not DATA_SUBDIR.exists():
        DATA_SUBDIR.mkdir()
    
    with ureq.urlopen(url) as resp:
        tmp = resp.read().decode("utf-8")
        if tmp:
            # Get data into temp file.
            tempf = NamedTemporaryFile(mode="w", encoding="utf-8", dir = DATA_SUBDIR, delete = False)
            tempf.write(tmp)
            tempf.seek(0)            
            tempf.close()
            
            # Create dataframe object and check type.
            df_ = spark.read.csv(tempf.name, inferSchema = True, header = True)
            assert (type(df_) == DataFrame), "Data object type error."
            
            return df_

dataset = create_csv_dataframe(DATA_URL)

# https://spark.apache.org/docs/3.1.2/api/python/reference/api/pyspark.sql.DataFrame.html#pyspark.sql.DataFrame
if dataset:
    dataset.printSchema()
    print(dataset.columns)

root
 |-- age: integer (nullable = true)
 |-- workclass: string (nullable = true)
 |-- fnlwgt: integer (nullable = true)
 |-- education: string (nullable = true)
 |-- education-num: integer (nullable = true)
 |-- marital-status: string (nullable = true)
 |-- occupation: string (nullable = true)
 |-- relationship: string (nullable = true)
 |-- race: string (nullable = true)
 |-- sex: string (nullable = true)
 |-- capital-gain: integer (nullable = true)
 |-- capital-loss: integer (nullable = true)
 |-- hours-per-week: integer (nullable = true)
 |-- native-country: string (nullable = true)
 |-- label: string (nullable = true)

['age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'capital-gain', 'capital-loss', 'hours-per-week', 'native-country', 'label']


In [11]:
for col, _type in dataset.dtypes:
    if _type == "string":
        dataset = dataset.withColumn(col, ltrim(rtrim(dataset[col])))

# Print row count.
dataset.count()

32561

In [12]:
dataset.show(10)

+---+----------------+------+---------+-------------+--------------------+-----------------+-------------+-----+------+------------+------------+--------------+--------------+-----+
|age|       workclass|fnlwgt|education|education-num|      marital-status|       occupation| relationship| race|   sex|capital-gain|capital-loss|hours-per-week|native-country|label|
+---+----------------+------+---------+-------------+--------------------+-----------------+-------------+-----+------+------------+------------+--------------+--------------+-----+
| 39|       State-gov| 77516|Bachelors|           13|       Never-married|     Adm-clerical|Not-in-family|White|  Male|        2174|           0|            40| United-States|<=50K|
| 50|Self-emp-not-inc| 83311|Bachelors|           13|  Married-civ-spouse|  Exec-managerial|      Husband|White|  Male|           0|           0|            13| United-States|<=50K|
| 38|         Private|215646|  HS-grad|            9|            Divorced|Handlers-cleaner

In [18]:
# Distinct values
uniq_values: dict = {}
for col, _type in dataset.dtypes:
    if _type == "string":
        uniq_values[col] = []
        
        uniq_values[col] = sorted(dataset.select(col).distinct())

if len(uniq_values) > 0:
    print("\n".join([f"{k}:  {v}" for k, v in uniq_values.items()]))

workclass:  [Column<'workclass'>]
education:  [Column<'education'>]
marital-status:  [Column<'marital-status'>]
occupation:  [Column<'occupation'>]
relationship:  [Column<'relationship'>]
race:  [Column<'race'>]
sex:  [Column<'sex'>]
native-country:  [Column<'native-country'>]
label:  [Column<'label'>]


In [51]:
# abc = dataset.dropDuplicates(["sex"]).select("sex")
# [r.sex for r in abc]
tmp_table_name: str = "census"
dataset.createOrReplaceTempView(tmp_table_name)

_df = spark.sql("""
SELECT sex
FROM census
GROUP BY sex
""")

[r.sex for r in _df.collect()]

# spark.catalog.dropTempView(tmp_table_name)

['Female', 'Male']

In [None]:
# Split data
train_pct = 0.75
df_train, df_test = final_df.randomSplit([train_pct, 1-train_pct])

In [13]:
# Vectorized field
dependent_variable: str = "label"
features_col: str = "features"

independent_variables = [i for i in dataset.columns if not i in ("Email", "Address", dependent_variable)]

feature_assmblr = VectorAssembler(
    inputCols = independent_variables,
    outputCol = features_col
)

pyspark.sql.dataframe.DataFrame

In [None]:
# Clear out dataset when done.
spark.catalog.dropTempView(tmp_table_name)

for f in DATA_SUBDIR.rglob("*"):
    _path = Path(f)
    if _path.is_file():
        _path.unlink()