In [1]:
import sys
import os

os.environ['JAVA_HOME'] = '/usr/lib/jvm/java-1.8.0-openjdk-amd64'
os.environ['PYSPARK_PYTHON'] = '/home/aadi/miniconda3/envs/spark_env/bin/python' 
os.environ['PYSPARK_DRIVER_PYTHON'] = '/home/aadi/miniconda3/envs/spark_env/bin/python' 

In [15]:
from pyspark.ml import feature
import pyspark 
from pyspark.sql import SparkSession
from sklearn.base import BaseEstimator, TransformerMixin

class ModelNotTrainedException(Exception):
    def __init__(self, message):
        self.message = message

class SparkXGBClassifier(BaseEstimator, TransformerMixin):
    def __init__(self, target:str, keys=[], **sparkxgbparams) -> None:
        self.target = target
        self.keys = keys
        self.sdf = None
        self.model = None
        self.clf = SparkXGBClassifier(
            **sparkxgbparams
            label_col=target,
            features_col='features',
        )

    def _vectorise(self, sdf):
        self.sdf = sdf
        vs = feature.VectorAssembler(
            inputCols=sdf.drop(*[*self.keys, self.target]).columns,
            outputCol='features', 
            handleInvalid='keep'
        )

        return vs.transform(sdf)

    def fit(self, sdf):
        vec = self._vectorise(sdf)
        self.model = self.clf.fit(vec)
        return self.model

    def predict(self, sdf):
        if not self.model:
            raise ModelNotTrainedException('fit() must be run before calling predict()')
        return self.clf.transform(sdf).select(*[*self.sdf.columns, 'rawPrediction', 'probability'])


spark = SparkSession.builder.getOrCreate()

In [16]:
from numpy import random
from pandas import DataFrame

N = 1000
df = DataFrame({
    'num1': random.normal(100, 10, size=N),
    'num2': random.normal(100, 10, size=N),
    'num3': random.normal(100, 10, size=N),
    'cat1': random.choice([0, 1], size=N),
    'cat2': random.choice([0, 1], size=N),
    'target': random.choice([0, 1], p=[0.9, 0.1], size=N) 
})

sdf = spark.createDataFrame(df)

  for column, series in pdf.iteritems():


In [4]:
sdf.show()

+------------------+------------------+------------------+----+----+------+
|              num1|              num2|              num3|cat1|cat2|target|
+------------------+------------------+------------------+----+----+------+
|  98.2435509680153|104.93204127892011|121.68163854911032|   1|   1|     0|
| 90.38698745904239|111.62499023848352| 97.42047476171886|   1|   0|     0|
|102.06766565063468|105.48408648246144| 95.53154938287331|   1|   1|     0|
| 79.99136197520531|102.72658512411729| 85.72984330685046|   0|   0|     0|
| 94.52144416926258|106.63120413180721| 92.19556052650287|   0|   0|     0|
| 96.75326657976005|100.92777760247088| 86.78045752245558|   1|   0|     0|
| 97.89071928358211| 87.28374292509204| 88.98380826949743|   1|   0|     0|
|103.84709353154184|143.18551705249897|105.65693954097746|   0|   0|     0|
| 96.19520221212967| 83.08510925545953| 96.12402175702199|   1|   1|     0|
|107.65187942754865|  87.2800246590321| 93.04929914522802|   0|   1|     0|
| 92.1106008

In [5]:
from pyspark.ml import feature

vs = feature.VectorAssembler(
    inputCols=sdf.drop('target').columns,
    outputCol='features', 
    handleInvalid='keep'
)

vec = vs.transform(sdf)

In [23]:
import xgboost
from xgboost.spark import SparkXGBClassifier

param = {
  'max_depth': 8, 
  'learning_rate': 0.3, 
  'tree_method': 'hist', 
  'num_parallel_tree': 8, 
  'eval_metric': 'auc',
}

clf = SparkXGBClassifier(
    **param,
    features_col='features',
    label_col='target',
    num_workers=4,
    use_gpu=False,
    verbose=3
)
clf = clf.fit(vec)

[08:22:08] task 3 got new rank 0                                    (0 + 4) / 4]
[08:22:08] task 2 got new rank 1
[08:22:08] task 1 got new rank 2
[08:22:08] task 0 got new rank 3


In [14]:
clf.transform(vec).show()

+------------------+------------------+------------------+----+----+------+--------------------+--------------------+----------+--------------------+
|              num1|              num2|              num3|cat1|cat2|target|            features|       rawPrediction|prediction|         probability|
+------------------+------------------+------------------+----+----+------+--------------------+--------------------+----------+--------------------+
|  98.2435509680153|104.93204127892011|121.68163854911032|   1|   1|     0|[98.2435509680153...|[4.24246597290039...|       0.0|[0.98583149909973...|
| 90.38698745904239|111.62499023848352| 97.42047476171886|   1|   0|     0|[90.3869874590423...|[3.14202928543090...|       0.0|[0.95859348773956...|
|102.06766565063468|105.48408648246144| 95.53154938287331|   1|   1|     0|[102.067665650634...|[5.42074489593505...|       0.0|[0.99559563398361...|
| 79.99136197520531|102.72658512411729| 85.72984330685046|   0|   0|     0|[79.9913619752053...|[4.0