In [14]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler, StringIndexer
from pyspark.ml.pipeline import Pipeline
from pyspark.ml.classification import RandomForestClassifier, OneVsRest
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [2]:
spark = SparkSession.builder.appName('One vs Rest').getOrCreate()

In [3]:
spark

In [4]:
file_path = 'data/bank-full.csv'

In [5]:
df = spark.read.csv(file_path, inferSchema=True, header=True, sep=';')

In [6]:
df.printSchema()

root
 |-- age: integer (nullable = true)
 |-- job: string (nullable = true)
 |-- marital: string (nullable = true)
 |-- education: string (nullable = true)
 |-- default: string (nullable = true)
 |-- balance: integer (nullable = true)
 |-- housing: string (nullable = true)
 |-- loan: string (nullable = true)
 |-- contact: string (nullable = true)
 |-- day: integer (nullable = true)
 |-- month: string (nullable = true)
 |-- duration: integer (nullable = true)
 |-- campaign: integer (nullable = true)
 |-- pdays: integer (nullable = true)
 |-- previous: integer (nullable = true)
 |-- poutcome: string (nullable = true)
 |-- y: string (nullable = true)



In [7]:
def vector_assemble(df, features_list, target):
    assembler = VectorAssembler(inputCols=features_list, outputCol='features')
    string_indexer = StringIndexer(inputCol=target, outputCol='label')
    
    stages = [assembler, string_indexer]
    selected_cols = ['label', 'features'] + features_list
    
    pipeline = Pipeline(stages=stages)
    pipeline_model = pipeline.fit(df)
    df = pipeline_model.transform(df).select(selected_cols)
    return df

In [11]:
target = 'education'
multiclass_df = df.select(['age', 'balance', 'day', 'duration',
                            'campaign', 'pdays', 'previous', 'education'])

features_list = multiclass_df.columns
features_list.remove('education')

In [12]:
# applying the function:
df_final = vector_assemble(df, features_list, target)

In [13]:
df_final.printSchema()

root
 |-- label: double (nullable = false)
 |-- features: vector (nullable = true)
 |-- age: integer (nullable = true)
 |-- balance: integer (nullable = true)
 |-- day: integer (nullable = true)
 |-- duration: integer (nullable = true)
 |-- campaign: integer (nullable = true)
 |-- pdays: integer (nullable = true)
 |-- previous: integer (nullable = true)



In [15]:
# generating train and test split:

train, test = df_final.randomSplit([0.8, 0.2])

In [16]:
# instantiating the base classifier:
clf = RandomForestClassifier(featuresCol='features', labelCol='label')
ovr = OneVsRest(classifier=clf, featuresCol='features', labelCol='label')

In [17]:
# training the multiclass model:

ovr_model = ovr.fit(train)

In [31]:
ovr_model.models

[RandomForestClassificationModel: uid=RandomForestClassifier_6932c979352a, numTrees=20, numClasses=2, numFeatures=7,
 RandomForestClassificationModel: uid=RandomForestClassifier_6932c979352a, numTrees=20, numClasses=2, numFeatures=7,
 RandomForestClassificationModel: uid=RandomForestClassifier_6932c979352a, numTrees=20, numClasses=2, numFeatures=7,
 RandomForestClassificationModel: uid=RandomForestClassifier_6932c979352a, numTrees=20, numClasses=2, numFeatures=7]