# Лаб4. Прогнозирование пола и возрастной категории — Spark Streaming
Задание: https://github.com/newprolab/sber-spark-ds-18/blob/main/labs/lab04.md

### Библиотеки и контекст

In [1]:
import findspark
findspark.init('/opt/spark-3.4.3/')

In [2]:
from pyspark import SparkConf, SparkContext

from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import *

from pyspark.ml.feature import RegexTokenizer, StopWordsRemover, HashingTF, CountVectorizer, IDF

from pyspark.ml.clustering import GaussianMixture
from pyspark.ml.evaluation import ClusteringEvaluator
from pyspark.ml import Pipeline

conf = SparkConf()
spark = (
    SparkSession
    .builder
    .config(conf=conf)
    .appName('max_burdasov_lab4')
    .getOrCreate()
)

sc = spark.sparkContext

https://repos.spark-packages.org/ added as a remote repository with the name: repo-1
Ivy Default Cache set to: /data/home/maksim.burdasov/.ivy2/cache
The jars for the packages stored in: /data/home/maksim.burdasov/.ivy2/jars
org.apache.spark#spark-sql-kafka-0-10_2.12 added as a dependency
:: resolving dependencies :: org.apache.spark#spark-submit-parent-03cef50e-6d1e-410c-8ab8-ecf4088989f8;1.0
	confs: [default]


:: loading settings :: url = jar:file:/opt/spark-3.4.3/jars/ivy-2.5.1.jar!/org/apache/ivy/core/settings/ivysettings.xml


	found org.apache.spark#spark-sql-kafka-0-10_2.12;3.4.3 in central
	found org.apache.spark#spark-token-provider-kafka-0-10_2.12;3.4.3 in central
	found org.apache.kafka#kafka-clients;3.3.2 in central
	found org.lz4#lz4-java;1.8.0 in central
	found org.xerial.snappy#snappy-java;1.1.10.5 in central
	found org.slf4j#slf4j-api;2.0.6 in central
	found org.apache.hadoop#hadoop-client-runtime;3.3.4 in central
	found org.apache.hadoop#hadoop-client-api;3.3.4 in central
	found commons-logging#commons-logging;1.1.3 in central
	found com.google.code.findbugs#jsr305;3.0.0 in central
	found org.apache.commons#commons-pool2;2.11.1 in central
:: resolution report :: resolve 488ms :: artifacts dl 13ms
	:: modules in use:
	com.google.code.findbugs#jsr305;3.0.0 from central in [default]
	commons-logging#commons-logging;1.1.3 from central in [default]
	org.apache.commons#commons-pool2;2.11.1 from central in [default]
	org.apache.hadoop#hadoop-client-api;3.3.4 from central in [default]
	org.apache.hadoop#

In [14]:
# sc.stop()

### Загрузка данных

In [4]:
!hdfs dfs -ls /labs/slaba04/

Found 1 items
-rw-r--r--   3 hdfs hdfs  655090069 2022-01-06 18:46 /labs/slaba04/gender_age_dataset.txt


In [42]:
df = (
    spark
    .read
    .format('csv')
    .option("delimiter", "\t")
    .option('header', 'true')
    .option('inferSchema', 'true')
    .load('/labs/slaba04/gender_age_dataset.txt')
)

# Фильтр строк без указанного таргета
df = df.filter((F.col('age') != '-') | (F.col('gender') != '-'))

df = df.cache()
df.count()

                                                                                

36138

In [12]:
# Проверка уникальности идентификатора
df.groupBy('uid').count().filter(F.col('count') > 1).count()

                                                                                

0

### Изучение данных

In [5]:
df.columns

['gender', 'age', 'uid', 'user_json']

In [19]:
df.take(1)[0]['user_json']

'{"visits": [{"url": "http://zebra-zoya.ru/200028-chehol-organayzer-dlja-macbook-11-grid-it.html?utm_campaign=397720794&utm_content=397729344&utm_medium=cpc&utm_source=begun", "timestamp": 1419688144068}, {"url": "http://news.yandex.ru/yandsearch?cl4url=chezasite.com/htc/htc-one-m9-delay-86327.html&lr=213&rpt=story", "timestamp": 1426666298001}, {"url": "http://www.sotovik.ru/news/240283-htc-one-m9-zaderzhivaetsja.html", "timestamp": 1426666298000}, {"url": "http://news.yandex.ru/yandsearch?cl4url=chezasite.com/htc/htc-one-m9-delay-86327.html&lr=213&rpt=story", "timestamp": 1426661722001}, {"url": "http://www.sotovik.ru/news/240283-htc-one-m9-zaderzhivaetsja.html", "timestamp": 1426661722000}]}'

### Генерация признаков

In [43]:
### Парсинг колонки user_json

# Схема содержимого логов
log_schema = StructType([
    StructField('visits', ArrayType(
        StructType([
            StructField('url', StringType()),
            StructField('timestamp', StringType())
        ])
    ))
])

# Парсинг
df = df.withColumn('user_json', F.from_json(F.col('user_json'), log_schema))

# Создание отельных колонок
df = df.withColumn('urls', F.col('user_json.visits.url'))
df = df.withColumn('timestamps', F.col('user_json.visits.timestamp'))

# Колонка с количеством посещенных сайтов
df = df.withColumn('urls_cnt', F.size('urls'))

# Удаление освоенной колонки
df = df.drop('user_json')

df = df.cache()
df.show(1)

[Stage 161:>                                                        (0 + 1) / 1]

+------+-----+--------------------+--------------------+--------------------+--------+
|gender|  age|                 uid|                urls|          timestamps|urls_cnt|
+------+-----+--------------------+--------------------+--------------------+--------+
|     F|18-24|d50192e5-c44e-4ae...|[http://zebra-zoy...|[1419688144068, 1...|       5|
+------+-----+--------------------+--------------------+--------------------+--------+
only showing top 1 row



                                                                                

In [44]:
### Пайплайн векторизации списка URL

# Объединение URL в одну строку
df = df.withColumn('urls', F.concat_ws(" ", 'urls'))

# Получение колонки токенов из URL
tokenizer = RegexTokenizer(
    inputCol="urls",
    outputCol="tokens_raw",
    pattern=r"\b[\w]{2,}\b",
    gaps=False,  # pattern используется для поиска токенов
    toLowercase=True
)

custom_stopwords = [
    # Общие служебные слова URL
    'http', 'https', 'www', 'com', 'ru', 'net', 'org', 'html', 'php', 'asp', 'aspx', 'jsp',
    'utm', 'referrer', 'ref', 'source', 'click', 'id', 'page', 'index', 'feed', 'menu',
    'api', 'track', 'trackid', 'session', 'sid', 'token', 'auth', 'access', 'key', 'lang',
    'language', 'query', 'search', 'rpt', 'clid', 'clid', 'utm_campaign', 'utm_medium', 'utm_source',
    'utm_content', 'utm_term', 'fbclid', 'gclid', 'mc_cid', 'mc_eid',

    # Частые параметры и служебные слова
    'page', 'view', 'item', 'category', 'product', 'offer', 'promo', 'click', 'redirect',
    'index', 'default', 'main',

    # Общие слова, малоинформативные для кластеризации
    'id', 'type', 'action', 'mode', 'ref', 'referrer', 'sessionid', 'userid', 'user',

    # Распространённые сокращения и слова из URL
    'www1', 'www2', 'www3', 'mobile', 'm', 'amp', 'cdn', 'static', 'cache',

    # Часто встречающиеся короткие слова
    'to', 'in', 'on', 'at', 'by', 'of', 'and', 'or', 'for', 'with', 'from'
]

# Удаление неинформативных слов
custom_stopwords = StopWordsRemover.loadDefaultStopWords("english") + custom_stopwords
stopwords_remover = StopWordsRemover(
    inputCol="tokens_raw",
    outputCol="tokens",
    stopWords=custom_stopwords
)

# Расчет TF
tf_vectorizer = HashingTF(numFeatures=1024, inputCol="tokens", outputCol="tf_features")

# Расчет IDF
idf_counter = IDF(inputCol='tf_features', outputCol='tfidf_features')

pipeline = Pipeline(stages=[tokenizer, stopwords_remover, tf_vectorizer, idf_counter])
model = pipeline.fit(df)

# Преобразуем данные
url_ftrs_df = model.transform(df)

# Датафрейм с таргетами и фичами по спискам URL
url_ftrs_df = (
    url_ftrs_df
    .select(
        # Идентификатор
        'uid', 
        # Колонки фичей
        'urls_cnt', 
        'tfidf_features', 
        # Колонки таргетов
        'gender', 
        'age'
    )
)

url_ftrs_df = url_ftrs_df.cache()
url_ftrs_df.show(1)

[Stage 163:>                                                        (0 + 1) / 1]

+--------------------+--------+--------------------+------+-----+
|                 uid|urls_cnt|      tfidf_features|gender|  age|
+--------------------+--------+--------------------+------+-----+
|d50192e5-c44e-4ae...|       5|(1024,[17,33,37,7...|     F|18-24|
+--------------------+--------+--------------------+------+-----+
only showing top 1 row



                                                                                

In [45]:
### Подготовка фичей из отметок времени

# Разворачиваем списки временных отметок
df_exploded = df.select('uid', F.explode(F.col('timestamps')).alias('ts')).orderBy('uid', F.col('ts').asc())

# Получение дат
df_exploded = df_exploded.withColumn("ts", F.from_unixtime(F.col("ts") / 1000))

# Отдельные колонки для агрегат
df_exploded = df_exploded.withColumn("date", F.to_date("ts"))
df_exploded = df_exploded.withColumn("hour", F.hour("ts"))

# Флаги захода на сайты по четвертям дня
df_exploded = df_exploded.withColumn(
    "night_flg", 
    F.when((df_exploded.hour >= 0) & (df_exploded.hour <= 5), F.lit(1)).otherwise(0)
)

df_exploded = df_exploded.withColumn(
    "morning_flg", 
    F.when((df_exploded.hour >= 6) & (df_exploded.hour <= 11), F.lit(1)).otherwise(0)
)

df_exploded = df_exploded.withColumn(
    "afternoon_flg", 
    F.when((df_exploded.hour >= 12) & (df_exploded.hour <= 17), F.lit(1)).otherwise(0)
)

df_exploded = df_exploded.withColumn(
    "evening_flg", 
    F.when((df_exploded.hour >= 18) & (df_exploded.hour <= 23), F.lit(1)).otherwise(0)
)

# df_exploded.show(30)

# Группировка
time_ftrs_df = df_exploded.groupBy("uid").agg(
    F.mean(F.col('hour')).alias('mean_hour'),
    F.median(F.col('hour')).alias('median_hour'),
    (F.count('*') / F.count_distinct(F.col('date'))).alias('avg_urls_per_day'),
    (F.sum(F.col('night_flg')) * 100.0 / F.count('*')).alias('night_pct'),
    (F.sum(F.col('morning_flg')) * 100.0 / F.count('*')).alias('morning_pct'),
    (F.sum(F.col('afternoon_flg')) * 100.0 / F.count('*')).alias('afternoon_pct'),
    (F.sum(F.col('evening_flg')) * 100.0 / F.count('*')).alias('evening_pct'),
)

# Показываем результат
time_ftrs_df = time_ftrs_df.cache()
time_ftrs_df.show(3)



+--------------------+------------------+-----------+------------------+------------------+------------------+------------------+-----------------+
|                 uid|         mean_hour|median_hour|  avg_urls_per_day|         night_pct|       morning_pct|     afternoon_pct|      evening_pct|
+--------------------+------------------+-----------+------------------+------------------+------------------+------------------+-----------------+
|0392f398-ea7e-4a1...|15.716981132075471|       21.0| 5.170731707317073| 20.28301886792453| 9.433962264150944| 5.188679245283019|65.09433962264151|
|094b1e7e-97a6-441...| 15.96774193548387|       16.0|20.666666666666668|11.290322580645162|1.6129032258064515|53.225806451612904|33.87096774193548|
|095544a2-64f7-422...|              16.5|       17.0|               2.0|               0.0|              25.0|              50.0|             25.0|
|098a0e00-8597-475...|              18.0|       18.0|               2.0|               0.0|               0.0|  

                                                                                

In [46]:
# Получение датасета со всеми фичами
union_df = url_ftrs_df.join(time_ftrs_df, how='left', on='uid')

union_df = union_df.cache()
union_df.show(1, True, True)



-RECORD 0--------------------------------
 uid              | 0392f398-ea7e-4a1... 
 urls_cnt         | 212                  
 tfidf_features   | (1024,[7,8,10,11,... 
 gender           | F                    
 age              | >=55                 
 mean_hour        | 15.716981132075471   
 median_hour      | 21.0                 
 avg_urls_per_day | 5.170731707317073    
 night_pct        | 20.28301886792453    
 morning_pct      | 9.433962264150944    
 afternoon_pct    | 5.188679245283019    
 evening_pct      | 65.09433962264151    
only showing top 1 row



                                                                                

### Подготовка таргета

#### Для обучения 2 отдельных моделей

In [19]:
# union_df.groupBy("gender").count().select("gender").rdd.flatMap(lambda x: x).collect()  # ['F', 'M']
# union_df.groupBy("age").count().select("age").rdd.flatMap(lambda x: x).collect()  # ['>=55', '45-54', '35-44', '25-34', '18-24']

In [47]:
### Приведение таргетов к числовому виду

# Замена колонки с полом
trg_prepared_df = union_df.withColumn(
    'target_gender',
    F.when(F.col('gender') == 'F', 0).otherwise(1)
)

# Замена колонки возраста
trg_prepared_df = trg_prepared_df.withColumn(
    "target_age",
    F.when(F.col('age') == "18-24", 0)
    .when(F.col('age') == "25-34", 1)
    .when(F.col('age') == "35-44", 2)
    .when(F.col('age') == "45-54", 3)
    .otherwise(4)
)

trg_prepared_df.select('uid', 'gender', 'target_gender', 'age', 'target_age').show(5)

+--------------------+------+-------------+-----+----------+
|                 uid|gender|target_gender|  age|target_age|
+--------------------+------+-------------+-----+----------+
|0392f398-ea7e-4a1...|     F|            0| >=55|         4|
|094b1e7e-97a6-441...|     F|            0|18-24|         0|
|095544a2-64f7-422...|     F|            0|25-34|         1|
|098a0e00-8597-475...|     F|            0| >=55|         4|
|0a595fa1-bae0-41d...|     M|            1| >=55|         4|
+--------------------+------+-------------+-----+----------+
only showing top 5 rows



#### Для обучения 1 модели на мульти-таргете

In [31]:
# ### Трансформация таргета (для мультитаргета)

# data = data.withColumn("target_combined", F.concat(F.col("gender"), F.lit('__'), F.col("age")))

# # Замена таргета на число
# data = data.withColumn(
#     "label",  # название таргета из доки
#     # Комбинации таргетов
#     F.when(F.col('target_combined') == "F__18-24", 0)
#     .when(F.col('target_combined') == "F__25-34", 1)
#     .when(F.col('target_combined') == "F__35-44", 2)
#     .when(F.col('target_combined') == "F__45-54", 3)
#     .when(F.col('target_combined') == "F__>=55", 4)
#     .when(F.col('target_combined') == "M__18-24", 5)
#     .when(F.col('target_combined') == "M__25-34", 6)
#     .when(F.col('target_combined') == "M__35-44", 7)
#     .when(F.col('target_combined') == "M__45-54", 8)
#     .when(F.col('target_combined') == "M__>=55", 9)
#     .otherwise(None)
# )

# data.select('gender', 'age', 'target_combined', 'label').show()

In [11]:
# data.filter(F.col('label').isNull()).count()

                                                                                

0

### Подготовка итогового датасета + сохранение в HDFS

In [26]:
from pyspark.ml.feature import VectorAssembler

In [48]:
### Помещение всех признаков в один вектор

feature_cols = [
    "urls_cnt",
    "tfidf_features",
    "mean_hour",
    "median_hour",
    "avg_urls_per_day",
    "night_pct",
    "morning_pct",
    "afternoon_pct",
    "evening_pct"
]

assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")
final_df = assembler.transform(trg_prepared_df)

final_df.show(1, True, True)

-RECORD 0--------------------------------
 uid              | 0392f398-ea7e-4a1... 
 urls_cnt         | 212                  
 tfidf_features   | (1024,[7,8,10,11,... 
 gender           | F                    
 age              | >=55                 
 mean_hour        | 15.716981132075471   
 median_hour      | 21.0                 
 avg_urls_per_day | 5.170731707317073    
 night_pct        | 20.28301886792453    
 morning_pct      | 9.433962264150944    
 afternoon_pct    | 5.188679245283019    
 evening_pct      | 65.09433962264151    
 target_gender    | 0                    
 target_age       | 4                    
 features         | (1032,[0,8,9,11,1... 
only showing top 1 row



In [49]:
final_df = final_df.select(
    'features',
    'target_gender',
    'target_age'
)

final_df.show(3)

+--------------------+-------------+----------+
|            features|target_gender|target_age|
+--------------------+-------------+----------+
|(1032,[0,8,9,11,1...|            0|         4|
|(1032,[0,2,5,8,14...|            0|         0|
|(1032,[0,106,424,...|            0|         1|
+--------------------+-------------+----------+
only showing top 3 rows



In [50]:
# Сохранение в HDFS
(
    final_df
    .write
    .mode('overwrite')
    .parquet('lab04_two_models_dataset_v1.parquet')
)

                                                                                

In [51]:
!hdfs dfs -ls

Found 9 items
drwx------   - maksim.burdasov maksim.burdasov          0 2025-05-05 18:28 .Trash
drwxr-xr-x   - maksim.burdasov maksim.burdasov          0 2025-05-05 17:23 .sparkStaging
drwxr-xr-x   - maksim.burdasov maksim.burdasov          0 2025-04-22 13:00 lab03.csv
drwxr-xr-x   - maksim.burdasov maksim.burdasov          0 2025-05-05 14:33 lab04_model_age_rf_v0
drwxr-xr-x   - maksim.burdasov maksim.burdasov          0 2025-05-05 17:34 lab04_model_age_rf_v1
drwxr-xr-x   - maksim.burdasov maksim.burdasov          0 2025-05-05 17:34 lab04_model_gender_rf_v1
drwxr-xr-x   - maksim.burdasov maksim.burdasov          0 2025-05-05 12:39 lab04_two_models_dataset.parquet
drwxr-xr-x   - maksim.burdasov maksim.burdasov          0 2025-05-05 18:34 lab04_two_models_dataset_v1.parquet
drwxr-xr-x   - maksim.burdasov maksim.burdasov          0 2025-04-22 11:57 models


In [16]:
# # Сохранение в файл
# final_df.toPandas().to_csv("lab04_data/clean_dataset_multitarget.csv", index=False, encoding="utf-8")

                                                                                

### Обучение моделей

In [33]:
from pyspark.ml.classification import DecisionTreeClassifier, RandomForestClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator

In [53]:
# Загрузка подготовленного датасета
two_models_dataset = spark.read.parquet('lab04_two_models_dataset_v1.parquet')

# two_models_dataset.show(3)

In [54]:
train_df, test_df = two_models_dataset.randomSplit([0.8, 0.2], seed=21)

#### Бинарный классификатор для предсказания пола

In [55]:
### Обучение модели для определения пола

# # Модель v1: DecisionTreeClassifier  (Accuracy: 0.5654, ROC-AUC: 0.4415)
# gen_dtc = DecisionTreeClassifier(featuresCol="features", labelCol="target_gender")
# gen_model = gen_dtc.fit(train_df)

# Модель v2: RandomForestClassifier  (Accuracy: 0.5877, ROC-AUC: 0.6147)
gen_dtc = RandomForestClassifier(featuresCol="features", labelCol="target_gender")
gen_model = gen_dtc.fit(train_df)

                                                                                

In [56]:
### Оценка метрик классификации для колонки пола
gen_preds = gen_model.transform(test_df)

gen_bc_evaluator = BinaryClassificationEvaluator(
    labelCol="target_gender", 
    rawPredictionCol="rawPrediction", 
    metricName="areaUnderROC"
)
gen_roc_auc = gen_bc_evaluator.evaluate(gen_preds)
print(f"ROC-AUC: {gen_roc_auc:.4f}")

gen_mc_evaluator = MulticlassClassificationEvaluator(
    labelCol="target_gender",
    predictionCol="prediction",
    metricName="accuracy"
)
gen_accuracy = gen_mc_evaluator.evaluate(gen_preds)
print(f"Accuracy: {gen_accuracy:.4f}")

                                                                                

ROC-AUC: 0.6147




Accuracy: 0.5877


                                                                                

In [57]:
gen_preds = gen_preds.withColumnRenamed('prediction', 'predicted_gender')
gen_preds = gen_preds.select('features', 'target_gender', 'target_age', 'predicted_gender')
gen_preds.show(5)

+--------------------+-------------+----------+----------------+
|            features|target_gender|target_age|predicted_gender|
+--------------------+-------------+----------+----------------+
|(1032,[0,447,619,...|            1|         2|             0.0|
|(1032,[0,1,6,13,1...|            0|         2|             1.0|
|(1032,[0,2,3,4,6,...|            0|         4|             1.0|
|(1032,[0,83,161,2...|            0|         1|             0.0|
|(1032,[0,13,24,38...|            1|         1|             0.0|
+--------------------+-------------+----------+----------------+
only showing top 5 rows



                                                                                

#### Многолассовый классификатор для предсказания возраста (сравнение моделей)

In [43]:
# train_df, test_df = two_models_dataset.randomSplit([0.8, 0.2], seed=21)

In [66]:
### Проверка баланса классов
# train_df.groupBy('target_age').count().show()
# test_df.groupBy('target_age').count().show()

##### v1: DecisionTreeClassifier (Accuracy = 0.4280)

In [53]:
### Обучение модели для определения возраста

age_dtc = DecisionTreeClassifier(featuresCol="features", labelCol="target_age")
age_model = age_dtc.fit(train_df)

                                                                                

In [54]:
### Оценка метрик классификации для колонки пола

age_preds = age_model.transform(test_df)

mc_evaluator = MulticlassClassificationEvaluator(
    labelCol="target_age",
    predictionCol="prediction",
    metricName="accuracy"
)
accuracy = mc_evaluator.evaluate(age_preds)
print(f"Accuracy = {accuracy:.4f}")



Accuracy = 0.4280


                                                                                

##### v2: RandomForestClassifier (Accuracy = 0.4359)

In [58]:
from pyspark.ml.classification import RandomForestClassifier

In [59]:
### Обучение модели

age_rfc = RandomForestClassifier(featuresCol="features", labelCol="target_age")
age_model_2 = age_rfc.fit(train_df)

                                                                                

In [60]:
### Оценка метрик

# age_preds_2 = age_model_2.transform(test_df)
age_preds_2 = age_model_2.transform(gen_preds)  # предсказания первого этапа

mc_evaluator_2 = MulticlassClassificationEvaluator(
    labelCol="target_age",
    predictionCol="prediction",
    metricName="accuracy"
)
accuracy_2 = mc_evaluator_2.evaluate(age_preds_2)
print(f"Accuracy = {accuracy_2:.4f}")



Accuracy = 0.4359


                                                                                

In [19]:
age_preds_2.show(3)

+--------------------+-------------+----------+----------------+--------------------+--------------------+----------+
|            features|target_gender|target_age|predicted_gender|       rawPrediction|         probability|prediction|
+--------------------+-------------+----------+----------------+--------------------+--------------------+----------+
|(1032,[0,1,2,3,4,...|            1|         3|             1.0|[1.78290559466698...|[0.08914527973334...|       1.0|
|(1032,[0,1,2,3,4,...|            1|         3|             1.0|[2.16707246261565...|[0.10835362313078...|       1.0|
|(1032,[0,1,2,3,5,...|            0|         2|             0.0|[1.49032865192912...|[0.07451643259645...|       1.0|
+--------------------+-------------+----------+----------------+--------------------+--------------------+----------+
only showing top 3 rows



In [61]:
### Оценка верных предсказаний ОДНОВРЕМЕННО под двум таргетам

true_cnt = age_preds_2.filter(
    (F.col('target_gender') == F.col('predicted_gender'))
    & (F.col('target_age') == F.col('prediction'))
).count()

total_cnt = age_preds_2.count()

print(
    f'Оба таргета верны: {true_cnt} из {total_cnt} ~ {round(true_cnt * 100 / total_cnt, 2)}%'
)



Оба таргета верны: 1896 из 7387 ~ 25.67%


                                                                                

##### Подбор гиперпараметров
обучение проходило 3+ часа и упало с ошибкой, пробовал с меньшей сеткой - прироста не дало

In [44]:
# from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
# from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# # Определим нашу модель
# rf = RandomForestClassifier(labelCol="target_age", featuresCol="features")

# # Сетка гиперпараметров
# param_grid = (
#     ParamGridBuilder()
#     .addGrid(rf.numTrees, [30, 50, 80])
#     .addGrid(rf.maxDepth, [10, 20, 30])
#     .addGrid(rf.maxBins, [32, 64])  # 32 - стандартное значение maxBins
#     .build()
# )

# # Оценка метрик
# evaluator = MulticlassClassificationEvaluator(labelCol="target_age", predictionCol="prediction", metricName="accuracy")

# # Кросс-валидация
# cross_val = CrossValidator(
#     estimator=rf,
#     estimatorParamMaps=param_grid,
#     evaluator=evaluator,
#     numFolds=3
# )

# # Обучение
# cv_model = cross_val.fit(train_df)

# # Лучшее сочетание гиперпараметров
# best_age_rfc_model = cv_model.bestModel
# print(best_age_rfc_model.extractParamMap())

# # Оценка метрик с лучшими параметрами
# predictions = best_age_rfc_model.transform(test_df)
# test_accuracy = evaluator.evaluate(predictions)
# print(f"Test Accuracy: {test_accuracy:.4f}")

25/05/05 14:52:39 WARN scheduler.DAGScheduler: Broadcasting large task binary with size 1078.1 KiB
25/05/05 14:52:41 WARN scheduler.DAGScheduler: Broadcasting large task binary with size 1460.9 KiB
25/05/05 14:52:44 WARN scheduler.DAGScheduler: Broadcasting large task binary with size 1956.3 KiB
25/05/05 14:52:47 WARN scheduler.DAGScheduler: Broadcasting large task binary with size 2.5 MiB
25/05/05 14:52:52 WARN scheduler.DAGScheduler: Broadcasting large task binary with size 1700.3 KiB
25/05/05 14:53:04 WARN scheduler.DAGScheduler: Broadcasting large task binary with size 1205.8 KiB
25/05/05 14:53:07 WARN scheduler.DAGScheduler: Broadcasting large task binary with size 1607.9 KiB
25/05/05 14:53:11 WARN scheduler.DAGScheduler: Broadcasting large task binary with size 2.1 MiB
25/05/05 14:53:16 WARN scheduler.DAGScheduler: Broadcasting large task binary with size 2.7 MiB
25/05/05 14:53:21 WARN scheduler.DAGScheduler: Broadcasting large task binary with size 1746.5 KiB
25/05/05 14:53:30 W

KeyboardInterrupt: 

25/05/05 17:22:54 WARN scheduler.DAGScheduler: Broadcasting large task binary with size 23.0 MiB
25/05/05 17:23:11 WARN scheduler.DAGScheduler: Broadcasting large task binary with size 24.5 MiB

##### Сохранение моделей в HDFS + тестовый инференс

In [64]:
gen_model.save('lab04_model_gender_rf_v2')

# age_model_2.save('lab04_model_age_rf_v0')
age_model_2.save('lab04_model_age_rf_v2')

In [65]:
!hdfs dfs -ls

Found 11 items
drwx------   - maksim.burdasov maksim.burdasov          0 2025-05-05 18:28 .Trash
drwxr-xr-x   - maksim.burdasov maksim.burdasov          0 2025-05-05 17:23 .sparkStaging
drwxr-xr-x   - maksim.burdasov maksim.burdasov          0 2025-04-22 13:00 lab03.csv
drwxr-xr-x   - maksim.burdasov maksim.burdasov          0 2025-05-05 14:33 lab04_model_age_rf_v0
drwxr-xr-x   - maksim.burdasov maksim.burdasov          0 2025-05-05 17:34 lab04_model_age_rf_v1
drwxr-xr-x   - maksim.burdasov maksim.burdasov          0 2025-05-05 18:39 lab04_model_age_rf_v2
drwxr-xr-x   - maksim.burdasov maksim.burdasov          0 2025-05-05 17:34 lab04_model_gender_rf_v1
drwxr-xr-x   - maksim.burdasov maksim.burdasov          0 2025-05-05 18:39 lab04_model_gender_rf_v2
drwxr-xr-x   - maksim.burdasov maksim.burdasov          0 2025-05-05 12:39 lab04_two_models_dataset.parquet
drwxr-xr-x   - maksim.burdasov maksim.burdasov          0 2025-05-05 18:34 lab04_two_models_dataset_v1.parquet
drwxr-xr-x   - maks

In [15]:
# Вывести полный путь до личной папки
# !echo $(hdfs getconf -confKey fs.defaultFS)/user/$USER
# >>> hdfs://spark-master-1.newprolab.com:8020/user/maksim.burdasov

In [66]:
### Тестовый инференс

from pyspark.ml.classification import RandomForestClassificationModel
# from pyspark.ml.evaluation import MulticlassClassificationEvaluator

my_folder = "hdfs://spark-master-1.newprolab.com:8020/user/maksim.burdasov/"
model_name = "lab04_model_age_rf_v2"

# Загрузка модели
loaded_model = RandomForestClassificationModel.load(my_folder + model_name)

test_preds = loaded_model.transform(test_df)

# Оценка метрик
test_evaluator = MulticlassClassificationEvaluator(labelCol="target_age", predictionCol="prediction", metricName="accuracy")

test_accuracy = test_evaluator.evaluate(test_preds)
print(f"Accuracy = {test_accuracy:.4f}")



Accuracy = 0.4359


                                                                                

In [43]:
test_preds.show(3)

+--------------------+-------------+----------+--------------------+--------------------+----------+
|            features|target_gender|target_age|       rawPrediction|         probability|prediction|
+--------------------+-------------+----------+--------------------+--------------------+----------+
|(1032,[0,1,2,3,4,...|            1|         3|[1.78290559466698...|[0.08914527973334...|       1.0|
|(1032,[0,1,2,3,4,...|            1|         3|[2.16707246261565...|[0.10835362313078...|       1.0|
|(1032,[0,1,2,3,5,...|            0|         2|[1.49032865192912...|[0.07451643259645...|       1.0|
+--------------------+-------------+----------+--------------------+--------------------+----------+
only showing top 3 rows



#### Классификатор предсказания мульти-таргета

In [20]:
### Обучение модели для мультитаргета

dtc_multi = DecisionTreeClassifier(featuresCol="features", labelCol="label")
multi_model = dtc_multi.fit(train_df)

                                                                                

In [21]:
### Оценка метрик классификации для мультитаргета

multi_preds = multi_model.transform(test_df)

multi_mc_evaluator = MulticlassClassificationEvaluator(
    labelCol="label",
    predictionCol="prediction",
    metricName="accuracy"
)
accuracy = multi_mc_evaluator.evaluate(multi_preds)
print(f"Accuracy = {accuracy:.4f}")



Accuracy = 0.2377


                                                                                

In [26]:
# ### Подбор гиперпараметров (требует много времени + упало с ошибкой)

# # from pyspark.ml.classification import DecisionTreeClassifier
# # from pyspark.ml.evaluation import MulticlassClassificationEvaluator
# from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

# dt = DecisionTreeClassifier(labelCol="label", featuresCol="features")

# paramGrid = (
#     ParamGridBuilder()
#     .addGrid(dt.maxDepth, [3, 5, 10])
#     .addGrid(dt.maxBins, [20, 30])
#     .build()
# )

# evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")

# crossval = CrossValidator(
#     estimator=dt,
#     estimatorParamMaps=paramGrid,
#     evaluator=evaluator,
#     numFolds=5
# )

# cvModel = crossval.fit(train_df)

In [None]:
# # Лучшие параметры
# best_params = cvModel.bestModel.extractParamMap()
# for param, value in best_params.items():
#     print(f"{param.name}: {value}")

In [28]:
# ### Точность
# predictions = cvModel.transform(test_df)
# print("Accuracy: ", evaluator.evaluate(predictions))

In [27]:
# # Сохранение модели

# model.write().overwrite().save("lab04_data/model_multi_dt_best")

### Остановка контекста

In [49]:
spark.stop()