In [61]:
from StatsAggregation.analysis_helpers import *
import pandas as pd
import numpy as np
from sklearn.base import TransformerMixin
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
import scipy.sparse as sp
from sklearn.ensemble import StackingClassifier
from itertools import combinations
from catboost import CatBoostClassifier
from sklearn.metrics import f1_score
from sklearn.base import ClassifierMixin
from sklearn.model_selection import cross_validate
from sklearn.base import BaseEstimator
from sklearn.metrics import confusion_matrix

In [62]:
df = load_df('aggregated_df.csv')

In [63]:
df.head()

Unnamed: 0,author,book,text,text_no_punkt,lemmas,tags,tokens,counts,probs
0,Аверченко,Averchenko_A-T-Averchenko-Sobranie-sochineniy-...,Мы за пять лет. Материалы [к биографии]\nКак б...,Мы за пять лет Материалы к биографии Как будто...,пять год материал биография кроваво-красный ра...,4_NUMR 3_NOUN 9_NOUN 9_NOUN 15_ADJF 6_NOUN 8_V...,Мы за пять лет . Материалы [ к биографии ] Как...,243,4.5e-05
1,Аверченко,Averchenko_A-T-Averchenko-Sobranie-sochineniy-...,Сидел он за большим письменным столом перед де...,Сидел он за большим письменным столом перед де...,сидеть больший письменный стол деревянный доск...,5_VERB 7_ADJF 10_ADJF 6_NOUN 10_ADJF 6_NOUN 6_...,Сидел он за большим письменным столом перед де...,243,4.5e-05
2,Аверченко,Averchenko_A-T-Averchenko-Sobranie-sochineniy-...,"— Да уж, — качал головой сдержанный Ре-ми. — Н...",Да уж качал головой сдержанный Ре-ми Нехорошо ...,качать голова сдержать ре-ми нехорошо нехорошо...,5_VERB 7_NOUN 10_PRTF 5_None 8_ADVB 8_ADVB 4_A...,"— Да уж , — качал головой сдержанный Ре-ми . —...",243,4.5e-05
3,Аверченко,Averchenko_A-T-Averchenko-Sobranie-sochineniy-...,"Поверит ли кто-нибудь, что нами за эти пять ле...",Поверит ли кто-нибудь что нами за эти пять лет...,поверить кто-нибудь пять год совместно м.г кор...,7_VERB 10_NPRO 4_NUMR 3_NOUN 9_ADVB 3_None 11_...,"Поверит ли кто-нибудь , что нами за эти пять л...",243,4.5e-05
4,Аверченко,Averchenko_A-T-Averchenko-Sobranie-sochineniy-...,"8\nАверченко А. Избранные рассказы. М., 1985. ...",8 Аверченко А Избранные рассказы М 1985 С 7 9 ...,8 избранный рассказ м 1985 7 9 ежегодник 156 1...,1_None 9_ADJF 8_NOUN 1_NOUN 4_None 1_None 1_No...,"8 Аверченко А. Избранные рассказы . М. , 1985 ...",243,4.5e-05


### Генератор фолдов для кросс-валидации

In [64]:
def books_cross_val(df, k=5, seed=10):
    df_remain = df
    while k > 0:
        if k == 1:
            train_idx = df.index.difference(df_remain.index)
            test_idx = df_remain.index
        else:
            share = (k - 1) / k
            df_remain, fold, _, _ = train_test_split(df_remain, share=share, seed=seed, cross_val=True)
            train_idx = df.index.difference(fold.index)
            test_idx = fold.index
        yield train_idx, test_idx
        k -= 1

### TfIdf для нескольких колонок

In [87]:
class MultiTfidf(TransformerMixin):
    def __init__(self, cols=None, tfidf_type='classic'):
        self.cols = cols
        self.tfidf_type = tfidf_type
    
    def fit(self, X, y=None):
        if self.tfidf_type == 'classic':
            get_vectorizer = get_document_vectorizer
        elif self.tfidf_type == 'class_based':
            get_vectorizer = get_author_vectorizer
        else:
            raise ValueError("Unknown vectorizer")
        self.vectorizers_ = [get_vectorizer(X, column=col) for col in self.cols]
        return self
    
    def transform(self, X, y=None):
        results = []
        for vec, col in zip(self.vectorizers_, self.cols):
            results.append(vec.transform(X[col]))
        return sp.hstack(results)
    
    def get_feature_names_out(self, input_features=None):
        features = np.array([])
        for v in self.vectorizers_:
            features = np.append(features, v.get_feature_names_out())
        return features

In [66]:
class MyStack(ClassifierMixin, BaseEstimator):
    def __init__(self, estimators=None, final_estimator=None):
        self.estimators = estimators
        self.final_estimator = final_estimator
    
    def fit(self, X, y):
        X = X.reset_index(drop=True)
        cv = books_cross_val(X)
        self.model_ = StackingClassifier(self.estimators, self.final_estimator, cv=cv)
        self.model_.fit(X, y)
        return self
    
    def predict_proba(self, X):
        return self.model_.predict_proba(X)
    
    def predict(self, X):
        return self.model_.predict(X)
    
    def get_n_important_features(self, n=100):
        estimators = self.model_.estimators_
        features = []
        for est in estimators:
            f = est[0].get_feature_names_out()
            max_ind = np.argpartition(np.abs(est[1].coef_), -n)[:, -n:]
            features.append(f[max_ind])
        return features
    

### Базовый пайплайн

In [91]:
def get_base_estimator(cols, vec_type='classic'):
    pipe = Pipeline([
        ('vectorizer', MultiTfidf(cols=cols, tfidf_type=vec_type)),
        ('model', LogisticRegression(class_weight='balanced', max_iter=500, C=1000))
    ])
    return pipe

### Самое интересное

In [68]:
encoder = get_encoder(df)

In [92]:
estimators = []
for cols in combinations(['text_no_punkt', 'lemmas', 'tags', 'tokens'], 2):
    estimators.append((';'.join(cols), get_base_estimator(cols)))

In [93]:
model = MyStack(estimators, CatBoostClassifier())

In [71]:
scoring = 'f1_macro'
cv = books_cross_val(df, k=2, seed=20)
y = encoder.transform(df['author'])

In [72]:
cvs = cross_validate(model, df, y, scoring=scoring, cv=cv, return_estimator=True)

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


Learning rate set to 0.088705
0:	learn: 2.1561664	total: 111ms	remaining: 1m 51s
1:	learn: 1.9313676	total: 221ms	remaining: 1m 50s
2:	learn: 1.7757361	total: 335ms	remaining: 1m 51s
3:	learn: 1.6483345	total: 448ms	remaining: 1m 51s
4:	learn: 1.5530209	total: 557ms	remaining: 1m 50s
5:	learn: 1.4687019	total: 667ms	remaining: 1m 50s
6:	learn: 1.4038989	total: 779ms	remaining: 1m 50s
7:	learn: 1.3486696	total: 891ms	remaining: 1m 50s
8:	learn: 1.2901185	total: 1s	remaining: 1m 50s
9:	learn: 1.2361448	total: 1.11s	remaining: 1m 50s
10:	learn: 1.1981936	total: 1.23s	remaining: 1m 50s
11:	learn: 1.1569178	total: 1.34s	remaining: 1m 50s
12:	learn: 1.1163189	total: 1.45s	remaining: 1m 50s
13:	learn: 1.0793234	total: 1.56s	remaining: 1m 50s
14:	learn: 1.0473478	total: 1.68s	remaining: 1m 49s
15:	learn: 1.0176255	total: 1.79s	remaining: 1m 50s
16:	learn: 0.9934708	total: 1.91s	remaining: 1m 50s
17:	learn: 0.9682882	total: 2.01s	remaining: 1m 49s
18:	learn: 0.9406228	total: 2.12s	remaining: 1m

157:	learn: 0.3310232	total: 17.9s	remaining: 1m 35s
158:	learn: 0.3301452	total: 18s	remaining: 1m 35s
159:	learn: 0.3290518	total: 18.2s	remaining: 1m 35s
160:	learn: 0.3277165	total: 18.3s	remaining: 1m 35s
161:	learn: 0.3266558	total: 18.4s	remaining: 1m 35s
162:	learn: 0.3258993	total: 18.5s	remaining: 1m 35s
163:	learn: 0.3247148	total: 18.6s	remaining: 1m 34s
164:	learn: 0.3238978	total: 18.7s	remaining: 1m 34s
165:	learn: 0.3230939	total: 18.8s	remaining: 1m 34s
166:	learn: 0.3222078	total: 19s	remaining: 1m 34s
167:	learn: 0.3216520	total: 19.1s	remaining: 1m 34s
168:	learn: 0.3204800	total: 19.2s	remaining: 1m 34s
169:	learn: 0.3193999	total: 19.3s	remaining: 1m 34s
170:	learn: 0.3186339	total: 19.4s	remaining: 1m 34s
171:	learn: 0.3174314	total: 19.5s	remaining: 1m 34s
172:	learn: 0.3160436	total: 19.7s	remaining: 1m 33s
173:	learn: 0.3149695	total: 19.8s	remaining: 1m 33s
174:	learn: 0.3143045	total: 19.9s	remaining: 1m 33s
175:	learn: 0.3133336	total: 20s	remaining: 1m 33s

313:	learn: 0.2222829	total: 36s	remaining: 1m 18s
314:	learn: 0.2217874	total: 36.1s	remaining: 1m 18s
315:	learn: 0.2215349	total: 36.2s	remaining: 1m 18s
316:	learn: 0.2212581	total: 36.3s	remaining: 1m 18s
317:	learn: 0.2203796	total: 36.4s	remaining: 1m 18s
318:	learn: 0.2200199	total: 36.5s	remaining: 1m 18s
319:	learn: 0.2195271	total: 36.7s	remaining: 1m 17s
320:	learn: 0.2187449	total: 36.8s	remaining: 1m 17s
321:	learn: 0.2182284	total: 36.9s	remaining: 1m 17s
322:	learn: 0.2176860	total: 37s	remaining: 1m 17s
323:	learn: 0.2170981	total: 37.1s	remaining: 1m 17s
324:	learn: 0.2168398	total: 37.3s	remaining: 1m 17s
325:	learn: 0.2161518	total: 37.4s	remaining: 1m 17s
326:	learn: 0.2156535	total: 37.5s	remaining: 1m 17s
327:	learn: 0.2151177	total: 37.6s	remaining: 1m 17s
328:	learn: 0.2146947	total: 37.7s	remaining: 1m 16s
329:	learn: 0.2142548	total: 37.8s	remaining: 1m 16s
330:	learn: 0.2136567	total: 37.9s	remaining: 1m 16s
331:	learn: 0.2132603	total: 38.1s	remaining: 1m 1

471:	learn: 0.1588731	total: 54.4s	remaining: 1m
472:	learn: 0.1586453	total: 54.5s	remaining: 1m
473:	learn: 0.1583944	total: 54.6s	remaining: 1m
474:	learn: 0.1580910	total: 54.7s	remaining: 1m
475:	learn: 0.1578056	total: 54.8s	remaining: 1m
476:	learn: 0.1575062	total: 54.9s	remaining: 1m
477:	learn: 0.1572469	total: 55.1s	remaining: 1m
478:	learn: 0.1568300	total: 55.2s	remaining: 1m
479:	learn: 0.1565899	total: 55.3s	remaining: 59.9s
480:	learn: 0.1562050	total: 55.4s	remaining: 59.8s
481:	learn: 0.1560739	total: 55.5s	remaining: 59.7s
482:	learn: 0.1558642	total: 55.6s	remaining: 59.5s
483:	learn: 0.1556440	total: 55.7s	remaining: 59.4s
484:	learn: 0.1552802	total: 55.9s	remaining: 59.3s
485:	learn: 0.1550159	total: 56s	remaining: 59.2s
486:	learn: 0.1547479	total: 56.1s	remaining: 59.1s
487:	learn: 0.1544701	total: 56.2s	remaining: 59s
488:	learn: 0.1539325	total: 56.3s	remaining: 58.9s
489:	learn: 0.1537233	total: 56.4s	remaining: 58.7s
490:	learn: 0.1533992	total: 56.5s	remai

631:	learn: 0.1190140	total: 1m 12s	remaining: 42.4s
632:	learn: 0.1188047	total: 1m 12s	remaining: 42.3s
633:	learn: 0.1185200	total: 1m 13s	remaining: 42.1s
634:	learn: 0.1183388	total: 1m 13s	remaining: 42s
635:	learn: 0.1181731	total: 1m 13s	remaining: 41.9s
636:	learn: 0.1179784	total: 1m 13s	remaining: 41.8s
637:	learn: 0.1177925	total: 1m 13s	remaining: 41.7s
638:	learn: 0.1175040	total: 1m 13s	remaining: 41.6s
639:	learn: 0.1174026	total: 1m 13s	remaining: 41.5s
640:	learn: 0.1171239	total: 1m 13s	remaining: 41.3s
641:	learn: 0.1168812	total: 1m 13s	remaining: 41.2s
642:	learn: 0.1165602	total: 1m 14s	remaining: 41.1s
643:	learn: 0.1162365	total: 1m 14s	remaining: 41s
644:	learn: 0.1160790	total: 1m 14s	remaining: 40.9s
645:	learn: 0.1157612	total: 1m 14s	remaining: 40.8s
646:	learn: 0.1156017	total: 1m 14s	remaining: 40.6s
647:	learn: 0.1154635	total: 1m 14s	remaining: 40.5s
648:	learn: 0.1152278	total: 1m 14s	remaining: 40.4s
649:	learn: 0.1151171	total: 1m 14s	remaining: 40.

787:	learn: 0.0927671	total: 1m 30s	remaining: 24.4s
788:	learn: 0.0925846	total: 1m 30s	remaining: 24.3s
789:	learn: 0.0924763	total: 1m 31s	remaining: 24.2s
790:	learn: 0.0923765	total: 1m 31s	remaining: 24.1s
791:	learn: 0.0922618	total: 1m 31s	remaining: 24s
792:	learn: 0.0920610	total: 1m 31s	remaining: 23.9s
793:	learn: 0.0919736	total: 1m 31s	remaining: 23.7s
794:	learn: 0.0918333	total: 1m 31s	remaining: 23.6s
795:	learn: 0.0917276	total: 1m 31s	remaining: 23.5s
796:	learn: 0.0916301	total: 1m 31s	remaining: 23.4s
797:	learn: 0.0915244	total: 1m 31s	remaining: 23.3s
798:	learn: 0.0913250	total: 1m 32s	remaining: 23.2s
799:	learn: 0.0911623	total: 1m 32s	remaining: 23s
800:	learn: 0.0910977	total: 1m 32s	remaining: 22.9s
801:	learn: 0.0909112	total: 1m 32s	remaining: 22.8s
802:	learn: 0.0907270	total: 1m 32s	remaining: 22.7s
803:	learn: 0.0906254	total: 1m 32s	remaining: 22.6s
804:	learn: 0.0905650	total: 1m 32s	remaining: 22.5s
805:	learn: 0.0904194	total: 1m 32s	remaining: 22.

943:	learn: 0.0743961	total: 1m 48s	remaining: 6.45s
944:	learn: 0.0743511	total: 1m 48s	remaining: 6.33s
945:	learn: 0.0742569	total: 1m 48s	remaining: 6.22s
946:	learn: 0.0741860	total: 1m 49s	remaining: 6.1s
947:	learn: 0.0741226	total: 1m 49s	remaining: 5.99s
948:	learn: 0.0740730	total: 1m 49s	remaining: 5.87s
949:	learn: 0.0739962	total: 1m 49s	remaining: 5.76s
950:	learn: 0.0738428	total: 1m 49s	remaining: 5.64s
951:	learn: 0.0737688	total: 1m 49s	remaining: 5.53s
952:	learn: 0.0736290	total: 1m 49s	remaining: 5.41s
953:	learn: 0.0735645	total: 1m 49s	remaining: 5.3s
954:	learn: 0.0735220	total: 1m 49s	remaining: 5.18s
955:	learn: 0.0734156	total: 1m 50s	remaining: 5.07s
956:	learn: 0.0733685	total: 1m 50s	remaining: 4.95s
957:	learn: 0.0733102	total: 1m 50s	remaining: 4.84s
958:	learn: 0.0732442	total: 1m 50s	remaining: 4.72s
959:	learn: 0.0731349	total: 1m 50s	remaining: 4.61s
960:	learn: 0.0730903	total: 1m 50s	remaining: 4.49s
961:	learn: 0.0730089	total: 1m 50s	remaining: 4

  y = column_or_1d(y, warn=True)


Learning rate set to 0.088165
0:	learn: 2.1344079	total: 110ms	remaining: 1m 49s
1:	learn: 1.8886355	total: 216ms	remaining: 1m 47s
2:	learn: 1.7381971	total: 323ms	remaining: 1m 47s
3:	learn: 1.6023935	total: 429ms	remaining: 1m 46s
4:	learn: 1.4848400	total: 538ms	remaining: 1m 46s
5:	learn: 1.4034083	total: 643ms	remaining: 1m 46s
6:	learn: 1.3264242	total: 748ms	remaining: 1m 46s
7:	learn: 1.2574093	total: 855ms	remaining: 1m 46s
8:	learn: 1.1981328	total: 961ms	remaining: 1m 45s
9:	learn: 1.1495556	total: 1.07s	remaining: 1m 46s
10:	learn: 1.1064154	total: 1.18s	remaining: 1m 45s
11:	learn: 1.0672069	total: 1.28s	remaining: 1m 45s
12:	learn: 1.0308695	total: 1.39s	remaining: 1m 45s
13:	learn: 0.9906094	total: 1.5s	remaining: 1m 45s
14:	learn: 0.9576588	total: 1.6s	remaining: 1m 45s
15:	learn: 0.9259190	total: 1.71s	remaining: 1m 45s
16:	learn: 0.9005312	total: 1.82s	remaining: 1m 45s
17:	learn: 0.8761228	total: 1.92s	remaining: 1m 44s
18:	learn: 0.8504564	total: 2.03s	remaining: 1

157:	learn: 0.2482031	total: 17.1s	remaining: 1m 31s
158:	learn: 0.2471387	total: 17.2s	remaining: 1m 30s
159:	learn: 0.2466816	total: 17.3s	remaining: 1m 30s
160:	learn: 0.2457550	total: 17.4s	remaining: 1m 30s
161:	learn: 0.2454104	total: 17.5s	remaining: 1m 30s
162:	learn: 0.2443798	total: 17.6s	remaining: 1m 30s
163:	learn: 0.2432683	total: 17.7s	remaining: 1m 30s
164:	learn: 0.2423707	total: 17.9s	remaining: 1m 30s
165:	learn: 0.2416721	total: 18s	remaining: 1m 30s
166:	learn: 0.2410563	total: 18.1s	remaining: 1m 30s
167:	learn: 0.2399104	total: 18.2s	remaining: 1m 30s
168:	learn: 0.2391411	total: 18.3s	remaining: 1m 29s
169:	learn: 0.2383213	total: 18.4s	remaining: 1m 29s
170:	learn: 0.2375212	total: 18.5s	remaining: 1m 29s
171:	learn: 0.2363918	total: 18.6s	remaining: 1m 29s
172:	learn: 0.2357328	total: 18.7s	remaining: 1m 29s
173:	learn: 0.2346877	total: 18.8s	remaining: 1m 29s
174:	learn: 0.2339384	total: 18.9s	remaining: 1m 29s
175:	learn: 0.2334540	total: 19.1s	remaining: 1m

313:	learn: 0.1559801	total: 34.2s	remaining: 1m 14s
314:	learn: 0.1558157	total: 34.3s	remaining: 1m 14s
315:	learn: 0.1556230	total: 34.4s	remaining: 1m 14s
316:	learn: 0.1553934	total: 34.5s	remaining: 1m 14s
317:	learn: 0.1548224	total: 34.6s	remaining: 1m 14s
318:	learn: 0.1544078	total: 34.7s	remaining: 1m 14s
319:	learn: 0.1538630	total: 34.8s	remaining: 1m 14s
320:	learn: 0.1534878	total: 34.9s	remaining: 1m 13s
321:	learn: 0.1529091	total: 35.1s	remaining: 1m 13s
322:	learn: 0.1524009	total: 35.2s	remaining: 1m 13s
323:	learn: 0.1520579	total: 35.3s	remaining: 1m 13s
324:	learn: 0.1517139	total: 35.4s	remaining: 1m 13s
325:	learn: 0.1509265	total: 35.5s	remaining: 1m 13s
326:	learn: 0.1505054	total: 35.6s	remaining: 1m 13s
327:	learn: 0.1501150	total: 35.7s	remaining: 1m 13s
328:	learn: 0.1496996	total: 35.8s	remaining: 1m 13s
329:	learn: 0.1494673	total: 36s	remaining: 1m 13s
330:	learn: 0.1491887	total: 36.1s	remaining: 1m 12s
331:	learn: 0.1489092	total: 36.2s	remaining: 1m

471:	learn: 0.1099770	total: 51.7s	remaining: 57.9s
472:	learn: 0.1098008	total: 51.8s	remaining: 57.7s
473:	learn: 0.1096164	total: 51.9s	remaining: 57.6s
474:	learn: 0.1095185	total: 52s	remaining: 57.5s
475:	learn: 0.1093123	total: 52.2s	remaining: 57.4s
476:	learn: 0.1090201	total: 52.3s	remaining: 57.3s
477:	learn: 0.1086390	total: 52.4s	remaining: 57.2s
478:	learn: 0.1084610	total: 52.5s	remaining: 57.1s
479:	learn: 0.1083205	total: 52.6s	remaining: 57s
480:	learn: 0.1080854	total: 52.7s	remaining: 56.9s
481:	learn: 0.1078272	total: 52.8s	remaining: 56.8s
482:	learn: 0.1074650	total: 52.9s	remaining: 56.7s
483:	learn: 0.1073902	total: 53s	remaining: 56.5s
484:	learn: 0.1071455	total: 53.2s	remaining: 56.4s
485:	learn: 0.1070364	total: 53.3s	remaining: 56.3s
486:	learn: 0.1068862	total: 53.4s	remaining: 56.2s
487:	learn: 0.1066317	total: 53.5s	remaining: 56.1s
488:	learn: 0.1063009	total: 53.6s	remaining: 56s
489:	learn: 0.1061432	total: 53.7s	remaining: 55.9s
490:	learn: 0.105896

631:	learn: 0.0795141	total: 1m 9s	remaining: 40.4s
632:	learn: 0.0793621	total: 1m 9s	remaining: 40.3s
633:	learn: 0.0792381	total: 1m 9s	remaining: 40.2s
634:	learn: 0.0791020	total: 1m 9s	remaining: 40.1s
635:	learn: 0.0789384	total: 1m 9s	remaining: 40s
636:	learn: 0.0788986	total: 1m 9s	remaining: 39.9s
637:	learn: 0.0787189	total: 1m 10s	remaining: 39.8s
638:	learn: 0.0786367	total: 1m 10s	remaining: 39.7s
639:	learn: 0.0785305	total: 1m 10s	remaining: 39.5s
640:	learn: 0.0784028	total: 1m 10s	remaining: 39.4s
641:	learn: 0.0782571	total: 1m 10s	remaining: 39.3s
642:	learn: 0.0779811	total: 1m 10s	remaining: 39.2s
643:	learn: 0.0778404	total: 1m 10s	remaining: 39.1s
644:	learn: 0.0777409	total: 1m 10s	remaining: 39s
645:	learn: 0.0776338	total: 1m 10s	remaining: 38.9s
646:	learn: 0.0773869	total: 1m 11s	remaining: 38.8s
647:	learn: 0.0772042	total: 1m 11s	remaining: 38.7s
648:	learn: 0.0770666	total: 1m 11s	remaining: 38.6s
649:	learn: 0.0769139	total: 1m 11s	remaining: 38.5s
650

787:	learn: 0.0608624	total: 1m 26s	remaining: 23.4s
788:	learn: 0.0608019	total: 1m 26s	remaining: 23.2s
789:	learn: 0.0606871	total: 1m 27s	remaining: 23.1s
790:	learn: 0.0606356	total: 1m 27s	remaining: 23s
791:	learn: 0.0605870	total: 1m 27s	remaining: 22.9s
792:	learn: 0.0605013	total: 1m 27s	remaining: 22.8s
793:	learn: 0.0604137	total: 1m 27s	remaining: 22.7s
794:	learn: 0.0602527	total: 1m 27s	remaining: 22.6s
795:	learn: 0.0600351	total: 1m 27s	remaining: 22.5s
796:	learn: 0.0599182	total: 1m 27s	remaining: 22.4s
797:	learn: 0.0598781	total: 1m 27s	remaining: 22.3s
798:	learn: 0.0597124	total: 1m 28s	remaining: 22.1s
799:	learn: 0.0596330	total: 1m 28s	remaining: 22s
800:	learn: 0.0594679	total: 1m 28s	remaining: 21.9s
801:	learn: 0.0593527	total: 1m 28s	remaining: 21.8s
802:	learn: 0.0592494	total: 1m 28s	remaining: 21.7s
803:	learn: 0.0591279	total: 1m 28s	remaining: 21.6s
804:	learn: 0.0590408	total: 1m 28s	remaining: 21.5s
805:	learn: 0.0589103	total: 1m 28s	remaining: 21.

943:	learn: 0.0481518	total: 1m 44s	remaining: 6.2s
944:	learn: 0.0481054	total: 1m 44s	remaining: 6.09s
945:	learn: 0.0480551	total: 1m 44s	remaining: 5.98s
946:	learn: 0.0479782	total: 1m 44s	remaining: 5.87s
947:	learn: 0.0479298	total: 1m 45s	remaining: 5.76s
948:	learn: 0.0478218	total: 1m 45s	remaining: 5.65s
949:	learn: 0.0477876	total: 1m 45s	remaining: 5.54s
950:	learn: 0.0477528	total: 1m 45s	remaining: 5.43s
951:	learn: 0.0476708	total: 1m 45s	remaining: 5.32s
952:	learn: 0.0476547	total: 1m 45s	remaining: 5.21s
953:	learn: 0.0475949	total: 1m 45s	remaining: 5.1s
954:	learn: 0.0475168	total: 1m 45s	remaining: 4.99s
955:	learn: 0.0474567	total: 1m 45s	remaining: 4.88s
956:	learn: 0.0474115	total: 1m 46s	remaining: 4.76s
957:	learn: 0.0473962	total: 1m 46s	remaining: 4.65s
958:	learn: 0.0473366	total: 1m 46s	remaining: 4.54s
959:	learn: 0.0472406	total: 1m 46s	remaining: 4.43s
960:	learn: 0.0471923	total: 1m 46s	remaining: 4.32s
961:	learn: 0.0471121	total: 1m 46s	remaining: 4

  y = column_or_1d(y, warn=True)


In [73]:
np.mean(cvs['test_score'])

0.6693075611396591

In [74]:
est = cvs['estimator'][0]

In [75]:
_, df_test, _, y_test = train_test_split(df, seed=20)

In [76]:
x = confusion_matrix(y_test, encoder.classes_[est.predict(df_test)], labels=df['author'].unique())

  y = column_or_1d(y, warn=True)


In [77]:
precision = np.diag(x) / x.sum(axis=1)
recall = np.diag(x) / x.sum(axis=0)
f1 = 2 * precision * recall / (precision + recall)

  recall = np.diag(x) / x.sum(axis=0)


In [78]:
f1_df = pd.DataFrame({'f1': f1, 'author': df['author'].unique()})

In [79]:
f1_df

Unnamed: 0,f1,author
0,0.921776,Аверченко
1,0.773687,Булгаков
2,0.842541,Бунин
3,0.807531,Гоголь
4,0.705757,Горький
5,0.949662,Достоевский
6,0.142857,Карамзин
7,0.8967,Куприн
8,,Лермонтов
9,0.796137,Мордовцев


In [80]:
f1_df['f1'].mean()

0.6933494138398023

### То же самое, но без редко встречающихся авторов

In [98]:
df_less_authors = df[~df['author'].isin(['Чехов', 'Тургенев', 'Пушкин', 'Лермонтов', 'Карамзин'])].reset_index(drop=True)

In [99]:
encoder_la = get_encoder(df_less_authors)

In [100]:
cv_la = books_cross_val(df_less_authors, k=2, seed=20)
y_la = encoder.transform(df_less_authors['author'])

In [101]:
cvs = cross_validate(model, df_less_authors, y_la, scoring=scoring, cv=cv_la, return_estimator=True)

Learning rate set to 0.087588
0:	learn: 1.9264644	total: 50.4ms	remaining: 50.4s
1:	learn: 1.7109396	total: 97.6ms	remaining: 48.7s
2:	learn: 1.5509313	total: 147ms	remaining: 48.7s
3:	learn: 1.4342107	total: 193ms	remaining: 48.1s
4:	learn: 1.3353693	total: 243ms	remaining: 48.3s
5:	learn: 1.2589331	total: 299ms	remaining: 49.5s
6:	learn: 1.1920911	total: 346ms	remaining: 49.1s
7:	learn: 1.1252044	total: 393ms	remaining: 48.7s
8:	learn: 1.0752168	total: 441ms	remaining: 48.5s
9:	learn: 1.0306481	total: 494ms	remaining: 48.9s
10:	learn: 0.9883236	total: 544ms	remaining: 48.9s
11:	learn: 0.9400948	total: 591ms	remaining: 48.7s
12:	learn: 0.9062343	total: 640ms	remaining: 48.6s
13:	learn: 0.8719482	total: 688ms	remaining: 48.4s
14:	learn: 0.8405288	total: 740ms	remaining: 48.6s
15:	learn: 0.8132591	total: 789ms	remaining: 48.5s
16:	learn: 0.7860490	total: 837ms	remaining: 48.4s
17:	learn: 0.7632195	total: 884ms	remaining: 48.2s
18:	learn: 0.7427912	total: 933ms	remaining: 48.2s
19:	learn

164:	learn: 0.2792039	total: 8.01s	remaining: 40.5s
165:	learn: 0.2783639	total: 8.05s	remaining: 40.5s
166:	learn: 0.2776594	total: 8.1s	remaining: 40.4s
167:	learn: 0.2770751	total: 8.15s	remaining: 40.4s
168:	learn: 0.2760603	total: 8.2s	remaining: 40.3s
169:	learn: 0.2751302	total: 8.25s	remaining: 40.3s
170:	learn: 0.2741296	total: 8.29s	remaining: 40.2s
171:	learn: 0.2731402	total: 8.34s	remaining: 40.2s
172:	learn: 0.2720397	total: 8.39s	remaining: 40.1s
173:	learn: 0.2713098	total: 8.44s	remaining: 40.1s
174:	learn: 0.2708289	total: 8.48s	remaining: 40s
175:	learn: 0.2702651	total: 8.53s	remaining: 40s
176:	learn: 0.2699526	total: 8.58s	remaining: 39.9s
177:	learn: 0.2689383	total: 8.63s	remaining: 39.8s
178:	learn: 0.2680445	total: 8.67s	remaining: 39.8s
179:	learn: 0.2672342	total: 8.72s	remaining: 39.7s
180:	learn: 0.2664527	total: 8.77s	remaining: 39.7s
181:	learn: 0.2656964	total: 8.81s	remaining: 39.6s
182:	learn: 0.2650271	total: 8.86s	remaining: 39.6s
183:	learn: 0.2642

324:	learn: 0.1881447	total: 15.6s	remaining: 32.5s
325:	learn: 0.1877664	total: 15.7s	remaining: 32.4s
326:	learn: 0.1874844	total: 15.7s	remaining: 32.4s
327:	learn: 0.1869682	total: 15.8s	remaining: 32.3s
328:	learn: 0.1864751	total: 15.8s	remaining: 32.3s
329:	learn: 0.1860701	total: 15.9s	remaining: 32.2s
330:	learn: 0.1857088	total: 15.9s	remaining: 32.2s
331:	learn: 0.1852424	total: 16s	remaining: 32.1s
332:	learn: 0.1849293	total: 16s	remaining: 32.1s
333:	learn: 0.1844840	total: 16.1s	remaining: 32s
334:	learn: 0.1834810	total: 16.1s	remaining: 32s
335:	learn: 0.1829375	total: 16.2s	remaining: 31.9s
336:	learn: 0.1826096	total: 16.2s	remaining: 31.9s
337:	learn: 0.1820155	total: 16.3s	remaining: 31.8s
338:	learn: 0.1816792	total: 16.3s	remaining: 31.8s
339:	learn: 0.1812159	total: 16.4s	remaining: 31.8s
340:	learn: 0.1806624	total: 16.4s	remaining: 31.7s
341:	learn: 0.1804482	total: 16.5s	remaining: 31.7s
342:	learn: 0.1799276	total: 16.5s	remaining: 31.6s
343:	learn: 0.179354

484:	learn: 0.1364915	total: 23.4s	remaining: 24.9s
485:	learn: 0.1362001	total: 23.5s	remaining: 24.8s
486:	learn: 0.1357236	total: 23.5s	remaining: 24.8s
487:	learn: 0.1354992	total: 23.6s	remaining: 24.7s
488:	learn: 0.1352969	total: 23.6s	remaining: 24.7s
489:	learn: 0.1350546	total: 23.7s	remaining: 24.6s
490:	learn: 0.1346493	total: 23.7s	remaining: 24.6s
491:	learn: 0.1342773	total: 23.8s	remaining: 24.5s
492:	learn: 0.1340188	total: 23.8s	remaining: 24.5s
493:	learn: 0.1338391	total: 23.9s	remaining: 24.4s
494:	learn: 0.1336640	total: 23.9s	remaining: 24.4s
495:	learn: 0.1333748	total: 24s	remaining: 24.3s
496:	learn: 0.1330111	total: 24s	remaining: 24.3s
497:	learn: 0.1326483	total: 24.1s	remaining: 24.2s
498:	learn: 0.1324633	total: 24.1s	remaining: 24.2s
499:	learn: 0.1322127	total: 24.2s	remaining: 24.2s
500:	learn: 0.1319879	total: 24.2s	remaining: 24.1s
501:	learn: 0.1317231	total: 24.2s	remaining: 24.1s
502:	learn: 0.1315454	total: 24.3s	remaining: 24s
503:	learn: 0.1313

643:	learn: 0.1015749	total: 31.2s	remaining: 17.2s
644:	learn: 0.1014035	total: 31.2s	remaining: 17.2s
645:	learn: 0.1012596	total: 31.3s	remaining: 17.1s
646:	learn: 0.1011937	total: 31.3s	remaining: 17.1s
647:	learn: 0.1010893	total: 31.4s	remaining: 17s
648:	learn: 0.1008667	total: 31.4s	remaining: 17s
649:	learn: 0.1006722	total: 31.5s	remaining: 16.9s
650:	learn: 0.1004508	total: 31.5s	remaining: 16.9s
651:	learn: 0.1002934	total: 31.6s	remaining: 16.8s
652:	learn: 0.1001275	total: 31.6s	remaining: 16.8s
653:	learn: 0.0999598	total: 31.7s	remaining: 16.8s
654:	learn: 0.0998126	total: 31.7s	remaining: 16.7s
655:	learn: 0.0996494	total: 31.8s	remaining: 16.7s
656:	learn: 0.0994292	total: 31.8s	remaining: 16.6s
657:	learn: 0.0992642	total: 31.9s	remaining: 16.6s
658:	learn: 0.0991515	total: 31.9s	remaining: 16.5s
659:	learn: 0.0989603	total: 32s	remaining: 16.5s
660:	learn: 0.0987657	total: 32s	remaining: 16.4s
661:	learn: 0.0985279	total: 32.1s	remaining: 16.4s
662:	learn: 0.098349

803:	learn: 0.0782298	total: 39s	remaining: 9.52s
804:	learn: 0.0781810	total: 39.1s	remaining: 9.47s
805:	learn: 0.0781063	total: 39.1s	remaining: 9.42s
806:	learn: 0.0780194	total: 39.2s	remaining: 9.37s
807:	learn: 0.0778738	total: 39.2s	remaining: 9.32s
808:	learn: 0.0777844	total: 39.3s	remaining: 9.27s
809:	learn: 0.0776229	total: 39.3s	remaining: 9.23s
810:	learn: 0.0774429	total: 39.4s	remaining: 9.18s
811:	learn: 0.0773740	total: 39.4s	remaining: 9.13s
812:	learn: 0.0772484	total: 39.5s	remaining: 9.08s
813:	learn: 0.0771069	total: 39.5s	remaining: 9.03s
814:	learn: 0.0770046	total: 39.6s	remaining: 8.98s
815:	learn: 0.0768935	total: 39.6s	remaining: 8.94s
816:	learn: 0.0767446	total: 39.7s	remaining: 8.89s
817:	learn: 0.0766279	total: 39.7s	remaining: 8.84s
818:	learn: 0.0765260	total: 39.8s	remaining: 8.79s
819:	learn: 0.0763922	total: 39.8s	remaining: 8.74s
820:	learn: 0.0762799	total: 39.9s	remaining: 8.7s
821:	learn: 0.0761761	total: 39.9s	remaining: 8.65s
822:	learn: 0.0

966:	learn: 0.0618991	total: 47.1s	remaining: 1.61s
967:	learn: 0.0618160	total: 47.2s	remaining: 1.56s
968:	learn: 0.0616674	total: 47.2s	remaining: 1.51s
969:	learn: 0.0615788	total: 47.3s	remaining: 1.46s
970:	learn: 0.0615223	total: 47.3s	remaining: 1.41s
971:	learn: 0.0614507	total: 47.4s	remaining: 1.36s
972:	learn: 0.0613717	total: 47.4s	remaining: 1.31s
973:	learn: 0.0613218	total: 47.5s	remaining: 1.27s
974:	learn: 0.0612608	total: 47.5s	remaining: 1.22s
975:	learn: 0.0612067	total: 47.6s	remaining: 1.17s
976:	learn: 0.0611137	total: 47.6s	remaining: 1.12s
977:	learn: 0.0610922	total: 47.7s	remaining: 1.07s
978:	learn: 0.0609680	total: 47.7s	remaining: 1.02s
979:	learn: 0.0608872	total: 47.8s	remaining: 975ms
980:	learn: 0.0607894	total: 47.8s	remaining: 926ms
981:	learn: 0.0607323	total: 47.9s	remaining: 877ms
982:	learn: 0.0606572	total: 47.9s	remaining: 829ms
983:	learn: 0.0605862	total: 48s	remaining: 780ms
984:	learn: 0.0604746	total: 48s	remaining: 731ms
985:	learn: 0.06

  y = column_or_1d(y, warn=True)


Learning rate set to 0.088521
0:	learn: 1.9198472	total: 63.8ms	remaining: 1m 3s
1:	learn: 1.6963144	total: 114ms	remaining: 57.1s
2:	learn: 1.5464693	total: 165ms	remaining: 54.9s
3:	learn: 1.4265322	total: 216ms	remaining: 53.8s
4:	learn: 1.3244395	total: 269ms	remaining: 53.5s
5:	learn: 1.2407817	total: 319ms	remaining: 52.9s
6:	learn: 1.1683259	total: 368ms	remaining: 52.2s
7:	learn: 1.1053827	total: 419ms	remaining: 52s
8:	learn: 1.0518745	total: 470ms	remaining: 51.8s
9:	learn: 1.0043022	total: 521ms	remaining: 51.6s
10:	learn: 0.9644061	total: 571ms	remaining: 51.4s
11:	learn: 0.9242360	total: 623ms	remaining: 51.3s
12:	learn: 0.8930467	total: 675ms	remaining: 51.2s
13:	learn: 0.8615197	total: 727ms	remaining: 51.2s
14:	learn: 0.8302306	total: 776ms	remaining: 51s
15:	learn: 0.8044634	total: 827ms	remaining: 50.9s
16:	learn: 0.7809975	total: 878ms	remaining: 50.8s
17:	learn: 0.7590362	total: 927ms	remaining: 50.6s
18:	learn: 0.7380982	total: 977ms	remaining: 50.5s
19:	learn: 0.7

162:	learn: 0.2815906	total: 8.09s	remaining: 41.5s
163:	learn: 0.2803308	total: 8.14s	remaining: 41.5s
164:	learn: 0.2798324	total: 8.19s	remaining: 41.4s
165:	learn: 0.2792099	total: 8.24s	remaining: 41.4s
166:	learn: 0.2780629	total: 8.29s	remaining: 41.3s
167:	learn: 0.2776730	total: 8.33s	remaining: 41.3s
168:	learn: 0.2767999	total: 8.38s	remaining: 41.2s
169:	learn: 0.2757750	total: 8.44s	remaining: 41.2s
170:	learn: 0.2749625	total: 8.48s	remaining: 41.1s
171:	learn: 0.2741037	total: 8.53s	remaining: 41.1s
172:	learn: 0.2735050	total: 8.58s	remaining: 41s
173:	learn: 0.2729628	total: 8.63s	remaining: 41s
174:	learn: 0.2725274	total: 8.68s	remaining: 40.9s
175:	learn: 0.2716308	total: 8.73s	remaining: 40.9s
176:	learn: 0.2707164	total: 8.78s	remaining: 40.8s
177:	learn: 0.2696638	total: 8.83s	remaining: 40.8s
178:	learn: 0.2686504	total: 8.88s	remaining: 40.7s
179:	learn: 0.2679171	total: 8.93s	remaining: 40.7s
180:	learn: 0.2673061	total: 8.97s	remaining: 40.6s
181:	learn: 0.26

324:	learn: 0.1924771	total: 16.1s	remaining: 33.5s
325:	learn: 0.1918354	total: 16.2s	remaining: 33.5s
326:	learn: 0.1915474	total: 16.2s	remaining: 33.4s
327:	learn: 0.1910453	total: 16.3s	remaining: 33.4s
328:	learn: 0.1908018	total: 16.3s	remaining: 33.3s
329:	learn: 0.1904358	total: 16.4s	remaining: 33.3s
330:	learn: 0.1900987	total: 16.4s	remaining: 33.2s
331:	learn: 0.1897113	total: 16.5s	remaining: 33.2s
332:	learn: 0.1891885	total: 16.5s	remaining: 33.1s
333:	learn: 0.1889435	total: 16.6s	remaining: 33.1s
334:	learn: 0.1884494	total: 16.6s	remaining: 33s
335:	learn: 0.1881770	total: 16.7s	remaining: 33s
336:	learn: 0.1878842	total: 16.7s	remaining: 32.9s
337:	learn: 0.1873695	total: 16.8s	remaining: 32.9s
338:	learn: 0.1868573	total: 16.8s	remaining: 32.8s
339:	learn: 0.1866080	total: 16.9s	remaining: 32.8s
340:	learn: 0.1861722	total: 16.9s	remaining: 32.7s
341:	learn: 0.1857843	total: 17s	remaining: 32.7s
342:	learn: 0.1854386	total: 17s	remaining: 32.6s
343:	learn: 0.185206

485:	learn: 0.1414306	total: 24.5s	remaining: 25.9s
486:	learn: 0.1411782	total: 24.5s	remaining: 25.8s
487:	learn: 0.1410302	total: 24.6s	remaining: 25.8s
488:	learn: 0.1408996	total: 24.6s	remaining: 25.8s
489:	learn: 0.1406317	total: 24.7s	remaining: 25.7s
490:	learn: 0.1404061	total: 24.7s	remaining: 25.7s
491:	learn: 0.1401455	total: 24.8s	remaining: 25.6s
492:	learn: 0.1397159	total: 24.9s	remaining: 25.6s
493:	learn: 0.1395282	total: 24.9s	remaining: 25.5s
494:	learn: 0.1393810	total: 25s	remaining: 25.5s
495:	learn: 0.1392097	total: 25s	remaining: 25.4s
496:	learn: 0.1390042	total: 25.1s	remaining: 25.4s
497:	learn: 0.1388248	total: 25.1s	remaining: 25.3s
498:	learn: 0.1386694	total: 25.2s	remaining: 25.3s
499:	learn: 0.1384394	total: 25.2s	remaining: 25.2s
500:	learn: 0.1380360	total: 25.3s	remaining: 25.2s
501:	learn: 0.1377282	total: 25.3s	remaining: 25.1s
502:	learn: 0.1375480	total: 25.4s	remaining: 25.1s
503:	learn: 0.1371570	total: 25.4s	remaining: 25s
504:	learn: 0.1369

648:	learn: 0.1103800	total: 33.1s	remaining: 17.9s
649:	learn: 0.1102253	total: 33.1s	remaining: 17.8s
650:	learn: 0.1101524	total: 33.2s	remaining: 17.8s
651:	learn: 0.1099868	total: 33.2s	remaining: 17.7s
652:	learn: 0.1098544	total: 33.3s	remaining: 17.7s
653:	learn: 0.1096275	total: 33.3s	remaining: 17.6s
654:	learn: 0.1094921	total: 33.4s	remaining: 17.6s
655:	learn: 0.1093029	total: 33.4s	remaining: 17.5s
656:	learn: 0.1090663	total: 33.5s	remaining: 17.5s
657:	learn: 0.1089727	total: 33.5s	remaining: 17.4s
658:	learn: 0.1087439	total: 33.6s	remaining: 17.4s
659:	learn: 0.1086107	total: 33.7s	remaining: 17.3s
660:	learn: 0.1084304	total: 33.7s	remaining: 17.3s
661:	learn: 0.1082827	total: 33.8s	remaining: 17.2s
662:	learn: 0.1079611	total: 33.8s	remaining: 17.2s
663:	learn: 0.1078124	total: 33.9s	remaining: 17.1s
664:	learn: 0.1076124	total: 33.9s	remaining: 17.1s
665:	learn: 0.1074984	total: 34s	remaining: 17s
666:	learn: 0.1072959	total: 34s	remaining: 17s
667:	learn: 0.107219

809:	learn: 0.0878918	total: 41.6s	remaining: 9.77s
810:	learn: 0.0877380	total: 41.7s	remaining: 9.71s
811:	learn: 0.0876797	total: 41.7s	remaining: 9.66s
812:	learn: 0.0875007	total: 41.8s	remaining: 9.61s
813:	learn: 0.0874033	total: 41.8s	remaining: 9.56s
814:	learn: 0.0873454	total: 41.9s	remaining: 9.51s
815:	learn: 0.0872709	total: 42s	remaining: 9.46s
816:	learn: 0.0870758	total: 42s	remaining: 9.41s
817:	learn: 0.0870464	total: 42.1s	remaining: 9.36s
818:	learn: 0.0869745	total: 42.1s	remaining: 9.31s
819:	learn: 0.0868291	total: 42.2s	remaining: 9.26s
820:	learn: 0.0867912	total: 42.2s	remaining: 9.2s
821:	learn: 0.0866816	total: 42.3s	remaining: 9.15s
822:	learn: 0.0866146	total: 42.3s	remaining: 9.1s
823:	learn: 0.0864393	total: 42.4s	remaining: 9.05s
824:	learn: 0.0863364	total: 42.4s	remaining: 9s
825:	learn: 0.0862774	total: 42.5s	remaining: 8.95s
826:	learn: 0.0860967	total: 42.5s	remaining: 8.9s
827:	learn: 0.0858997	total: 42.6s	remaining: 8.85s
828:	learn: 0.0857903	

968:	learn: 0.0720946	total: 50.3s	remaining: 1.61s
969:	learn: 0.0719780	total: 50.3s	remaining: 1.56s
970:	learn: 0.0719069	total: 50.4s	remaining: 1.5s
971:	learn: 0.0717964	total: 50.4s	remaining: 1.45s
972:	learn: 0.0717484	total: 50.5s	remaining: 1.4s
973:	learn: 0.0716550	total: 50.6s	remaining: 1.35s
974:	learn: 0.0715659	total: 50.6s	remaining: 1.3s
975:	learn: 0.0715008	total: 50.7s	remaining: 1.25s
976:	learn: 0.0714378	total: 50.7s	remaining: 1.19s
977:	learn: 0.0713390	total: 50.8s	remaining: 1.14s
978:	learn: 0.0712856	total: 50.8s	remaining: 1.09s
979:	learn: 0.0711645	total: 50.9s	remaining: 1.04s
980:	learn: 0.0710546	total: 50.9s	remaining: 987ms
981:	learn: 0.0709724	total: 51s	remaining: 935ms
982:	learn: 0.0709278	total: 51s	remaining: 883ms
983:	learn: 0.0708185	total: 51.1s	remaining: 831ms
984:	learn: 0.0707869	total: 51.2s	remaining: 779ms
985:	learn: 0.0706122	total: 51.2s	remaining: 727ms
986:	learn: 0.0705072	total: 51.3s	remaining: 675ms
987:	learn: 0.07041

  y = column_or_1d(y, warn=True)


In [103]:
np.mean(cvs['test_score'])

0.8397827164854054

In [104]:
est = cvs['estimator'][0]

In [107]:
est.get_n_important_features(10)

[array([['почему', 'точно', 'крысаков', 'точно', 'вы', 'мы', 'шпаргалка',
         'если', 'сейчас', 'про'],
        ['елена', 'маргарита', 'ответить', 'прокуратор', 'ответил',
         'турбина', 'булгакова', 'булгаков', 'маргарита', 'николка'],
        ['всё', 'мол', 'соколович', 'мол', 'елагин', 'возле',
         'сосновский', 'опять', 'возле', 'сад'],
        ['чтобы', 'голова', 'григориевич', 'чичиков', 'чичиков', 'главе',
         'глава', 'книга', 'ковалёв', 'парубок'],
        ['культура', 'рабочий', 'смолин', 'её', 'фома', 'её', 'фома',
         'всё', 'ещё', 'человек'],
        ['вдруг', 'сказал', 'достоевского', 'князь', 'все', 'алеш', 'всё',
         'достоевский', 'митя', 'чтоб'],
        ['олеся', 'точно', 'башкирцев', 'всё', 'ромашов', 'точно',
         'ромашов', 'что', 'чудин', 'буланин'],
        ['цезарь', 'лариса', 'девушка', 'де', 'девушка', 'москва',
         'левин', 'левин', 'царевич', 'хлыновый'],
        ['аннинька', 'ни', 'ежели', 'градоначальник', 'владимиры

In [108]:
_, df_test_la, _, y_test_la = train_test_split(df_less_authors, seed=20)

In [109]:
x_la = confusion_matrix(y_test_la, encoder.classes_[est.predict(df_test_la)], labels=df_less_authors['author'].unique())

  y = column_or_1d(y, warn=True)


In [110]:
precision = np.diag(x_la) / x_la.sum(axis=1)
recall = np.diag(x_la) / x_la.sum(axis=0)
f1 = 2 * precision * recall / (precision + recall)

In [112]:
f1_df = pd.DataFrame({'f1': f1, 'author': df_less_authors['author'].unique()})

In [113]:
f1_df

Unnamed: 0,f1,author
0,0.927536,Аверченко
1,0.869182,Булгаков
2,0.823994,Бунин
3,0.768856,Гоголь
4,0.926058,Горький
5,0.895755,Достоевский
6,0.874728,Куприн
7,0.833648,Мордовцев
8,0.53406,Салтыков-Щедрин
9,0.855031,Толстой


In [114]:
f1_df['f1'].mean()

0.8308848601574941