Начнем с Ирисов Фишера!

In [67]:
from pyspark.sql import SparkSession
import os
import sys
from pyspark.sql.functions import *
from pyspark.ml import *
from pyspark.sql.types import *
import findspark
findspark.init('C:/Spark/spark-3.4.2-bin-hadoop3/')

os.environ['PYSPARK_PYTHON'] = sys.executable
os.environ['PYSPARK_DRIVER_PYTHON'] = sys.executable

spark = SparkSession.builder\
    .config("spark.driver.memory", "10g")\
    .master('local[*]')\
    .appName('HW2').getOrCreate()




table_schema  = StructType([StructField('Id', IntegerType(), True),
                            StructField('SepalLengthCm', FloatType(), True),
                            StructField('SepalWidthCm', FloatType(), False),
                            StructField('PetalLengthCm', FloatType(), False),
                            StructField('PetalWidthCm', FloatType(), False),
                            StructField('Species', StringType(), True)])

iris_dataset = spark.read.format("csv").\
                option("delimiter", ",").\
                option("header","true").\
                option("encoding", "utf-8").\
                schema(table_schema).\
                load("iris.csv")

Смотрим на датасет

In [68]:
iris_dataset.show()

+---+-------------+------------+-------------+------------+-----------+
| Id|SepalLengthCm|SepalWidthCm|PetalLengthCm|PetalWidthCm|    Species|
+---+-------------+------------+-------------+------------+-----------+
|  1|          5.1|         3.5|          1.4|         0.2|Iris-setosa|
|  2|          4.9|         3.0|          1.4|         0.2|Iris-setosa|
|  3|          4.7|         3.2|          1.3|         0.2|Iris-setosa|
|  4|          4.6|         3.1|          1.5|         0.2|Iris-setosa|
|  5|          5.0|         3.6|          1.4|         0.2|Iris-setosa|
|  6|          5.4|         3.9|          1.7|         0.4|Iris-setosa|
|  7|          4.6|         3.4|          1.4|         0.3|Iris-setosa|
|  8|          5.0|         3.4|          1.5|         0.2|Iris-setosa|
|  9|          4.4|         2.9|          1.4|         0.2|Iris-setosa|
| 10|          4.9|         3.1|          1.5|         0.1|Iris-setosa|
| 11|          5.4|         3.7|          1.5|         0.2|Iris-

Также смотрим на схему

In [69]:
iris_dataset.printSchema()

root
 |-- Id: integer (nullable = true)
 |-- SepalLengthCm: float (nullable = true)
 |-- SepalWidthCm: float (nullable = true)
 |-- PetalLengthCm: float (nullable = true)
 |-- PetalWidthCm: float (nullable = true)
 |-- Species: string (nullable = true)



Закодируем классы Ирисов

In [70]:
classes = iris_dataset.select('Species').distinct().withColumn('class_id', monotonically_increasing_id())

result = iris_dataset.join(classes,on='Species').drop('Species')

In [71]:
result.printSchema()

root
 |-- Id: integer (nullable = true)
 |-- SepalLengthCm: float (nullable = true)
 |-- SepalWidthCm: float (nullable = true)
 |-- PetalLengthCm: float (nullable = true)
 |-- PetalWidthCm: float (nullable = true)
 |-- class_id: long (nullable = false)



Определяем фичи и таргеты

In [72]:
from pyspark.ml.feature import VectorAssembler

assembler = VectorAssembler(
    inputCols=["SepalLengthCm", "SepalWidthCm", "PetalLengthCm", "PetalWidthCm"],
    outputCol="features")

output = assembler.transform(result).select('class_id','features')

Смотрим на результат векторизации

In [73]:
output.show()

+--------+--------------------+
|class_id|            features|
+--------+--------------------+
|       1|[5.09999990463256...|
|       1|[4.90000009536743...|
|       1|[4.69999980926513...|
|       1|[4.59999990463256...|
|       1|[5.0,3.5999999046...|
|       1|[5.40000009536743...|
|       1|[4.59999990463256...|
|       1|[5.0,3.4000000953...|
|       1|[4.40000009536743...|
|       1|[4.90000009536743...|
|       1|[5.40000009536743...|
|       1|[4.80000019073486...|
|       1|[4.80000019073486...|
|       1|[4.30000019073486...|
|       1|[5.80000019073486...|
|       1|[5.69999980926513...|
|       1|[5.40000009536743...|
|       1|[5.09999990463256...|
|       1|[5.69999980926513...|
|       1|[5.09999990463256...|
+--------+--------------------+
only showing top 20 rows



Разбиваем на train и test

In [74]:
train_df, test_df = output.randomSplit(weights=[0.8,0.2], seed=100)



Определяем и тренируем классификатор 


In [75]:
DT = classification.DecisionTreeClassifier(featuresCol='features',labelCol='class_id')
model = DT.fit(train_df)

Делаем предсказания на тесте

In [76]:

predictions = model.transform(test_df)
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

evaluator = MulticlassClassificationEvaluator(labelCol="class_id", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)
print("Test Error = %g" % (1.0 - accuracy))
print(accuracy)

Test Error = 0.0645161
0.9354838709677419


Теперь кластеризация


In [77]:
from pyspark.ml.clustering import KMeans
from pyspark.ml.evaluation import ClusteringEvaluator

output1 = output.drop('class_id')
train_df1, test_df1 = output1.randomSplit(weights=[0.8,0.2], seed=100)

kmeans = KMeans(k=3)
model = kmeans.fit(train_df1)

In [78]:
predictions = model.transform(test_df1)
evaluator = ClusteringEvaluator()
silhouette = evaluator.evaluate(predictions)
print("Silhouette score = ", silhouette)

Silhouette score =  0.7464386444766574


Перейдем к другим датасетам

In [79]:
youtube_dataset = spark.read.format("csv").\
                option("delimiter", ",").\
                option("header","true").\
                option("encoding", "cp1251").\
                load("youtube_channels_1M_clean.csv")

In [80]:
youtube_dataset.show()

+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+------------------+-------------------------+---------------------------+------------------------+--------------------+
|          channel_id|        channel_link|        channel_name|    subscriber_count|         banner_link|         description|            keywords|              avatar|             country|         total_views|        total_videos|         join_date|mean_views_last_30_videos|median_views_last_30_videos|std_views_last_30_videos|     videos_per_week|
+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+------------------+-------------------------+-------------------

Смотри на схему

In [81]:
youtube_dataset.printSchema()

root
 |-- channel_id: string (nullable = true)
 |-- channel_link: string (nullable = true)
 |-- channel_name: string (nullable = true)
 |-- subscriber_count: string (nullable = true)
 |-- banner_link: string (nullable = true)
 |-- description: string (nullable = true)
 |-- keywords: string (nullable = true)
 |-- avatar: string (nullable = true)
 |-- country: string (nullable = true)
 |-- total_views: string (nullable = true)
 |-- total_videos: string (nullable = true)
 |-- join_date: string (nullable = true)
 |-- mean_views_last_30_videos: string (nullable = true)
 |-- median_views_last_30_videos: string (nullable = true)
 |-- std_views_last_30_videos: string (nullable = true)
 |-- videos_per_week: string (nullable = true)



Очищаем

In [82]:
clean_data = youtube_dataset.filter((youtube_dataset.subscriber_count > 0.0) &
                                    (youtube_dataset.total_views > 0.0) &
                                    (youtube_dataset.mean_views_last_30_videos > 0.0) &
                                    (youtube_dataset.total_videos > 0.0) &
                                    (youtube_dataset.median_views_last_30_videos > 0.0) &
                                    (youtube_dataset.std_views_last_30_videos > 0.0) &
                                    (youtube_dataset.videos_per_week > 0.0) )


clean_data.show(1000)

+--------------------+--------------------+--------------------+----------------+--------------------+--------------------+--------------------+--------------------+--------------------+------------+------------+----------+-------------------------+---------------------------+------------------------+---------------+
|          channel_id|        channel_link|        channel_name|subscriber_count|         banner_link|         description|            keywords|              avatar|             country| total_views|total_videos| join_date|mean_views_last_30_videos|median_views_last_30_videos|std_views_last_30_videos|videos_per_week|
+--------------------+--------------------+--------------------+----------------+--------------------+--------------------+--------------------+--------------------+--------------------+------------+------------+----------+-------------------------+---------------------------+------------------------+---------------+
|UCOqwGhI1AmpWwxMY...|/@KichuandYugiMag...|

Добавляем уникальные номера под каждую страну

In [83]:
countrys = clean_data.select('country').distinct().withColumn('country_id', monotonically_increasing_id())
data_encode = clean_data.join(countrys, on='country').drop('country')

In [84]:
data_encode.show()

+--------------------+--------------------+--------------------+----------------+--------------------+--------------------+--------------------+--------------------+------------+------------+----------+-------------------------+---------------------------+------------------------+---------------+----------+
|          channel_id|        channel_link|        channel_name|subscriber_count|         banner_link|         description|            keywords|              avatar| total_views|total_videos| join_date|mean_views_last_30_videos|median_views_last_30_videos|std_views_last_30_videos|videos_per_week|country_id|
+--------------------+--------------------+--------------------+----------------+--------------------+--------------------+--------------------+--------------------+------------+------------+----------+-------------------------+---------------------------+------------------------+---------------+----------+
|UCiik0mUNQJE7OP6i...|    /@-dancezumba364|Р Р°Р·СЂР°Р±Р°С‚С...|         

Если у пользователя есть ссылка в баннере, описание, ключевые слова, аватар, то мы ставим 1, в отсальном, Null заменяем на 0

In [85]:
data_encode = data_encode.withColumn("banner_link", when(data_encode.banner_link.isNull(), 0).otherwise(1)).\
                        withColumn("description", when(data_encode.description.isNull(), 0).otherwise(1)).\
                        withColumn("keywords", when(data_encode.keywords.isNull(), 0).otherwise(1)).\
                        withColumn("avatar", when(data_encode.avatar.isNull(), 0).otherwise(1)).drop('channel_id').\
                        drop('channel_link').drop('channel_name').drop('join_date').na.fill(0)

In [86]:
data_encode.show()

+----------------+-----------+-----------+--------+------+------------+------------+-------------------------+---------------------------+------------------------+---------------+----------+
|subscriber_count|banner_link|description|keywords|avatar| total_views|total_videos|mean_views_last_30_videos|median_views_last_30_videos|std_views_last_30_videos|videos_per_week|country_id|
+----------------+-----------+-----------+--------+------+------------+------------+-------------------------+---------------------------+------------------------+---------------+----------+
|          115000|          1|          1|       1|     1|  30109443.0|       189.0|                    811.9|                      649.5|         538.07207385876|            1.0|         0|
|         6630000|          1|          1|       0|     1|4383148592.0|      2854.0|         11555.4333333333|                     5981.0|        12544.7803532873|           0.25|         0|
|              48|          0|          1|   

One Hot Encode'им страны 

In [87]:
from pyspark.ml.feature import OneHotEncoder as ohe

encode = ohe()
encode.setInputCols(['country_id'])
encode.setOutputCols(['country'])

model = encode.fit(data_encode)

In [88]:
end_result = model.transform(data_encode).drop('country_id')

In [89]:
end_result.show()

+----------------+-----------+-----------+--------+------+------------+------------+-------------------------+---------------------------+------------------------+---------------+---------------+
|subscriber_count|banner_link|description|keywords|avatar| total_views|total_videos|mean_views_last_30_videos|median_views_last_30_videos|std_views_last_30_videos|videos_per_week|        country|
+----------------+-----------+-----------+--------+------+------------+------------+-------------------------+---------------------------+------------------------+---------------+---------------+
|          115000|          1|          1|       1|     1|  30109443.0|       189.0|                    811.9|                      649.5|         538.07207385876|            1.0|(170,[0],[1.0])|
|         6630000|          1|          1|       0|     1|4383148592.0|      2854.0|         11555.4333333333|                     5981.0|        12544.7803532873|           0.25|(170,[0],[1.0])|
|              48|  

Кастуем в нужный нам тип данных столбцы

In [90]:
end_result = end_result.withColumn('subscriber_count', end_result.subscriber_count.cast(IntegerType()))\
                        .withColumn('banner_link', end_result.banner_link.cast(IntegerType()))\
                        .withColumn('description', end_result.description.cast(IntegerType()))\
                        .withColumn('keywords', end_result.keywords.cast(IntegerType()))\
                        .withColumn('avatar', end_result.avatar.cast(IntegerType()))\
                        .withColumn('total_videos', end_result.total_videos.cast(IntegerType()))\
                        .withColumn('total_views', end_result.total_views.cast(IntegerType()))\
                        .withColumn('mean_views_last_30_videos', end_result.mean_views_last_30_videos.cast(FloatType()))\
                        .withColumn('median_views_last_30_videos', end_result.median_views_last_30_videos.cast(FloatType()))\
                        .withColumn('std_views_last_30_videos', end_result.std_views_last_30_videos.cast(FloatType()))\
                        .withColumn('videos_per_week', end_result.videos_per_week.cast(FloatType())).fillna(0)
                        


Можно провести следующие 3 эксперимента:
1) Число просмотров на основе наличия у него информации в баннере, аватара и тп.
2) Классификация пользователей по наличию у него ссылки на баннер
3) Кластеризации данных о подписчиках YouTube каналов на основе количества подписчиков, среднего числа просмотров последних 30 видео, медианы просмотров последних 30 видео и количества выпускаемых видео в неделю

In [91]:
assembler = VectorAssembler(
    inputCols=['country','total_videos','videos_per_week'],
    outputCol="features")

experiment1 = assembler.transform(end_result).select('total_views','features')

In [92]:
experiment1.show()

+-----------+--------------------+
|total_views|            features|
+-----------+--------------------+
|     296394|(172,[0,170,171],...|
|   12488171|(172,[0,170,171],...|
|     155045|(172,[0,170,171],...|
|   69308821|(172,[0,170,171],...|
|  461453109|(172,[0,170,171],...|
|        134|(172,[0,170,171],...|
|     479884|(172,[0,170,171],...|
|   37310085|(172,[0,170,171],...|
|      68256|(172,[0,170,171],...|
|    2944530|(172,[0,170,171],...|
|     432969|(172,[0,170,171],...|
|    2428854|(172,[0,170,171],...|
|     145866|(172,[0,170,171],...|
|     357961|(172,[0,170,171],...|
|     139638|(172,[0,170,171],...|
|      23121|(172,[0,170,171],...|
|    7746011|(172,[0,170,171],...|
|     155628|(172,[0,170,171],...|
|      86898|(172,[0,170,171],...|
|   10953670|(172,[0,170,171],...|
+-----------+--------------------+
only showing top 20 rows



Так как есть OHE фичи, то есть разреженные данные, то их лучше отскейлить через MaxAbs

In [93]:
from pyspark.ml.feature import MaxAbsScaler

maScaler = MaxAbsScaler()
maScaler.setInputCol("features")

model = maScaler.fit(experiment1)
model.setOutputCol("features_scaled")
scaled_data = model.transform(experiment1).drop('features')

Тут всё относительно просто:
достает алгоритмы регресси: рассматриваются 2 - с использованием решающих деревьев и линейной регресии

In [94]:
from pyspark.ml.regression import DecisionTreeRegressor,LinearRegression

from pyspark.ml.evaluation import RegressionEvaluator

features = ['countrys','total_videos', 'videos_per_week']
(train, test) = scaled_data.randomSplit([0.8, 0.2], seed=42)


lr = LinearRegression(featuresCol="features_scaled", labelCol="total_views", regParam=0.1, elasticNetParam=1.0,maxIter=1000)


model = lr.fit(train)

# Оценка модели на тестовой выборке
test_results = model.transform(test)
rmse = RegressionEvaluator(labelCol="total_views", predictionCol="prediction", metricName="rmse").evaluate(test_results)
r2 = RegressionEvaluator(labelCol="total_views", predictionCol="prediction", metricName="r2").evaluate(test_results)
print(f"RMSE: {rmse:.2f}")
print(f"R-squared: {r2:.2f}")


RMSE: 129517189.62
R-squared: 0.03


Ошибка очень большая, как и R-squred метрика
Вообще, нужно провести стат гипотезу, и провести эксперименты о наличии взаимосвязи между фичами и таргетом
Пока что, я бы утверждал, что взаимосвязи нет

In [95]:
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import RandomForestRegressor
from pyspark.ml.evaluation import RegressionEvaluator


(train, test) = experiment1.randomSplit([0.8, 0.2], seed=42)

# Создание модели дерева решений для регрессии
rf = RandomForestRegressor(featuresCol="features", labelCol="total_views",maxDepth=15)

# Обучение модели
model = rf.fit(train)

# Оценка модели на тестовой выборке
predictions = model.transform(test)
rmse = RegressionEvaluator(labelCol="total_views", predictionCol="prediction", metricName="rmse").evaluate(predictions)
r2 = RegressionEvaluator(labelCol="total_views", predictionCol="prediction", metricName="r2").evaluate(predictions)
print("Root Mean Squared Error (RMSE) on test data = %g" % rmse)
print("R2 on test data = %g" % r2)

# Получение важности признаков
feature_importances = model.featureImportances
print("Feature importances: ", feature_importances)

Root Mean Squared Error (RMSE) on test data = 1.27862e+08
R2 on test data = 0.0570009
Feature importances:  (172,[0,1,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,20,21,22,23,24,26,27,28,29,32,33,35,36,37,39,40,41,43,44,45,46,47,48,49,51,52,53,54,56,57,60,61,62,63,65,66,67,69,70,71,72,74,75,76,78,79,80,81,82,83,84,85,86,87,88,89,91,93,94,95,96,97,98,99,100,101,103,104,105,106,108,109,110,113,114,115,116,117,138,139,156,166,170,171],[0.00033273324719486125,1.5596604806206922e-08,0.004967766854694074,0.004557281034963844,0.015005781575890536,0.0027345732754428063,0.0019945683714377264,6.682181500051693e-06,0.0018891390239406769,1.2112000795915434e-06,0.004388803863570467,7.898215506891029e-05,0.011918038239313115,0.0004431029799467579,0.00376867186545142,0.03019192617178848,0.00026815615380233846,3.1066298860573886e-05,0.00928874582269866,0.0022795664520190166,2.102811977324128e-05,1.4648091263620624e-05,0.00011254555868932447,0.0018381493952833114,0.0016014852366865399,0.011529338340015951,

Дерево решений также показывает неудволетворительный результат, чтобы делать вывод о взаимосвязи
Удовлетворительно - когда R2 > 0.5 хотя бы

In [33]:
end_result.show()

+----------------+-----------+-----------+--------+------+-----------+------------+-------------------------+---------------------------+------------------------+---------------+---------------+
|subscriber_count|banner_link|description|keywords|avatar|total_views|total_videos|mean_views_last_30_videos|median_views_last_30_videos|std_views_last_30_videos|videos_per_week|        country|
+----------------+-----------+-----------+--------+------+-----------+------------+-------------------------+---------------------------+------------------------+---------------+---------------+
|          115000|          1|          1|       1|     1|   30109443|         189|                    811.9|                      649.5|                538.0721|            1.0|(170,[0],[1.0])|
|         6630000|          1|          1|       0|     1|          0|        2854|                11555.434|                     5981.0|                12544.78|           0.25|(170,[0],[1.0])|
|              48|       

Классификация пользователей по наличию у него ссылки на баннер

In [96]:
assembler = VectorAssembler(
    inputCols=['total_views','subscriber_count',"total_views", "total_videos","mean_views_last_30_videos",
               "median_views_last_30_videos","std_views_last_30_videos","videos_per_week"],
    outputCol="features")

experiment2 = assembler.transform(end_result).select('features','banner_link')

In [97]:
experiment2.show()

+--------------------+-----------+
|            features|banner_link|
+--------------------+-----------+
|[3.0109443E7,1150...|          1|
|[0.0,6630000.0,0....|          1|
|[18404.0,48.0,184...|          0|
|[21075.0,34.0,210...|          1|
|[3.26647961E8,130...|          1|
|[543984.0,3710.0,...|          1|
|[634823.0,40300.0...|          1|
|[848.0,8.0,848.0,...|          1|
|[8337926.0,33100....|          1|
|[6451.0,210.0,645...|          1|
|[3.3902333E7,9120...|          1|
|[8.7184215E7,2570...|          1|
|[810961.0,559.0,8...|          1|
|[144922.0,334.0,1...|          1|
|[665613.0,679.0,6...|          1|
|[7683.0,213.0,768...|          1|
|[9251525.0,44900....|          1|
|[101748.0,343.0,1...|          1|
|[434957.0,1260.0,...|          1|
|[744.0,17.0,744.0...|          0|
+--------------------+-----------+
only showing top 20 rows



Скейлим и задаем модель

Попробуем использовать Robust скейлер

In [98]:
from pyspark.ml.feature import RobustScaler

scaler = RobustScaler()
scaler.setInputCol("features")
scaler.setOutputCol("features_scaled")

model = scaler.fit(experiment2)
scaled_data = model.transform(experiment2).drop('features')

In [99]:
from pyspark.ml.classification import DecisionTreeClassifier


(train, test) = scaled_data.randomSplit([0.8, 0.2], seed=42)


clsf = DecisionTreeClassifier(featuresCol="features_scaled", labelCol="banner_link",maxDepth=5)

# Обучение модели
model = clsf.fit(train)

# Оценка модели на тестовой выборке
predictions = model.transform(test)
accuracy = MulticlassClassificationEvaluator(labelCol="banner_link", predictionCol="prediction", metricName="accuracy").evaluate(predictions)
f1 = MulticlassClassificationEvaluator(labelCol="banner_link", predictionCol="prediction", metricName="f1").evaluate(predictions)
print("accuracy = %g" % accuracy)
print("f1 = %g" % f1)



accuracy = 0.864611
f1 = 0.802934


Метрики довольно высокие, особенно f1, поэтому классифицировать людей по наличию ссылке в баннере/ на баннер - можно

После 3 эксперимента, в целях экономии времени, я просто буду показывать вывод
Кроме последнего датасета в SQLITE формате

Кластеризации данных о подписчиках YouTube каналов на основе количества подписчиков, среднего числа просмотров последних 30 видео, медианы просмотров последних 30 видео и количества выпускаемых видео в неделю

In [100]:

from pyspark.ml.clustering import KMeans
from pyspark.ml.feature import VectorAssembler


# Объединяем выбранные признаки в один вектор
assembler = VectorAssembler(inputCols=["subscriber_count",
                                        "mean_views_last_30_videos",
                                        "median_views_last_30_videos", "videos_per_week"], outputCol="features")
experiment3 = assembler.transform(end_result).select('features')

# Создаем модель KMeans с количеством кластеров равным 2
kmeans = KMeans(featuresCol="features", k=2)
model = kmeans.fit(experiment3)

# Добавляем предсказанные кластеры к оригинальному DataFrame
predictions = model.transform(experiment3)



In [101]:
evaluator = ClusteringEvaluator()
silhouette = evaluator.evaluate(predictions)
print("Silhouette score = ", silhouette)

Silhouette score =  0.9999853836831163


И также, кластеризовать получается вполне отлично

Следующий датасет


In [40]:
creditCard = spark.read.format("csv").\
                option("delimiter", ",").\
                option("header","true").\
                option("encoding", "cp1251").\
                load("creditcard.csv")
creditCard.show()

+----+------------------+-------------------+------------------+-------------------+-------------------+-------------------+--------------------+-------------------+------------------+-------------------+------------------+------------------+-------------------+-------------------+-------------------+-------------------+--------------------+-------------------+-------------------+-------------------+--------------------+-------------------+-------------------+-------------------+-------------------+-------------------+--------------------+-------------------+------+-----+
|Time|                V1|                 V2|                V3|                 V4|                 V5|                 V6|                  V7|                 V8|                V9|                V10|               V11|               V12|                V13|                V14|                V15|                V16|                 V17|                V18|                V19|                V20|                 V

1) Построение классификатора для обнаружения мошеннических операций
2) Кластеризация 
3) Линейная Регрессия для предсказания суммы транзакций

In [41]:
creditCard = creditCard.select(*(col(c).cast("float").alias(c) for c in creditCard.columns))

In [42]:
features = creditCard.columns[:-1]

In [43]:
from pyspark.ml.classification import GBTClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator

# Объединим колонки 'Time' и 'Amount' в одну колонку 'features'
assembler = VectorAssembler(inputCols=features, outputCol='features')
data_prepared = assembler.transform(creditCard)

# Разделим данные на обучающий и тестовый наборы
train_data, test_data = data_prepared.randomSplit([0.7, 0.3], seed=42)

# Обучим модель GBTClassifier
gbt = GBTClassifier(labelCol='Class', featuresCol='features', maxIter=10)
model = gbt.fit(train_data)

# Сделаем предсказания на тестовом наборе
predictions = model.transform(test_data)

# Оценим качество модели
evaluator = BinaryClassificationEvaluator(labelCol='Class', metricName='areaUnderROC')
auc = evaluator.evaluate(predictions)
print("Area Under ROC: %.4f" % auc)

Area Under ROC: 0.9738


In [44]:
kmeans = KMeans(k=2, featuresCol='features')
model_kmeans = kmeans.fit(data_prepared)

# Получим предсказанные кластеры
predictions_kmeans = model_kmeans.transform(data_prepared)

# Выведем кластеры
predictions_kmeans.select('features', 'prediction').show(5)

+--------------------+----------+
|            features|prediction|
+--------------------+----------+
|[0.0,-1.359807133...|         1|
|[0.0,1.1918570995...|         1|
|[1.0,-1.358354091...|         1|
|[1.0,-0.966271698...|         1|
|[2.0,-1.158233046...|         1|
+--------------------+----------+
only showing top 5 rows



In [45]:
evaluator = ClusteringEvaluator()
silhouette = evaluator.evaluate(predictions)
print("Silhouette score = ", silhouette)

Silhouette score =  -0.024190360835891354


А вот здесь скор по кластеризации очень уж плохой

In [46]:
assembler = VectorAssembler(inputCols=['Time',"Class"] + ['V'+str(i) for i in range(1, 29)], outputCol='features')
data_prepared = assembler.transform(creditCard)

train_data_regression, test_data_regression = data_prepared.randomSplit([0.7, 0.3], seed=42)

# Обучим модель линейной регрессии
lr = LinearRegression(labelCol='Amount', featuresCol='features')
model_regression = lr.fit(train_data_regression)

# Сделаем предсказания на тестовом наборе
predictions_regression = model_regression.transform(test_data_regression)

# Оценим качество модели
rmse = RegressionEvaluator(labelCol='Amount', metricName='rmse')
rmse = rmse.evaluate(predictions_regression)

r2 = RegressionEvaluator(labelCol='Amount', metricName='r2')
r2 = r2.evaluate(predictions_regression)

print("Root Mean Squared Error: %.4f" % rmse)
print("r2: %.4f" % r2)


Root Mean Squared Error: 75.8473
r2: 0.9111


А предсказывать Amount мы умеем

"All upwork" Dataset!

In [47]:
all_upwork1 = spark.read.format("csv").\
                option("delimiter", ",").\
                option("header","true").\
                option("encoding", "cp1251").\
                load("all_upwork_jobs_2024-02-07-2024-03-24.csv")

all_upwork2 = spark.read.format("csv").\
                option("delimiter", ",").\
                option("header","true").\
                option("encoding", "cp1251").\
                load("all_upwork_jobs_2024-03-24-2024-05-21.csv")

In [48]:
all_upwork1.show()

+--------------------+--------------------+--------------------+---------+----------+-----------+------+--------------+
|               title|                link|      published_date|is_hourly|hourly_low|hourly_high|budget|       country|
+--------------------+--------------------+--------------------+---------+----------+-----------+------+--------------+
|Experienced Media...|https://www.upwor...|2024-02-17 09:09:...|    False|      null|       null| 500.0|          null|
|Full Stack Developer|https://www.upwor...|2024-02-17 09:09:...|    False|      null|       null|1100.0| United States|
|     SMMA Bubble App|https://www.upwor...|2024-02-17 09:08:...|     True|      10.0|       30.0|  null| United States|
|Talent Hunter Spe...|https://www.upwor...|2024-02-17 09:08:...|     True|      null|       null|  null| United States|
|       Data Engineer|https://www.upwor...|2024-02-17 09:07:...|    False|      null|       null| 650.0|         India|
|SEO for Portugues...|https://www.upwor.

In [49]:
all_upwork2.show()

+--------------------+--------------------+--------------------+---------+----------+-----------+------+--------------------+
|               title|                link|      published_date|is_hourly|hourly_low|hourly_high|budget|             country|
+--------------------+--------------------+--------------------+---------+----------+-----------+------+--------------------+
|Real Estate Acqui...|https://www.upwor...|2024-05-20 19:14:...|     true|         4|          6|  null|       United States|
|WordPress Pagespe...|https://www.upwor...|2024-05-20 19:14:...|    false|      null|       null|    50|       United States|
|Environmental  + ...|https://www.upwor...|2024-05-20 19:14:...|    false|      null|       null|   100|       United States|
|Unity 2D/3D Game ...|https://www.upwor...|2024-05-20 19:14:...|     null|      null|       null|  null|            Pakistan|
|Senior Ionic Angu...|https://www.upwor...|2024-05-20 19:14:...|     true|        30|         45|  null|            Sl

Добавление второй таблицы к первой

In [50]:
all_upwork = all_upwork1.union(all_upwork2)

In [51]:
all_upwork.show()

+--------------------+--------------------+--------------------+---------+----------+-----------+------+--------------+
|               title|                link|      published_date|is_hourly|hourly_low|hourly_high|budget|       country|
+--------------------+--------------------+--------------------+---------+----------+-----------+------+--------------+
|Experienced Media...|https://www.upwor...|2024-02-17 09:09:...|    False|      null|       null| 500.0|          null|
|Full Stack Developer|https://www.upwor...|2024-02-17 09:09:...|    False|      null|       null|1100.0| United States|
|     SMMA Bubble App|https://www.upwor...|2024-02-17 09:08:...|     True|      10.0|       30.0|  null| United States|
|Talent Hunter Spe...|https://www.upwor...|2024-02-17 09:08:...|     True|      null|       null|  null| United States|
|       Data Engineer|https://www.upwor...|2024-02-17 09:07:...|    False|      null|       null| 650.0|         India|
|SEO for Portugues...|https://www.upwor.

Энкодим стобоце is_hourly и кастим остальные столцбы

In [52]:
all_upwork = all_upwork.replace(['False', 'True'], ['0', '1'], 'is_hourly')
data = all_upwork.select('is_hourly','hourly_low','hourly_high','budget')
data = data.select(*[col(c).cast("double").alias(c) for c in data.columns]).withColumn('id',monotonically_increasing_id())
data.show()


+---------+----------+-----------+------+---+
|is_hourly|hourly_low|hourly_high|budget| id|
+---------+----------+-----------+------+---+
|      0.0|      null|       null| 500.0|  0|
|      0.0|      null|       null|1100.0|  1|
|      1.0|      10.0|       30.0|  null|  2|
|      1.0|      null|       null|  null|  3|
|      0.0|      null|       null| 650.0|  4|
|      1.0|      null|       null|  null|  5|
|      0.0|      null|       null|   5.0|  6|
|      1.0|       7.0|       22.0|  null|  7|
|      1.0|      null|       null|  null|  8|
|      0.0|      null|       null| 500.0|  9|
|      0.0|      null|       null|  50.0| 10|
|      0.0|      null|       null|1200.0| 11|
|      1.0|      null|       null|  null| 12|
|      1.0|      40.0|       75.0|  null| 13|
|      0.0|      null|       null| 300.0| 14|
|      1.0|      30.0|       50.0|  null| 15|
|      0.0|      null|       null|   5.0| 16|
|      0.0|      null|       null|  20.0| 17|
|      1.0|      12.0|       30.0|

Убираем все записи, где все колонки null, так как в таком случае, непонятно, какая оплата

In [53]:
exceptrion_values = data.filter(data.hourly_low.isNull() &
                                data.hourly_high.isNull() & 
                                data.budget.isNull())



data = data.exceptAll(exceptrion_values).drop('id')

In [54]:
data.show()

+---------+----------+-----------+------+
|is_hourly|hourly_low|hourly_high|budget|
+---------+----------+-----------+------+
|      1.0|      15.0|       35.0|  null|
|      0.0|      null|       null| 100.0|
|      0.0|      null|       null| 300.0|
|      0.0|      null|       null| 500.0|
|      1.0|      15.0|       30.0|  null|
|      0.0|      null|       null|  18.0|
|      0.0|      null|       null|  50.0|
|      1.0|       5.0|       10.0|  null|
|      1.0|      10.0|       25.0|  null|
|      1.0|      10.0|       40.0|  null|
|      1.0|      15.0|       30.0|  null|
|      1.0|       5.0|       25.0|  null|
|      0.0|      null|       null|2000.0|
|      1.0|       3.0|       11.0|  null|
|      0.0|      null|       null|  20.0|
|      1.0|      20.0|       30.0|  null|
|      1.0|      35.0|       50.0|  null|
|      0.0|      null|       null|  50.0|
|      1.0|       5.0|       null|  null|
|      0.0|      null|       null| 200.0|
+---------+----------+-----------+

Предсказание бюджета

In [55]:


assembler = VectorAssembler(inputCols=['is_hourly','hourly_low','hourly_high'], outputCol="features")
vec = assembler.transform(data.fillna(0)).select('budget', 'features')

maScaler = MaxAbsScaler()
maScaler.setInputCol("features")
model = maScaler.fit(vec)
model.setOutputCol("features_scaled")
scaled_data = model.transform(vec).drop('features')


# Создание модели линейной регрессии
rf = RandomForestRegressor(featuresCol="features_scaled", labelCol="budget")


# Разделение данных на обучающий и тестовый набор
train_data, test_data = scaled_data.randomSplit([0.8, 0.2])

# Обучение модели и оценка качества
model = rf.fit(train_data)
predictions = model.transform(test_data)


rmse = RegressionEvaluator(labelCol='budget', metricName='rmse').evaluate(predictions)


r2 = RegressionEvaluator(labelCol='budget', metricName='r2').evaluate(predictions)


print("Root Mean Squared Error: %.4f" % rmse)
print("r2: %.4f" % r2)

Root Mean Squared Error: 11471.6096
r2: 0.0016


In [60]:
data.show()

+---------+----------+-----------+------+
|is_hourly|hourly_low|hourly_high|budget|
+---------+----------+-----------+------+
|      1.0|      15.0|       35.0|  null|
|      0.0|      null|       null| 100.0|
|      0.0|      null|       null| 300.0|
|      0.0|      null|       null| 500.0|
|      1.0|      15.0|       30.0|  null|
|      0.0|      null|       null|  18.0|
|      0.0|      null|       null|  50.0|
|      1.0|       5.0|       10.0|  null|
|      1.0|      10.0|       25.0|  null|
|      1.0|      10.0|       40.0|  null|
|      1.0|      15.0|       30.0|  null|
|      1.0|       5.0|       25.0|  null|
|      0.0|      null|       null|2000.0|
|      1.0|       3.0|       11.0|  null|
|      0.0|      null|       null|  20.0|
|      1.0|      20.0|       30.0|  null|
|      1.0|      35.0|       50.0|  null|
|      0.0|      null|       null|  50.0|
|      1.0|       5.0|       null|  null|
|      0.0|      null|       null| 200.0|
+---------+----------+-----------+

2) Кластеризация по зарплатам

In [61]:


# Преобразование ключевых слов в числовой формат
assembler = VectorAssembler(inputCols=['is_hourly','hourly_low','hourly_high'], outputCol="features")
vec = assembler.transform(data.fillna(0))
train_data, test_data = vec.randomSplit([0.7, 0.3])

# Создание модели KMeans для кластеризации
kmeans = KMeans(k=5, featuresCol="features")


# Обучение модели кластеризации
model = kmeans.fit(train_data.fillna(0))

# Предсказание кластера для каждой вакансии
clustered_data = model.transform(test_data)

# Вывод кластера для каждой вакансии
clustered_data.show()

evaluator = ClusteringEvaluator()
silhouette = evaluator.evaluate(clustered_data)
print("Silhouette score = ", silhouette)

+---------+----------+-----------+------+---------+----------+
|is_hourly|hourly_low|hourly_high|budget| features|prediction|
+---------+----------+-----------+------+---------+----------+
|      0.0|       0.0|        0.0|   5.0|(3,[],[])|         1|
|      0.0|       0.0|        0.0|   5.0|(3,[],[])|         1|
|      0.0|       0.0|        0.0|   5.0|(3,[],[])|         1|
|      0.0|       0.0|        0.0|   5.0|(3,[],[])|         1|
|      0.0|       0.0|        0.0|   5.0|(3,[],[])|         1|
|      0.0|       0.0|        0.0|   5.0|(3,[],[])|         1|
|      0.0|       0.0|        0.0|   5.0|(3,[],[])|         1|
|      0.0|       0.0|        0.0|   5.0|(3,[],[])|         1|
|      0.0|       0.0|        0.0|   5.0|(3,[],[])|         1|
|      0.0|       0.0|        0.0|   5.0|(3,[],[])|         1|
|      0.0|       0.0|        0.0|   5.0|(3,[],[])|         1|
|      0.0|       0.0|        0.0|   5.0|(3,[],[])|         1|
|      0.0|       0.0|        0.0|   5.0|(3,[],[])|    

Классифкатор типа зарплаты по признакам: высокая оплата в час, низкая почасовая оплата,  и имеющийся бюджет компании - признаки,
Таргет - is_hourly

In [63]:

assembler = VectorAssembler(inputCols=["hourly_low", "hourly_high", "budget"], outputCol="features")
vec = assembler.transform(data.fillna(0)).select('is_hourly', 'features')

# Создание модели GBTClassifier для классификации типа зарплаты
rf = GBTClassifier(featuresCol="features", labelCol="is_hourly")

# Разделение данных на обучающий и тестовый набор
train_data, test_data = vec.randomSplit([0.8, 0.2])

# Обучение модели и проверка точности классификации
model = rf.fit(train_data)
predictions = model.transform(test_data)


auc = BinaryClassificationEvaluator(labelCol='is_hourly', metricName='areaUnderROC').evaluate(predictions)
print("Area Under ROC: %.4f" % auc)

accuracy = MulticlassClassificationEvaluator(labelCol="is_hourly", predictionCol="prediction", metricName="accuracy").evaluate(predictions)
f1 = MulticlassClassificationEvaluator(labelCol="is_hourly", predictionCol="prediction", metricName="f1").evaluate(predictions)
print("accuracy = %g" % accuracy)
print("f1 = %g" % f1)

Area Under ROC: 0.8432
accuracy = 0.775422
f1 = 0.70018


Следующий датасет

In [64]:
mibici = spark.read.format("csv").\
                option("delimiter", ",").\
                option("header","true").\
                option("encoding", "cp1251").\
                load("mibici_2014-2024.csv")

In [65]:
mibici.printSchema()

root
 |-- _c0: string (nullable = true)
 |-- Trip_Id: string (nullable = true)
 |-- User_Id: string (nullable = true)
 |-- Sex: string (nullable = true)
 |-- Birth_year: string (nullable = true)
 |-- Trip_start: string (nullable = true)
 |-- Trip_end: string (nullable = true)
 |-- Origin_Id: string (nullable = true)
 |-- Destination_Id: string (nullable = true)
 |-- Age: string (nullable = true)
 |-- Duration: string (nullable = true)



In [66]:
mibici.show()

+---+--------+-------+---+----------+-------------------+-------------------+---------+--------------+---+---------------+
|_c0| Trip_Id|User_Id|Sex|Birth_year|         Trip_start|           Trip_end|Origin_Id|Destination_Id|Age|       Duration|
+---+--------+-------+---+----------+-------------------+-------------------+---------+--------------+---+---------------+
|  0|32244893|1470734|  M|      1981|2024-01-31 23:59:33|2024-02-01 00:11:15|       24|            86| 43|0 days 00:11:42|
|  1|32244892|2731702|  M|      1994|2024-01-31 23:59:06|2024-02-01 00:10:49|       48|           279| 30|0 days 00:11:43|
|  2|32244891|1431452|  M|      2001|2024-01-31 23:58:48|2024-02-01 00:01:42|      273|           383| 23|0 days 00:02:54|
|  3|32244890|2312602|  F|      2003|2024-01-31 23:58:44|2024-02-01 00:01:58|      273|           383| 21|0 days 00:03:14|
|  4|32244889|2266427|  M|      1999|2024-01-31 23:58:44|2024-02-01 00:01:39|      273|           383| 25|0 days 00:02:55|
|  5|32244888|10

Кодируем половые признаки, и также кастим нужные столбцы 

In [67]:
mibici = mibici.replace(['M','F'],['0','1'], 'Sex')

mibici = mibici.withColumn("day_duration", col("Duration").substr(0, 1)).drop("Duration")


mibici = mibici.withColumn("Trip_start", to_timestamp("Trip_start", "yyyy-MM-dd HH:mm:ss"))\
               .withColumn("Trip_end", to_timestamp("Trip_end", "yyyy-MM-dd HH:mm:ss"))\
               .withColumn("Birth_year", mibici.Birth_year.cast('int'))\
               .withColumn("Age", mibici.Age.cast('int'))\
               .withColumn("Sex", mibici.Sex.cast('int'))\
               .withColumn("day_duration", mibici.day_duration.cast('int'))


In [None]:
mibici.show()

+---+--------+-------+---+----------+-------------------+-------------------+---------+--------------+---+------------+
|_c0| Trip_Id|User_Id|Sex|Birth_year|         Trip_start|           Trip_end|Origin_Id|Destination_Id|Age|day_duration|
+---+--------+-------+---+----------+-------------------+-------------------+---------+--------------+---+------------+
|  0|32244893|1470734|  0|      1981|2024-01-31 23:59:33|2024-02-01 00:11:15|       24|            86| 43|           0|
|  1|32244892|2731702|  0|      1994|2024-01-31 23:59:06|2024-02-01 00:10:49|       48|           279| 30|           0|
|  2|32244891|1431452|  0|      2001|2024-01-31 23:58:48|2024-02-01 00:01:42|      273|           383| 23|           0|
|  3|32244890|2312602|  1|      2003|2024-01-31 23:58:44|2024-02-01 00:01:58|      273|           383| 21|           0|
|  4|32244889|2266427|  0|      1999|2024-01-31 23:58:44|2024-02-01 00:01:39|      273|           383| 25|           0|
|  5|32244888|1071506|  0|      1964|202

In [None]:
mibici.printSchema()

root
 |-- _c0: string (nullable = true)
 |-- Trip_Id: string (nullable = true)
 |-- User_Id: string (nullable = true)
 |-- Sex: integer (nullable = true)
 |-- Birth_year: integer (nullable = true)
 |-- Trip_start: timestamp (nullable = true)
 |-- Trip_end: timestamp (nullable = true)
 |-- Origin_Id: string (nullable = true)
 |-- Destination_Id: string (nullable = true)
 |-- Age: integer (nullable = true)
 |-- day_duration: integer (nullable = true)



In [None]:
mibici = mibici.withColumn('Difference_hour',(col("Trip_end").cast("long") - col('Trip_start').cast("long"))/ 3600)
mibici = mibici.withColumn('Difference_minutes',col("Difference_hour")*60)
mibici = mibici.withColumn('Difference_seconds',col("Difference_hour")*3600)

In [None]:
mibici.show()

+---+--------+-------+---+----------+-------------------+-------------------+---------+--------------+---+------------+-------------------+------------------+------------------+
|_c0| Trip_Id|User_Id|Sex|Birth_year|         Trip_start|           Trip_end|Origin_Id|Destination_Id|Age|day_duration|    Difference_hour|Difference_minutes|Difference_seconds|
+---+--------+-------+---+----------+-------------------+-------------------+---------+--------------+---+------------+-------------------+------------------+------------------+
|  0|32244893|1470734|  0|      1981|2024-01-31 23:59:33|2024-02-01 00:11:15|       24|            86| 43|           0|              0.195|11.700000000000001|             702.0|
|  1|32244892|2731702|  0|      1994|2024-01-31 23:59:06|2024-02-01 00:10:49|       48|           279| 30|           0|0.19527777777777777|11.716666666666667|             703.0|
|  2|32244891|1431452|  0|      2001|2024-01-31 23:58:48|2024-02-01 00:01:42|      273|           383| 23|    

Классификация пользователей по полу в заивимсоти от возраста и продолжительности поездки

In [None]:


assembler = VectorAssembler(inputCols=["Age", "day_duration","Difference_hour"], outputCol="features")
data = assembler.transform(mibici)

# Создание модели логистической регрессии для классификации пола
lr = GBTClassifier(featuresCol="features", labelCol="Sex")

# Разделение данных на обучающий и тестовый набор
train_data, test_data = data.randomSplit([0.8, 0.2])

# Обучение модели и оценка качества
model = lr.fit(train_data)
predictions = model.transform(test_data)

accuracy = MulticlassClassificationEvaluator(labelCol="Sex", predictionCol="prediction", metricName="accuracy").evaluate(predictions)
f1 = MulticlassClassificationEvaluator(labelCol="Sex", predictionCol="prediction", metricName="f1").evaluate(predictions)
auc = BinaryClassificationEvaluator(labelCol='Sex', metricName='areaUnderROC').evaluate(predictions)

print("accuracy = %g" % accuracy)
print("f1 = %g" % f1)
print("Area Under ROC: %.4f" % auc)

accuracy = 0.740672
f1 = 0.630325
Area Under ROC: 0.5973


По полу и количеству времени в пути будет предсказывать пол

In [None]:
assembler = VectorAssembler(inputCols=["day_duration",'Sex','Difference_hour'], outputCol="features")
data = assembler.transform(mibici)

scaler = RobustScaler()
scaler.setInputCol("features")
model = scaler.fit(data)
model.setOutputCol("features_scaled")
scaled_data = model.transform(data).drop('features')


In [None]:
# Создание модели линейной регрессии для предсказания возраста
lr = DecisionTreeRegressor(featuresCol="features_scaled", labelCol="Age")

train_data, test_data = scaled_data.randomSplit([0.8, 0.2])

# Обучение модели
model = lr.fit(train_data)

# Предсказание возраста на основе продолжительности поездки
predictions = model.transform(test_data)

rmse = RegressionEvaluator(labelCol='Age', metricName='rmse').evaluate(predictions)
r2 = RegressionEvaluator(labelCol='Age', metricName='r2').evaluate(predictions)


print("Root Mean Squared Error: %.4f" % rmse)
print("r2: %.4f" % r2)

Root Mean Squared Error: 10.2558
r2: 0.0171


Кластеризация пользователей на основе возраста и продолжительности поездки

In [None]:


assembler = VectorAssembler(inputCols=["Age","Difference_hour"], outputCol="features")
data = assembler.transform(mibici)

train_data, test_data = data.randomSplit([0.8, 0.2])

kmeans = KMeans(k=3, featuresCol="features")

# Обучение модели кластеризации
model = kmeans.fit(train_data)




In [None]:
predictions = model.transform(test_data)
silhouette = ClusteringEvaluator().evaluate(predictions)

print("Silhouette score = ", silhouette)

Silhouette score =  0.7833634390733011


Переходим к последнему датасету

In [102]:
sqlite_file = r'database.sqlite'

country = spark.read.format("jdbc") \
           .option("url", f"jdbc:sqlite:{sqlite_file}") \
           .option("dbtable", "country") \
           .load() 
league = spark.read.format("jdbc") \
           .option("url", f"jdbc:sqlite:{sqlite_file}") \
           .option("dbtable", "League") \
           .load() 
match = spark.read.format("jdbc") \
           .option("url", f"jdbc:sqlite:{sqlite_file}") \
           .option("dbtable", "Match") \
           .load() 
player = spark.read.format("jdbc") \
           .option("url", f"jdbc:sqlite:{sqlite_file}") \
           .option("dbtable", "Player") \
           .load() 
player_attr = spark.read.format("jdbc") \
           .option("url", f"jdbc:sqlite:{sqlite_file}") \
           .option("dbtable", "Player_Attributes") \
           .load() 
 
team = spark.read.format("jdbc") \
           .option("url", f"jdbc:sqlite:{sqlite_file}") \
           .option("dbtable", "Team") \
           .load() 
 
team_attr = spark.read.format("jdbc") \
           .option("url", f"jdbc:sqlite:{sqlite_file}") \
           .option("dbtable", "Team_Attributes") \
           .load() 


In [103]:
league.show()

+-----+----------+--------------------+
|   id|country_id|                name|
+-----+----------+--------------------+
|    1|         1|Belgium Jupiler L...|
| 1729|      1729|England Premier L...|
| 4769|      4769|      France Ligue 1|
| 7809|      7809|Germany 1. Bundes...|
|10257|     10257|       Italy Serie A|
|13274|     13274|Netherlands Eredi...|
|15722|     15722|  Poland Ekstraklasa|
|17642|     17642|Portugal Liga ZON...|
|19694|     19694|Scotland Premier ...|
|21518|     21518|     Spain LIGA BBVA|
|24558|     24558|Switzerland Super...|
+-----+----------+--------------------+



In [104]:
match.show()

+---+----------+---------+---------+-----+-------------------+------------+----------------+----------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+---------------+---------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+---------------+---------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+---------------+---------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+---------------+---------------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+--------------+--------------+-------------+-------------+---------

Смотрим сколько есть уникальных голов есть для home_team

In [105]:
match.select('home_team_goal').distinct().count()


11

In [106]:
country = country.withColumn('EncodeCountry', monotonically_increasing_id()).drop('name')
match = match.join(country, country.id == match.country_id)

In [107]:
match.show()

+-----+----------+---------+---------+-----+-------------------+------------+----------------+----------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+---------------+---------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+---------------+---------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+---------------+---------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+--------------+---------------+---------------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+--------------+--------------+-------------+-------------+-------

In [46]:
player.show()

+---+-------------+--------------------+------------------+-------------------+------+------+
| id|player_api_id|         player_name|player_fifa_api_id|           birthday|height|weight|
+---+-------------+--------------------+------------------+-------------------+------+------+
|  1|       505942|  Aaron Appindangoye|            218353|1992-02-29 00:00:00|   182|   187|
|  2|       155782|     Aaron Cresswell|            189615|1989-12-15 00:00:00|   170|   146|
|  3|       162549|         Aaron Doran|            186170|1991-05-13 00:00:00|   170|   163|
|  4|        30572|       Aaron Galindo|            140161|1982-05-08 00:00:00|   182|   198|
|  5|        23780|        Aaron Hughes|             17725|1979-11-08 00:00:00|   182|   154|
|  6|        27316|          Aaron Hunt|            158138|1986-09-04 00:00:00|   182|   161|
|  7|       564793|          Aaron Kuhl|            221280|1996-01-30 00:00:00|   172|   146|
|  8|        30895|        Aaron Lennon|            152747|1

Уникальных имен достаточно много, чтобы их использовать в качестве фиче
Хотя, почему бы и нет? Это же большие даныне

In [47]:
player.select('player_name').distinct().count()

10848

In [48]:
player_attr.show()

+---+------------------+-------------+-------------------+--------------+---------+--------------+-------------------+-------------------+--------+---------+----------------+-------------+-------+---------+-----+------------------+------------+------------+------------+------------+-------+---------+-------+----------+-------+-------+--------+----------+----------+-------------+-----------+------+---------+-------+---------------+--------------+---------+-----------+----------+--------------+-----------+
| id|player_fifa_api_id|player_api_id|               date|overall_rating|potential|preferred_foot|attacking_work_rate|defensive_work_rate|crossing|finishing|heading_accuracy|short_passing|volleys|dribbling|curve|free_kick_accuracy|long_passing|ball_control|acceleration|sprint_speed|agility|reactions|balance|shot_power|jumping|stamina|strength|long_shots|aggression|interceptions|positioning|vision|penalties|marking|standing_tackle|sliding_tackle|gk_diving|gk_handling|gk_kicking|gk_posit

In [56]:
full_player_info = player_attr.join(player.drop('id'), on='player_api_id').drop('player_fifa_api_id')



In [57]:
full_player_info.show()

+-------------+----+-------------------+--------------+---------+--------------+-------------------+-------------------+--------+---------+----------------+-------------+-------+---------+-----+------------------+------------+------------+------------+------------+-------+---------+-------+----------+-------+-------+--------+----------+----------+-------------+-----------+------+---------+-------+---------------+--------------+---------+-----------+----------+--------------+-----------+-----------------+-------------------+------+------+
|player_api_id|  id|               date|overall_rating|potential|preferred_foot|attacking_work_rate|defensive_work_rate|crossing|finishing|heading_accuracy|short_passing|volleys|dribbling|curve|free_kick_accuracy|long_passing|ball_control|acceleration|sprint_speed|agility|reactions|balance|shot_power|jumping|stamina|strength|long_shots|aggression|interceptions|positioning|vision|penalties|marking|standing_tackle|sliding_tackle|gk_diving|gk_handling|gk_

In [58]:
full_player_info.columns

['player_api_id',
 'id',
 'date',
 'overall_rating',
 'potential',
 'preferred_foot',
 'attacking_work_rate',
 'defensive_work_rate',
 'crossing',
 'finishing',
 'heading_accuracy',
 'short_passing',
 'volleys',
 'dribbling',
 'curve',
 'free_kick_accuracy',
 'long_passing',
 'ball_control',
 'acceleration',
 'sprint_speed',
 'agility',
 'reactions',
 'balance',
 'shot_power',
 'jumping',
 'stamina',
 'strength',
 'long_shots',
 'aggression',
 'interceptions',
 'positioning',
 'vision',
 'penalties',
 'marking',
 'standing_tackle',
 'sliding_tackle',
 'gk_diving',
 'gk_handling',
 'gk_kicking',
 'gk_positioning',
 'gk_reflexes',
 'player_name',
 'birthday',
 'height',
 'weight']

In [59]:
full_player_info = full_player_info.replace(['left','right'],['0','1'], 'preferred_foot')
full_player_info = full_player_info.withColumn('preferred_foot',full_player_info.preferred_foot.cast('integer'))

In [60]:
full_player_info.show()

+-------------+----+-------------------+--------------+---------+--------------+-------------------+-------------------+--------+---------+----------------+-------------+-------+---------+-----+------------------+------------+------------+------------+------------+-------+---------+-------+----------+-------+-------+--------+----------+----------+-------------+-----------+------+---------+-------+---------------+--------------+---------+-----------+----------+--------------+-----------+-----------------+-------------------+------+------+
|player_api_id|  id|               date|overall_rating|potential|preferred_foot|attacking_work_rate|defensive_work_rate|crossing|finishing|heading_accuracy|short_passing|volleys|dribbling|curve|free_kick_accuracy|long_passing|ball_control|acceleration|sprint_speed|agility|reactions|balance|shot_power|jumping|stamina|strength|long_shots|aggression|interceptions|positioning|vision|penalties|marking|standing_tackle|sliding_tackle|gk_diving|gk_handling|gk_

In [61]:
full_player_info.printSchema()

root
 |-- player_api_id: integer (nullable = true)
 |-- id: integer (nullable = true)
 |-- date: string (nullable = true)
 |-- overall_rating: integer (nullable = true)
 |-- potential: integer (nullable = true)
 |-- preferred_foot: integer (nullable = true)
 |-- attacking_work_rate: string (nullable = true)
 |-- defensive_work_rate: string (nullable = true)
 |-- crossing: integer (nullable = true)
 |-- finishing: integer (nullable = true)
 |-- heading_accuracy: integer (nullable = true)
 |-- short_passing: integer (nullable = true)
 |-- volleys: integer (nullable = true)
 |-- dribbling: integer (nullable = true)
 |-- curve: integer (nullable = true)
 |-- free_kick_accuracy: integer (nullable = true)
 |-- long_passing: integer (nullable = true)
 |-- ball_control: integer (nullable = true)
 |-- acceleration: integer (nullable = true)
 |-- sprint_speed: integer (nullable = true)
 |-- agility: integer (nullable = true)
 |-- reactions: integer (nullable = true)
 |-- balance: integer (nullab

1) Предсказание overall_rating игрока


In [63]:
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import RandomForestRegressor
from pyspark.ml.evaluation import RegressionEvaluator




feature_columns = ['potential', 'crossing', 'finishing', 'heading_accuracy', 
                   'short_passing', 'volleys', 'dribbling', 'curve',
                   'free_kick_accuracy', 'long_passing', 'ball_control', 
                   'acceleration', 'sprint_speed', 'agility', 'reactions', 
                   'balance', 'shot_power', 'jumping', 'stamina', 'strength', 
                   'long_shots', 'aggression', 'interceptions', 'positioning', 
                   'vision', 'penalties', 'marking', 'standing_tackle', 
                   'sliding_tackle', 'gk_diving', 'gk_handling', 'gk_kicking', 
                   'gk_positioning', 'gk_reflexes' 
                 ]

assembler = VectorAssembler(inputCols=feature_columns, outputCol="features")
data = assembler.transform(full_player_info.dropna().drop('id'))

train_data, test_data = data.randomSplit([0.8, 0.2], seed=123)


rf = RandomForestRegressor(featuresCol="features", labelCol="overall_rating")


model = rf.fit(train_data)


predictions = model.transform(test_data)


rmse = RegressionEvaluator(labelCol="overall_rating", predictionCol="prediction", metricName="rmse").evaluate(predictions)
r2 = RegressionEvaluator(labelCol="overall_rating", predictionCol="prediction", metricName="r2").evaluate(predictions)
print(f"Root Mean Squared Error (RMSE) on test data: {rmse}")
print(f"R2 on test data: {r2}")

predictions.select("overall_rating", "prediction").show(10)

Root Mean Squared Error (RMSE) on test data: 2.72145946383088
R2 on test data: 0.8505128176814065
+--------------+-----------------+
|overall_rating|       prediction|
+--------------+-----------------+
|            60| 59.7190923345691|
|            58|57.95269625585803|
|            61| 59.4405084456727|
|            61| 59.4405084456727|
|            68| 66.1781610086714|
|            70|69.74202367823763|
|            72| 70.0966439928736|
|            72| 70.0966439928736|
|            72|69.86123890186602|
|            72|69.86123890186602|
+--------------+-----------------+
only showing top 10 rows



2) Определение предпочтительной ноги игрока (preferred_foot)

In [66]:
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator



feature_columns = ['potential', 'crossing', 'finishing', 'heading_accuracy', 
                   'short_passing', 'volleys', 'dribbling', 'curve', 
                   'free_kick_accuracy', 'long_passing', 'ball_control', 
                   'acceleration', 'sprint_speed', 'agility', 'reactions', 
                   'balance', 'shot_power', 'jumping', 'stamina', 'strength', 
                   'long_shots', 'aggression', 'interceptions', 'positioning', 
                   'vision', 'penalties', 'marking', 'standing_tackle', 
                   'sliding_tackle', 'gk_diving', 'gk_handling', 'gk_kicking', 
                   'gk_positioning', 'gk_reflexes']

assembler = VectorAssembler(inputCols=feature_columns, outputCol="features")
data = assembler.transform(full_player_info.dropna().drop('id'))


train_data, test_data = data.randomSplit([0.8, 0.2], seed=123)


rf = RandomForestClassifier(featuresCol="features", labelCol="preferred_foot")


model = rf.fit(train_data)

predictions = model.transform(test_data)

accuracy = MulticlassClassificationEvaluator(labelCol="preferred_foot", predictionCol="prediction", metricName="accuracy").evaluate(predictions)
f1 = MulticlassClassificationEvaluator(labelCol="preferred_foot", predictionCol="prediction", metricName="f1").evaluate(predictions)
print(f"Accuracy on test data: {accuracy}")
print(f"f1 on test data: {f1}")




Accuracy on test data: 0.7569558328932373
f1 on test data: 0.6522442081057279


In [90]:
team.show()

+----+-----------+----------------+--------------------+---------------+
|  id|team_api_id|team_fifa_api_id|      team_long_name|team_short_name|
+----+-----------+----------------+--------------------+---------------+
|   1|       9987|             673|            KRC Genk|            GEN|
|   2|       9993|             675|        Beerschot AC|            BAC|
|   3|      10000|           15005|    SV Zulte-Waregem|            ZUL|
|   4|       9994|            2007|    Sporting Lokeren|            LOK|
|   5|       9984|            1750|   KSV Cercle Brugge|            CEB|
|   6|       8635|             229|      RSC Anderlecht|            AND|
|   7|       9991|             674|            KAA Gent|            GEN|
|   8|       9998|            1747|           RAEC Mons|            MON|
|   9|       7947|            null|       FCV Dender EH|            DEN|
|  10|       9985|             232|   Standard de Liège|            STL|
|  11|       8203|          110724|         KV Mech

In [91]:
team_attr.show()

+---+----------------+-----------+-------------------+----------------+---------------------+--------------------+-------------------------+------------------+-----------------------+---------------------------+---------------------+--------------------------+----------------------+---------------------------+----------------------+---------------------------+------------------------------+---------------+--------------------+-----------------+----------------------+----------------+---------------------+------------------------+
| id|team_fifa_api_id|team_api_id|               date|buildUpPlaySpeed|buildUpPlaySpeedClass|buildUpPlayDribbling|buildUpPlayDribblingClass|buildUpPlayPassing|buildUpPlayPassingClass|buildUpPlayPositioningClass|chanceCreationPassing|chanceCreationPassingClass|chanceCreationCrossing|chanceCreationCrossingClass|chanceCreationShooting|chanceCreationShootingClass|chanceCreationPositioningClass|defencePressure|defencePressureClass|defenceAggression|defenceAggressio

In [92]:
full_team_info = team_attr.join(team, on = 'team_api_id')

In [93]:
full_team_info.show()

+-----------+---+----------------+-------------------+----------------+---------------------+--------------------+-------------------------+------------------+-----------------------+---------------------------+---------------------+--------------------------+----------------------+---------------------------+----------------------+---------------------------+------------------------------+---------------+--------------------+-----------------+----------------------+----------------+---------------------+------------------------+----+----------------+--------------------+---------------+
|team_api_id| id|team_fifa_api_id|               date|buildUpPlaySpeed|buildUpPlaySpeedClass|buildUpPlayDribbling|buildUpPlayDribblingClass|buildUpPlayPassing|buildUpPlayPassingClass|buildUpPlayPositioningClass|chanceCreationPassing|chanceCreationPassingClass|chanceCreationCrossing|chanceCreationCrossingClass|chanceCreationShooting|chanceCreationShootingClass|chanceCreationPositioningClass|defencePress

3) Предсказание defencePressure

In [94]:
from pyspark.ml.feature import StringIndexer, OneHotEncoder
from pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler

columns =  ['buildUpPlaySpeedClass', 'buildUpPlayDribblingClass', 'buildUpPlayPassingClass',
             'buildUpPlayPositioningClass', 'chanceCreationPassingClass', 'chanceCreationCrossingClass',
             'chanceCreationShootingClass','chanceCreationPositioningClass', 'defencePressureClass', 
             'defenceAggressionClass', 'defenceTeamWidthClass', 'defenceDefenderLineClass']


indexers = [StringIndexer(inputCol=col, outputCol=col+"_index", handleInvalid="keep") for col in columns]

# Создание объектов OneHotEncoder для кодирования числовых значений
encoders = [OneHotEncoder(inputCol=col+"_index", outputCol=col+"_encoded") for col in columns]

# Создание объекта VectorAssembler для объединения признаков в один вектор
assembler = VectorAssembler(inputCols=[col+"_encoded" for col in columns], outputCol="features")

pipeline = Pipeline(stages=indexers + encoders + [assembler])

# Применение преобразований к данным
model = pipeline.fit(full_team_info)

transformed_data = model.transform(full_team_info)

In [95]:
from pyspark.ml.regression import LinearRegression


train_data, test_data = transformed_data.randomSplit([0.7, 0.3])

regr = LinearRegression(featuresCol='features', labelCol='defencePressure', predictionCol='prediction')

# Обучение модели
regr_model = regr.fit(train_data)

# Получение прогнозов
predictions = regr_model.transform(test_data)
predictions.select('defencePressure', 'prediction').show()

+---------------+------------------+
|defencePressure|        prediction|
+---------------+------------------+
|             32|31.096655945355057|
|             40| 46.68363670527772|
|             57|46.201308089748586|
|             50| 48.07847524104754|
|             47| 49.25543767082076|
|             70| 67.84079820522007|
|             48| 43.80425759247473|
|             59|47.195271313610206|
|             59|47.195271313610206|
|             47| 45.74561705564787|
|             45| 45.37034568036416|
|             36| 45.37034568036416|
|             47| 45.37034568036416|
|             42| 45.74561705564787|
|             45|   45.511659468801|
|             36| 45.74561705564787|
|             46| 45.37034568036416|
|             58| 45.74561705564787|
|             45| 45.37034568036416|
|             43| 45.37034568036416|
+---------------+------------------+
only showing top 20 rows



In [96]:
from pyspark.ml.evaluation import RegressionEvaluator
rmse = RegressionEvaluator(labelCol='defencePressure', metricName='rmse').evaluate(predictions)
r2 = RegressionEvaluator(labelCol='defencePressure', metricName='r2').evaluate(predictions)


print("Root Mean Squared Error: %.4f" % rmse)
print("r2: %.4f" % r2)

Root Mean Squared Error: 6.8905
r2: 0.5376


На этом всё, спасибо!