In [1]:
import pandas as pd
import numpy as np
import yaml

from catboost import CatBoostClassifier
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report

In [2]:
# Прочтём файл конфига с путями

CONFIG_PATH = "config.yaml"
with open(CONFIG_PATH, "r", encoding="utf-8") as config_file:
    CONFIG = yaml.load(config_file, Loader=yaml.FullLoader)

In [3]:
# Загрузим обработанный датасет

data = pd.read_csv(CONFIG['datasets_folder'] + '/processed_df.csv')

columns_order = [
    'post_id', 'topic', 'tfidf_sum', 
    'tfidf_mean', 'tfidf_max', 'user_id',
    'gender', 'age', 'country', 'city', 
    'exp_group', 'os', 'source','month', 
    'hour', 'day', 'weekday', 'timestamp', 'target'
]
data = data[columns_order]

In [4]:
data.head()

Unnamed: 0,post_id,topic,tfidf_sum,tfidf_mean,tfidf_max,user_id,gender,age,country,city,exp_group,os,source,month,hour,day,weekday,timestamp,target
0,7089,movie,6.740111,5e-06,0.189926,9253,1,18,Russia,Shchigry,3,0,0,10,6,22,4,2021-10-22 06:54:16,0
1,1374,politics,23.924812,1.8e-05,0.232533,94674,1,28,Russia,Moscow,0,0,0,10,6,22,4,2021-10-22 06:54:16,0
2,2023,tech,17.645959,1.3e-05,0.168426,36482,1,47,Russia,Rostov,3,0,0,10,6,22,4,2021-10-22 06:54:16,0
3,1830,sport,28.01625,2.1e-05,0.30791,87763,0,30,Belarus,Syanno,1,1,0,10,6,22,4,2021-10-22 06:54:16,0
4,2613,covid,5.566688,4e-06,0.196314,74045,1,25,Kazakhstan,Shymkent,1,1,0,10,6,22,4,2021-10-22 06:54:16,0


In [5]:
# Разделим датасет на трейн и тест

X = data.drop(['timestamp', 'target', 'user_id', 'post_id'], axis=1)
y = data['target']

X_train = X.iloc[:-712175].copy()
y_train = y.iloc[:-712175].copy()

X_test = X.iloc[-712175:].copy()
y_test = y.iloc[-712175:].copy()

In [6]:
# Ввиду дисбаланса классов, найдем их веса

classes = np.unique(y_train)
weights = compute_class_weight(class_weight='balanced', classes=classes, y=y_train)
class_weights = dict(zip(classes, weights))

In [7]:
class_weights

{0: 0.5710634823862322, 1: 4.017981269778513}

In [8]:
# Обучим катбуст

categorical_features = ['country', 'city', 'topic']
catboost_model = CatBoostClassifier(class_weights=class_weights, cat_features=categorical_features)

catboost_model.fit(X_train, y_train)

Learning rate set to 0.391196
0:	learn: 0.6765103	total: 1.75s	remaining: 29m 8s
1:	learn: 0.6699536	total: 3.27s	remaining: 27m 12s
2:	learn: 0.6673221	total: 4.74s	remaining: 26m 15s
3:	learn: 0.6654400	total: 5.76s	remaining: 23m 55s
4:	learn: 0.6644420	total: 6.87s	remaining: 22m 47s
5:	learn: 0.6639221	total: 8.16s	remaining: 22m 31s
6:	learn: 0.6632817	total: 9.39s	remaining: 22m 11s
7:	learn: 0.6629294	total: 10.8s	remaining: 22m 22s
8:	learn: 0.6568434	total: 12.4s	remaining: 22m 48s
9:	learn: 0.6563937	total: 13.9s	remaining: 23m
10:	learn: 0.6562075	total: 15.1s	remaining: 22m 36s
11:	learn: 0.6557992	total: 16.5s	remaining: 22m 38s
12:	learn: 0.6556431	total: 18.3s	remaining: 23m 6s
13:	learn: 0.6555298	total: 19.9s	remaining: 23m 24s
14:	learn: 0.6527359	total: 21.6s	remaining: 23m 38s
15:	learn: 0.6526287	total: 23.1s	remaining: 23m 42s
16:	learn: 0.6517529	total: 24.8s	remaining: 23m 53s
17:	learn: 0.6514457	total: 26.5s	remaining: 24m 6s
18:	learn: 0.6513054	total: 27.5s

153:	learn: 0.6422994	total: 3m 51s	remaining: 21m 12s
154:	learn: 0.6422934	total: 3m 53s	remaining: 21m 12s
155:	learn: 0.6422739	total: 3m 54s	remaining: 21m 10s
156:	learn: 0.6422417	total: 3m 56s	remaining: 21m 9s
157:	learn: 0.6422106	total: 3m 57s	remaining: 21m 7s
158:	learn: 0.6421633	total: 3m 59s	remaining: 21m 5s
159:	learn: 0.6421573	total: 4m 1s	remaining: 21m 5s
160:	learn: 0.6421449	total: 4m 2s	remaining: 21m 4s
161:	learn: 0.6420994	total: 4m 4s	remaining: 21m 3s
162:	learn: 0.6420837	total: 4m 6s	remaining: 21m 4s
163:	learn: 0.6420556	total: 4m 7s	remaining: 21m 2s
164:	learn: 0.6420230	total: 4m 9s	remaining: 21m
165:	learn: 0.6420127	total: 4m 10s	remaining: 20m 59s
166:	learn: 0.6419977	total: 4m 12s	remaining: 20m 59s
167:	learn: 0.6419574	total: 4m 14s	remaining: 20m 58s
168:	learn: 0.6419386	total: 4m 15s	remaining: 20m 58s
169:	learn: 0.6418979	total: 4m 17s	remaining: 20m 56s
170:	learn: 0.6418533	total: 4m 18s	remaining: 20m 54s
171:	learn: 0.6418167	total:

304:	learn: 0.6382553	total: 7m 48s	remaining: 17m 47s
305:	learn: 0.6382322	total: 7m 49s	remaining: 17m 45s
306:	learn: 0.6382234	total: 7m 51s	remaining: 17m 44s
307:	learn: 0.6382095	total: 7m 53s	remaining: 17m 43s
308:	learn: 0.6381925	total: 7m 54s	remaining: 17m 41s
309:	learn: 0.6381606	total: 7m 56s	remaining: 17m 40s
310:	learn: 0.6381336	total: 7m 57s	remaining: 17m 38s
311:	learn: 0.6381270	total: 7m 59s	remaining: 17m 36s
312:	learn: 0.6381132	total: 8m 1s	remaining: 17m 35s
313:	learn: 0.6381028	total: 8m 2s	remaining: 17m 34s
314:	learn: 0.6380975	total: 8m 4s	remaining: 17m 32s
315:	learn: 0.6380839	total: 8m 5s	remaining: 17m 31s
316:	learn: 0.6380642	total: 8m 7s	remaining: 17m 29s
317:	learn: 0.6380510	total: 8m 8s	remaining: 17m 27s
318:	learn: 0.6379785	total: 8m 10s	remaining: 17m 26s
319:	learn: 0.6379685	total: 8m 11s	remaining: 17m 24s
320:	learn: 0.6379377	total: 8m 12s	remaining: 17m 22s
321:	learn: 0.6379322	total: 8m 14s	remaining: 17m 21s
322:	learn: 0.63

453:	learn: 0.6354987	total: 11m 44s	remaining: 14m 7s
454:	learn: 0.6354756	total: 11m 46s	remaining: 14m 5s
455:	learn: 0.6354688	total: 11m 47s	remaining: 14m 4s
456:	learn: 0.6354620	total: 11m 49s	remaining: 14m 2s
457:	learn: 0.6354535	total: 11m 51s	remaining: 14m 1s
458:	learn: 0.6354368	total: 11m 52s	remaining: 14m
459:	learn: 0.6354072	total: 11m 54s	remaining: 13m 58s
460:	learn: 0.6354023	total: 11m 56s	remaining: 13m 57s
461:	learn: 0.6353837	total: 11m 57s	remaining: 13m 55s
462:	learn: 0.6353637	total: 11m 58s	remaining: 13m 53s
463:	learn: 0.6353577	total: 12m	remaining: 13m 52s
464:	learn: 0.6353366	total: 12m 2s	remaining: 13m 51s
465:	learn: 0.6353176	total: 12m 3s	remaining: 13m 49s
466:	learn: 0.6353129	total: 12m 5s	remaining: 13m 48s
467:	learn: 0.6353009	total: 12m 7s	remaining: 13m 46s
468:	learn: 0.6352889	total: 12m 8s	remaining: 13m 45s
469:	learn: 0.6352830	total: 12m 10s	remaining: 13m 43s
470:	learn: 0.6352652	total: 12m 11s	remaining: 13m 41s
471:	learn

601:	learn: 0.6333254	total: 15m 41s	remaining: 10m 22s
602:	learn: 0.6333182	total: 15m 43s	remaining: 10m 20s
603:	learn: 0.6333093	total: 15m 44s	remaining: 10m 19s
604:	learn: 0.6333021	total: 15m 46s	remaining: 10m 17s
605:	learn: 0.6332836	total: 15m 47s	remaining: 10m 16s
606:	learn: 0.6332662	total: 15m 49s	remaining: 10m 14s
607:	learn: 0.6332279	total: 15m 50s	remaining: 10m 12s
608:	learn: 0.6332224	total: 15m 52s	remaining: 10m 11s
609:	learn: 0.6332091	total: 15m 53s	remaining: 10m 9s
610:	learn: 0.6332025	total: 15m 55s	remaining: 10m 8s
611:	learn: 0.6331942	total: 15m 57s	remaining: 10m 6s
612:	learn: 0.6331897	total: 15m 58s	remaining: 10m 5s
613:	learn: 0.6331811	total: 16m	remaining: 10m 3s
614:	learn: 0.6331624	total: 16m 2s	remaining: 10m 2s
615:	learn: 0.6331567	total: 16m 3s	remaining: 10m
616:	learn: 0.6331488	total: 16m 5s	remaining: 9m 59s
617:	learn: 0.6331360	total: 16m 7s	remaining: 9m 57s
618:	learn: 0.6331302	total: 16m 8s	remaining: 9m 56s
619:	learn: 0.

751:	learn: 0.6314768	total: 19m 42s	remaining: 6m 30s
752:	learn: 0.6314737	total: 19m 44s	remaining: 6m 28s
753:	learn: 0.6314677	total: 19m 46s	remaining: 6m 27s
754:	learn: 0.6314479	total: 19m 47s	remaining: 6m 25s
755:	learn: 0.6314451	total: 19m 49s	remaining: 6m 23s
756:	learn: 0.6314327	total: 19m 51s	remaining: 6m 22s
757:	learn: 0.6314147	total: 19m 52s	remaining: 6m 20s
758:	learn: 0.6314125	total: 19m 54s	remaining: 6m 19s
759:	learn: 0.6314071	total: 19m 55s	remaining: 6m 17s
760:	learn: 0.6313892	total: 19m 57s	remaining: 6m 16s
761:	learn: 0.6313685	total: 19m 58s	remaining: 6m 14s
762:	learn: 0.6313613	total: 20m	remaining: 6m 12s
763:	learn: 0.6313399	total: 20m 1s	remaining: 6m 11s
764:	learn: 0.6313259	total: 20m 3s	remaining: 6m 9s
765:	learn: 0.6312941	total: 20m 4s	remaining: 6m 8s
766:	learn: 0.6312853	total: 20m 6s	remaining: 6m 6s
767:	learn: 0.6312703	total: 20m 8s	remaining: 6m 4s
768:	learn: 0.6312518	total: 20m 9s	remaining: 6m 3s
769:	learn: 0.6312293	tot

902:	learn: 0.6298618	total: 23m 46s	remaining: 2m 33s
903:	learn: 0.6298529	total: 23m 47s	remaining: 2m 31s
904:	learn: 0.6298504	total: 23m 49s	remaining: 2m 30s
905:	learn: 0.6298461	total: 23m 51s	remaining: 2m 28s
906:	learn: 0.6298378	total: 23m 53s	remaining: 2m 26s
907:	learn: 0.6298292	total: 23m 54s	remaining: 2m 25s
908:	learn: 0.6298224	total: 23m 56s	remaining: 2m 23s
909:	learn: 0.6298027	total: 23m 58s	remaining: 2m 22s
910:	learn: 0.6297971	total: 24m	remaining: 2m 20s
911:	learn: 0.6297819	total: 24m 1s	remaining: 2m 19s
912:	learn: 0.6297500	total: 24m 3s	remaining: 2m 17s
913:	learn: 0.6297467	total: 24m 4s	remaining: 2m 15s
914:	learn: 0.6297359	total: 24m 6s	remaining: 2m 14s
915:	learn: 0.6297242	total: 24m 7s	remaining: 2m 12s
916:	learn: 0.6297000	total: 24m 9s	remaining: 2m 11s
917:	learn: 0.6296763	total: 24m 10s	remaining: 2m 9s
918:	learn: 0.6296634	total: 24m 12s	remaining: 2m 8s
919:	learn: 0.6296573	total: 24m 14s	remaining: 2m 6s
920:	learn: 0.6296470	t

<catboost.core.CatBoostClassifier at 0x1426bd23790>

In [9]:
# Грубо оценим качество обученной модели

print(f"Качество на тесте: {catboost_model.score(X_test, y_test)}")
classification_report(y_test, catboost_model.predict(X_test), output_dict=True)

Качество на тесте: 0.561944746726577


{'0': {'precision': 0.92057471073557,
  'recall': 0.5401482486785543,
  'f1-score': 0.6808228960555581,
  'support': 615992},
 '1': {'precision': 0.19238127279103384,
  'recall': 0.7015376937712485,
  'f1-score': 0.30195737977821735,
  'support': 96183},
 'accuracy': 0.561944746726577,
 'macro avg': {'precision': 0.5564779917633019,
  'recall': 0.6208429712249014,
  'f1-score': 0.49139013791688774,
  'support': 712175},
 'weighted avg': {'precision': 0.822228335979619,
  'recall': 0.561944746726577,
  'f1-score': 0.6296551044985624,
  'support': 712175}}

In [10]:
catboost_model.save_model(CONFIG['data_folder'] + '/catboost_model', format="cbm")