# Обучение модели
В этом ноутбуке выполняется обучение модели на признаках, полученных в ноутбуках `prepare_and_save_features.ipynb` и `get_extended_posts_features.ipynb`.
Цель обучения - создать модель классификации, прогнозирующей, понравится ли пользователю с определенными характеристиками пост с его определенными характеристиками. 

In [1]:
import pandas as pd
import numpy as np
import datetime as dt
from dotenv import dotenv_values
from sklearn.metrics import recall_score, precision_score, f1_score
from catboost import CatBoostClassifier


RND_STATE = 2024
TEST_ROWS_COUNT = 300000
SAVE_MODEL_TO_FILE = False
SAVE_POSTS_DATA_TO_CSV = False
SAVE_POSTS_DATA_TO_DB = False

In [2]:
env_dict = dotenv_values("../creds/.env")
MODEL_FILE_NAME = env_dict.get("KC_STARTML_EXTENDED_MODEL_FILE", 
                               f"catboost_extended_model_{dt.datetime.now()}"
                               .split(".")[0]
                               .replace(" ", "_")
                               .replace(":", "_"))

## Загрузка подготовленных данных
Перед загрузкой данных с признаками надо сначала выполнить все действия в ноутбуке `prepare_and_save_features.ipynb`, предварительно поставив флаг `SAVE_DATA_TO_CSV = True`, а затем в ноутбуке `get_extended_posts_features.ipynb` для получения расширенного набора признаков для постов.

In [3]:
df_user_data = pd.read_csv("../data/user_features.csv")
df_user_data.head(3)

Unnamed: 0,user_id,gender,age,country_and_city,exp_group,os_iOS,source_organic
0,200,1,34,0.000869,3,0,0
1,201,0,37,0.010972,0,0,0
2,202,1,17,0.019796,4,0,0


In [4]:
df_post_data = pd.read_csv("../data/post_features.csv")
df_post_data.head(3)

Unnamed: 0,post_id,topic,tfidf_mean,tfidf_max
0,1,0.04743,0.627607,0.731115
1,2,0.04743,1.492879,-0.773927
2,3,0.04743,1.700723,-0.894502


In [5]:
df_post_d2v = pd.read_csv("../data/post_features_d2v.csv")
df_post_d2v.head(3)

Unnamed: 0,post_id,d2v_1,d2v_2,d2v_3,d2v_4,d2v_5,d2v_6,d2v_7,d2v_8,d2v_9,...,d2v_11,d2v_12,d2v_13,d2v_14,d2v_15,d2v_16,d2v_17,d2v_18,d2v_19,d2v_20
0,1,-5.696737,3.600629,1.888588,4.697815,-5.616477,-11.314406,4.154247,12.685128,-5.646606,...,2.846017,0.421474,-0.325882,-0.62348,-0.075401,6.569971,-3.440696,0.225233,2.80183,-3.116241
1,2,0.725621,-2.183182,-3.711747,-1.569358,-4.500835,-3.644571,4.032551,3.127399,-5.547462,...,-1.159752,0.315606,-0.859167,-0.523222,-1.368132,3.444111,2.81056,-1.729794,-0.331542,-1.127183
2,3,-4.928478,0.95042,4.643132,3.792058,-7.63181,-4.393818,4.179424,6.589443,-8.521201,...,3.551821,-3.042779,5.222722,1.212748,0.734874,5.77593,3.756821,-2.572037,5.106787,-4.602608


In [6]:
df_post_data = df_post_data.merge(df_post_d2v, on="post_id")
df_post_data.head(3)

Unnamed: 0,post_id,topic,tfidf_mean,tfidf_max,d2v_1,d2v_2,d2v_3,d2v_4,d2v_5,d2v_6,...,d2v_11,d2v_12,d2v_13,d2v_14,d2v_15,d2v_16,d2v_17,d2v_18,d2v_19,d2v_20
0,1,0.04743,0.627607,0.731115,-5.696737,3.600629,1.888588,4.697815,-5.616477,-11.314406,...,2.846017,0.421474,-0.325882,-0.62348,-0.075401,6.569971,-3.440696,0.225233,2.80183,-3.116241
1,2,0.04743,1.492879,-0.773927,0.725621,-2.183182,-3.711747,-1.569358,-4.500835,-3.644571,...,-1.159752,0.315606,-0.859167,-0.523222,-1.368132,3.444111,2.81056,-1.729794,-0.331542,-1.127183
2,3,0.04743,1.700723,-0.894502,-4.928478,0.95042,4.643132,3.792058,-7.63181,-4.393818,...,3.551821,-3.042779,5.222722,1.212748,0.734874,5.77593,3.756821,-2.572037,5.106787,-4.602608


In [7]:
df_feed_data = pd.read_csv("../data/feed_data.csv")
df_feed_data.head(3)

Unnamed: 0,timestamp,user_id,post_id,target
0,2021-10-01 06:01:40,1859,1498,1
1,2021-10-01 06:01:40,8663,3837,1
2,2021-10-01 06:01:40,15471,2810,0


## Объединение данных в один DataFrame

In [8]:
df_views_data = (df_feed_data
                .merge(df_user_data, on="user_id")
                .merge(df_post_data, on="post_id")).sort_values("timestamp").reset_index(drop=True)
del df_feed_data

In [9]:
target_counts = df_views_data.target.value_counts()
print(f"Количества значений целевой переменной:\n{target_counts}")
print(f"Баланс классов: {target_counts[1] / target_counts[0]:.4f}")
del target_counts

Количества значений целевой переменной:
0    4861763
1     498618
Name: target, dtype: int64
Баланс классов: 0.1026


## Создание моделей

### Деление данных на тренировочные и тестовые

In [10]:
X_train = df_views_data.iloc[:-TEST_ROWS_COUNT, :].drop(["target", "post_id", "user_id", "timestamp"], axis=1)
y_train = df_views_data.iloc[:-TEST_ROWS_COUNT, :]["target"]
X_test = df_views_data.iloc[-TEST_ROWS_COUNT:, :].drop(["target", "post_id", "user_id", "timestamp"], axis=1)
y_test = df_views_data.iloc[-TEST_ROWS_COUNT:, :]["target"]

In [11]:
def print_model_scores(model, X_train, X_test, y_train, y_test):
    """
    Печать величин Precision, Recall и F1-score переданной модели
    по переданным тренировочным и тестовым данным
    """
    pred_train = model.predict(X_train)
    print(f"Train Precision: {precision_score(y_train, pred_train):.4}")
    print(f"Train Recall: {recall_score(y_train, pred_train):.4}")
    print(f"Train F1-score: {f1_score(y_train, pred_train):.4}")
    print("")
    pred_test = model.predict(X_test)
    print(f"Precision: {precision_score(y_test, pred_test):.4}")
    print(f"Recall: {recall_score(y_test, pred_test):.4}")
    print(f"F1-score: {f1_score(y_test, pred_test):.4}")


In [12]:
def get_final_metric_value(model, test_data, k=5):
    """
    Функция измерения финальной метрики HitRate@k
    переданной модели, по переданным тестовым данным
    """
    data = test_data.copy()
    X_test_data = data.drop(["target", "post_id", "user_id", "timestamp"], axis=1)
    y_test_date = data.target

    X_test_data["pred"] = model.predict_proba(X_test_data)[:, 1]
    X_test_data["target"] = y_test_date
    X_test_data["user_id"] = data["user_id"]
    X_test_data["post_id"] = data["post_id"]

    users_hitrate = []
    
    for user in X_test_data["user_id"].unique():
        part = X_test_data[X_test_data["user_id"]==user]
        part = part.sort_values("pred", ascending=False)
        part = part.reset_index()
        user_hitrate =  part.target[:k].max()
        users_hitrate.append(user_hitrate)
        
    result = np.mean(users_hitrate)
    print(f"Среднее HitRate@{k} по пользователям из теста: {result}")
    
    return result

### Модель CatBoost

#### Обучение финальной модели

In [13]:
%%time
model = CatBoostClassifier(random_seed=RND_STATE, verbose=False, 
                          class_weights = {0: 1, 1:10},
                          iterations=1000,
                          l2_leaf_reg=1,
                          learning_rate=0.4,
                          depth=4)
model.fit(X_train, y_train)

print_model_scores(model, X_train, X_test, y_train, y_test)

Train Precision: 0.1577
Train Recall: 0.7093
Train F1-score: 0.258

Precision: 0.1771
Recall: 0.6746
F1-score: 0.2806
CPU times: user 44min 1s, sys: 17.7 s, total: 44min 19s
Wall time: 13min 47s


#### Измерение финальной метрики

In [14]:
get_final_metric_value(model, 
                       df_views_data.iloc[-TEST_ROWS_COUNT:, :], 
                       k=5)

Среднее HitRate@5 по пользователям из теста: 0.5656455142231948


0.5656455142231948

## Сохранение модели в файл ## 

In [15]:
if SAVE_MODEL_TO_FILE:
    model.save_model(f"../models/{MODEL_FILE_NAME}", format="cbm")

## Сохранение данных

### Сохранение расширенных признаков постов.

In [16]:
if SAVE_POSTS_DATA_TO_CSV:
    df_post_data.to_csv("../data/post_features_extended.csv", index=False)

In [17]:
if SAVE_POSTS_DATA_TO_DB:
    env_dict = dotenv_values("../creds/.env")
    
    server = env_dict.get("KC_STARTML_POSTGRES_SERVER", "")
    port = env_dict.get("KC_STARTML_POSTGRES_PORT")
    db_name = env_dict.get("KC_STARTML_POSTGRES_DB")
    user = env_dict.get("KC_STARTML_POSTGRES_USER", "")
    password = env_dict.get("KC_STARTML_POSTGRES_PASSWORD")
    
    password = (":" + password) if password else ""
    port = (":" + port) if port else ""
    db_name = ("/" + db_name) if db_name else ""
    
    connection_string = f"postgresql://{user}{password}@{server}{port}{db_name}"

    posts_features_ext_table = env_dict.get("KC_STARTML_POSTS_FEATURES_EXT_TABLE", "")

    df_post_data.to_sql(posts_features_ext_table, 
                        con=connection_string, 
                        if_exists="replace")
    