In [1]:
DF = spark.read.option("sep", ",").option("header", "true").csv("CrabAgePrediction.csv")
DF.printSchema()
DF.show()

root
 |-- Sex: string (nullable = true)
 |-- Length: string (nullable = true)
 |-- Diameter: string (nullable = true)
 |-- Height: string (nullable = true)
 |-- Weight: string (nullable = true)
 |-- Shucked Weight: string (nullable = true)
 |-- Viscera Weight: string (nullable = true)
 |-- Shell Weight: string (nullable = true)
 |-- Age: string (nullable = true)

+---+------+--------+------+-----------+--------------+--------------+------------+---+
|Sex|Length|Diameter|Height|     Weight|Shucked Weight|Viscera Weight|Shell Weight|Age|
+---+------+--------+------+-----------+--------------+--------------+------------+---+
|  F|1.4375|   1.175|0.4125| 24.6357155|    12.3320325|     5.5848515|    6.747181|  9|
|  M|0.8875|    0.65|0.2125| 5.40057975|     2.2963095|    1.37495075|   1.5592225|  6|
|  I|1.0375|   0.775|  0.25| 7.95203475|      3.231843|    1.60174675|  2.76407625|  6|
|  F| 1.175|  0.8875|  0.25|13.48018725|    4.74854125|    2.28213475|   5.2446575| 10|
|  I|0.8875|  0.66

In [3]:
weights=[0.7, 0.3]
train, test = DF.randomSplit(weights)
train.show(3)
test.show(3)

+---+------+--------+------+----------+--------------+--------------+------------+---+
|Sex|Length|Diameter|Height|    Weight|Shucked Weight|Viscera Weight|Shell Weight|Age|
+---+------+--------+------+----------+--------------+--------------+------------+---+
|  F|0.7625|  0.5625| 0.175|4.20990075|    1.65844575|    0.94970825|   1.2757275|  7|
|  F|0.8125|    0.65| 0.225|5.42892925|     2.4097075|      1.020582|    1.757669|  7|
|  F| 0.825|    0.65|   0.2|    5.6699|    1.77184375|      1.417475|    1.984465|  9|
+---+------+--------+------+----------+--------------+--------------+------------+---+
only showing top 3 rows

+---+------+--------+------+---------+--------------+--------------+------------+---+
|Sex|Length|Diameter|Height|   Weight|Shucked Weight|Viscera Weight|Shell Weight|Age|
+---+------+--------+------+---------+--------------+--------------+------------+---+
|  F|0.6875|  0.4875| 0.175|  2.26796|     0.8788345|    0.60951425|   0.7087375|  5|
|  F| 0.725|   0.525|0

In [10]:
print(train.count())
print(test.count())

2770
1123


In [24]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import RandomForestRegressor
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler, StringIndexer

In [25]:
# Создание SparkSession
spark = SparkSession.builder.appName("CrabAgePrediction").getOrCreate()
# Загружаем датасет
data = spark.read.csv("CrabAgePrediction.csv", header = True, inferSchema = True)
# Преобразуем Пол краба в числа
indexer = StringIndexer(inputCol="Sex", outputCol="SexIndex")
# Преобразовываем данные. Начинаем с выбора признаков
features = ["SexIndex", "Length", "Diameter" , "Height", "Weight", "Shucked Weight", "Viscera Weight", "Shell Weight"]
# Создаем вектор этих признаков
assembler = VectorAssembler(inputCols=features, outputCol="features")
# Разделяем данные на обучение и тест
(trainingData, testData) = data.randomSplit([0.7, 0.3], seed=12345)
# Создаем модель
rf = RandomForestRegressor(featuresCol = "features", labelCol= "Age")
# Создаем конвейер
pipeline = Pipeline(stages = [indexer, assembler, rf])
# Обучение модели
model = pipeline.fit(trainingData)
# Предсказание возраста на тестовой выборке
predictions = model.transform(testData)
# Оценка модели
evaluator = RegressionEvaluator(labelCol = "Age", predictionCol = "prediction", metricName = "rmse")
rmse = evaluator.evaluate(predictions)
print("Root Mean Squared Error (RMSE): %s" % rmse)
# Вывод результата и выключение SparkSession
predictions.select("Age", "prediction").show()
spark.stop()

Root Mean Squared Error (RMSE): 2.2615401627561877
+---+------------------+
|Age|        prediction|
+---+------------------+
|  9| 7.988392648316643|
| 10| 7.908705148316645|
|  8| 8.045248288980488|
|  7| 9.548391803438388|
|  7| 8.000412326977736|
|  8| 8.592632589196345|
| 10| 8.720904423833051|
|  9| 9.344447221980165|
| 10| 8.911436173951397|
|  7| 9.440142734707567|
|  7| 9.337883154618444|
|  8|  9.39839418522813|
| 11| 9.443347582927702|
|  6| 9.414458930156876|
|  9|10.059943090170167|
|  9| 9.443347582927702|
| 13| 9.771095354233097|
| 13|11.179394759835631|
|  8| 9.810599704373693|
| 13| 9.386385299029001|
+---+------------------+
only showing top 20 rows

