Для произвольно выбранного датасета провести обработку данных и построить предсказательную модель с использованием функционала pySpark.

In [1]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import LogisticRegression

In [2]:
spark = SparkSession.builder.getOrCreate()

In [3]:
df=spark.read.csv("winequality-red.csv",header=True,inferSchema=True, sep=";")

In [4]:
df.printSchema()

root
 |-- fixed acidity: double (nullable = true)
 |-- volatile acidity: double (nullable = true)
 |-- citric acid: double (nullable = true)
 |-- residual sugar: double (nullable = true)
 |-- chlorides: double (nullable = true)
 |-- free sulfur dioxide: double (nullable = true)
 |-- total sulfur dioxide: double (nullable = true)
 |-- density: double (nullable = true)
 |-- pH: double (nullable = true)
 |-- sulphates: double (nullable = true)
 |-- alcohol: double (nullable = true)
 |-- quality: integer (nullable = true)



In [5]:
df.show(5,truncate=False)

+-------------+----------------+-----------+--------------+---------+-------------------+--------------------+-------+----+---------+-------+-------+
|fixed acidity|volatile acidity|citric acid|residual sugar|chlorides|free sulfur dioxide|total sulfur dioxide|density|pH  |sulphates|alcohol|quality|
+-------------+----------------+-----------+--------------+---------+-------------------+--------------------+-------+----+---------+-------+-------+
|7.4          |0.7             |0.0        |1.9           |0.076    |11.0               |34.0                |0.9978 |3.51|0.56     |9.4    |5      |
|7.8          |0.88            |0.0        |2.6           |0.098    |25.0               |67.0                |0.9968 |3.2 |0.68     |9.8    |5      |
|7.8          |0.76            |0.04       |2.3           |0.092    |15.0               |54.0                |0.997  |3.26|0.65     |9.8    |5      |
|11.2         |0.28            |0.56       |1.9           |0.075    |17.0               |60.0       

In [6]:
feature_names = df.columns

In [7]:
feature_names.remove('quality')

In [8]:
print(feature_names)

['fixed acidity', 'volatile acidity', 'citric acid', 'residual sugar', 'chlorides', 'free sulfur dioxide', 'total sulfur dioxide', 'density', 'pH', 'sulphates', 'alcohol']


In [9]:
assembler = VectorAssembler(inputCols=feature_names, outputCol="features")
output = assembler.transform(df)

In [10]:
# разделим данные
train, test = output.randomSplit([0.7, 0.3])

In [11]:
train.show(5)

+-------------+----------------+-----------+--------------+---------+-------------------+--------------------+-------+----+---------+-------+-------+--------------------+
|fixed acidity|volatile acidity|citric acid|residual sugar|chlorides|free sulfur dioxide|total sulfur dioxide|density|  pH|sulphates|alcohol|quality|            features|
+-------------+----------------+-----------+--------------+---------+-------------------+--------------------+-------+----+---------+-------+-------+--------------------+
|          4.6|            0.52|       0.15|           2.1|    0.054|                8.0|                65.0| 0.9934| 3.9|     0.56|   13.1|      4|[4.6,0.52,0.15,2....|
|          4.7|             0.6|       0.17|           2.3|    0.058|               17.0|               106.0| 0.9932|3.85|      0.6|   12.9|      6|[4.7,0.6,0.17,2.3...|
|          5.0|            0.38|       0.01|           1.6|    0.048|               26.0|                60.0|0.99084| 3.7|     0.75|   14.0|    

Обучаем модель

In [12]:
lr = LogisticRegression(featuresCol="features", labelCol="quality", maxIter=10)

In [13]:
model = lr.fit(train)

In [14]:
pred = model.transform(test)

In [15]:
pred.select("features", "quality", "probability", "prediction").show(truncate=False)

+------------------------------------------------------------+-------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------+
|features                                                    |quality|probability                                                                                                                                                                                     |prediction|
+------------------------------------------------------------+-------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------+
|[4.9,0.42,0.0,2.1,0.048,16.0,42.0,0.99154,3.71,0.74,14.0]   |7      |[2.4226118673671896E-4,2.4226118673671896E-4,2.4226118673671896E-4,4.659870270613564E-4,0.001649112540597

In [16]:
spark.stop()