In [None]:
import os
import sys
os.environ["PYSPARK_SUBMIT_ARGS"]='--conf spark.sql.catalogImplementation=in-memory pyspark-shell'
os.environ["PYSPARK_PYTHON"]='/opt/anaconda/envs/bd9/bin/python'
os.environ["SPARK_HOME"]='/usr/hdp/current/spark2-client'

spark_home = os.environ.get('SPARK_HOME', None)
if not spark_home:
    raise ValueError('SPARK_HOME environment variable is not set')
sys.path.insert(0, os.path.join(spark_home, 'python'))
sys.path.insert(0, os.path.join(spark_home, 'python/lib/py4j-0.10.7-src.zip'))
exec(open(os.path.join(spark_home, 'python/pyspark/shell.py')).read())

In [None]:
from pyspark.ml import Estimator

In [None]:
from sklearn.linear_model import LogisticRegression

In [None]:
from sklearn.datasets import make_classification

In [None]:
from sklearn.model_selection import train_test_split

In [None]:
import pandas as pd

In [None]:
X, y = make_classification(random_state=5757)

In [None]:
y.shape

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=5757)

In [None]:
X_train.shape

In [None]:
est = LogisticRegression(random_state=5757)

In [None]:
est.fit(X_train, y_train)

In [None]:
from pyspark.ml.linalg import DenseVector

In [None]:
from pyspark.sql.types import *
from pyspark.ml.linalg import VectorUDT

## Способ №1, правильный, но не работает :(

In [None]:
schema = StructType(fields=[
    StructField("features", VectorUDT()),
    StructField("label", IntegerType())
])

In [None]:
df_test = spark.createDataFrame(zip(map(DenseVector, X_test), map(int, y_test)), schema=schema)

In [None]:
df_test.show()

In [None]:
#df_test.selectExpr("CAST(features AS ARRAY<DOUBLE>)").show()

In [None]:
est_broadcast = spark.sparkContext.broadcast(est)

In [None]:
import pyspark.sql.functions as F

In [None]:
@F.pandas_udf(FloatType())
def predict(series):
    predictions = est_broadcast.value.predict(series)
    return pd.Series(predictions)

In [None]:
df_test.withColumn("prediction", predict("features")).show()

In [None]:
df_test.select(df_test.features.cast(ArrayType(FloatType()))).show()

Все это происходит из-за этого https://issues.apache.org/jira/browse/SPARK-19653

## Способ №2, неправильный, но работает :/

In [None]:
@F.udf(ArrayType(FloatType()))
def vectorToArray(row):
    return row.tolist()

In [None]:
df_test = df_test.withColumn("features_array", vectorToArray("features"))

In [None]:
df_test.show()

In [None]:
@F.pandas_udf(FloatType())
def predict(series):
    # Необходимо сделать преобразования, потому что на вход приходит pd.Series(list)
    predictions = est_broadcast.value.predict(series.tolist())
    return pd.Series(predictions)

In [None]:
df_test.withColumn("prediction", predict("features_array")).show()

In [None]:
import pickle

In [None]:
with open("logistic_model.pk", "wb") as f:
    pickle.dump(est, f)

In [None]:
from pyspark import keyword_only
from pyspark.ml import Model
from pyspark.ml.param import Param, Params, TypeConverters
from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol

In [None]:
class SKLogisticRegreesionModel(Model, HasFeaturesCol, HasLabelCol, HasPredictionCol):
    model_file = Param(Params._dummy(), "model_file",
                      "path to pickled scikit-learn logistic regression model",
                      typeConverter=TypeConverters.toString)
    @keyword_only
    def __init__(self, model_file=None, featuresCol="features", labelCol="label", predictionCol="prediction"):
        super(SKLogisticRegreesionModel, self).__init__()
        if model_file is None:
            raise ValueError("model_file must be specified!")
        with open(model_file, "rb") as f:
            self.estimator = pickle.load(f)
        kwargs = self._input_kwargs
        self._set(**kwargs)
        
    def _transform(self, dataset):
        return dataset.withColumn(self.getPredictionCol(), predict(self.getFeaturesCol()))

In [None]:
spark_est = SKLogisticRegreesionModel(model_file="logistic_model.pk", featuresCol="features_array")

In [None]:
spark_est.transform(df).show()

In [None]:
from pyspark.ml import Pipeline

In [None]:
pipeline = Pipeline(stages=[
    spark_est
])

In [None]:
pipeline_model = pipeline.fit(dff)

In [None]:
pipeline_model.transform(dff).show()

In [None]:
class SKLogisticRegression(Estimator, HasFeaturesCol, HasPredictionCol, HasLabelCol):
    @keyword_only
    def __init__(self, featuresCol="features", predictionCol="prediction", labelCol="label"):
        super(SKLogisticRegression, self).__init__()
        kwargs = self._input_kwargs
        self._set(**kwargs)
        
    def _fit(self, dataset):
        local_dataset = dataset.select(self.getFeaturesCol(), self.getLabelCol()).toPandas()
        self.est = LogisticRegression()
        self.est.fit(local_dataset[self.getFeaturesCol()].tolist(), local_dataset[self.getLabelCol()])
        self.model_file = "logistic_regression.pk"
        with open(self.model_file, "wb") as f:
            pickle.dump(self.est, f)
        return SKLogisticRegreesionModel(model_file=self.model_file, predictionCol=self.getPredictionCol(),
                                         featuresCol=self.getFeaturesCol(), labelCol=self.getLabelCol())

In [None]:
spark_est = SKLogisticRegression(featuresCol="features_array")

In [None]:
spark_est_model = spark_est.fit(df_test)

In [None]:
spark_est_model.transform(df_test).show()

In [None]:
pipeline = Pipeline(stages=[
    spark_est
])

In [None]:
pipeline_model = pipeline.fit(df_test)

In [None]:
pipeline_model.transform(df_test).show()

## И, к слову об Arrow

In [None]:
from pyspark.sql.functions import rand

In [None]:
df = spark.range(1 << 22).toDF("id").withColumn("x", rand())

In [None]:
df.printSchema()

In [None]:
df.count()

In [None]:
spark.conf.set("spark.sql.execution.arrow.enabled", "false")

In [None]:
%time pdf = df.toPandas()

In [None]:
spark.conf.set("spark.sql.execution.arrow.enabled", "true")

In [None]:
%time pdf = df.toPandas()

In [None]:
spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch")

In [None]:
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "100")

In [None]:
%time pdf = df.toPandas()

In [None]:
spark.stop()