In [1]:
import pathlib
import random
import pandas as pd
import numpy as np
import sys

from sklearn.linear_model import LogisticRegression
from sklearn.multioutput import MultiOutputClassifier
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split, cross_validate, cross_val_predict

from sklearn.metrics import (
    f1_score, 
    accuracy_score,
    classification_report, 
)

ROOT_DIR = pathlib.Path().absolute()
DATA_DIR = ROOT_DIR / "data"
RANDOM_SEED = 42

## Загрузка и обзор данных

In [2]:
df_trends = pd.read_csv(DATA_DIR / "trends_description.csv")
df = pd.read_csv(DATA_DIR / "train.csv")
df_test = pd.read_csv(DATA_DIR / "test.csv")

In [3]:
df.head()

Unnamed: 0.1,Unnamed: 0,index,assessment,tags,text,trend_id_res0,trend_id_res1,trend_id_res2,trend_id_res3,trend_id_res4,...,trend_id_res40,trend_id_res41,trend_id_res42,trend_id_res43,trend_id_res44,trend_id_res45,trend_id_res46,trend_id_res47,trend_id_res48,trend_id_res49
0,0,5652,6.0,"{ASSORTMENT,PROMOTIONS,DELIVERY}","Маленький выбор товаров, хотелось бы ассортиме...",0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,1,18092,4.0,"{ASSORTMENT,PRICE,PRODUCTS_QUALITY,DELIVERY}",Быстро,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,2,13845,6.0,"{DELIVERY,PROMOTIONS,PRICE,ASSORTMENT,SUPPORT}",Доставка постоянно задерживается,0,0,0,1,0,...,0,0,0,0,0,0,0,0,0,0
3,3,25060,6.0,"{PRICE,PROMOTIONS,ASSORTMENT}",Наценка и ассортимент расстраивают,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,5,1428,6.0,"{PRICE,PROMOTIONS}",Можно немного скинуть минимальную сумму заказа...,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


## Обучение моделей

### Предобработка данных

In [4]:
df.head()

Unnamed: 0.1,Unnamed: 0,index,assessment,tags,text,trend_id_res0,trend_id_res1,trend_id_res2,trend_id_res3,trend_id_res4,...,trend_id_res40,trend_id_res41,trend_id_res42,trend_id_res43,trend_id_res44,trend_id_res45,trend_id_res46,trend_id_res47,trend_id_res48,trend_id_res49
0,0,5652,6.0,"{ASSORTMENT,PROMOTIONS,DELIVERY}","Маленький выбор товаров, хотелось бы ассортиме...",0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,1,18092,4.0,"{ASSORTMENT,PRICE,PRODUCTS_QUALITY,DELIVERY}",Быстро,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,2,13845,6.0,"{DELIVERY,PROMOTIONS,PRICE,ASSORTMENT,SUPPORT}",Доставка постоянно задерживается,0,0,0,1,0,...,0,0,0,0,0,0,0,0,0,0
3,3,25060,6.0,"{PRICE,PROMOTIONS,ASSORTMENT}",Наценка и ассортимент расстраивают,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,5,1428,6.0,"{PRICE,PROMOTIONS}",Можно немного скинуть минимальную сумму заказа...,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [5]:
X, y = df[["text"]], df[[f"trend_id_res{i}" for i in range(50)]]
X = X.astype("str").copy()
X_train, X_val, y_train, y_val = train_test_split(X, y, train_size= 0.8, random_state = 42)
print(f"X_train.shape is {X_train.shape}")
print(f"y_train.shape is {y_train.shape}")
print(f"X_val.shape is {X_val.shape}")
print(f"y_val.shape is {y_val.shape}")
X_test = df_test[["text"]]
print(f"X_test.shape is {X_test.shape}")


X_train.shape is (3698, 1)
y_train.shape is (3698, 50)
X_val.shape is (925, 1)
y_val.shape is (925, 50)
X_test.shape is (9015, 1)


###  Проверка качества на тречнировчном датасете

In [6]:
preprocessor = ColumnTransformer(
    [
        ("vetorizer", TfidfVectorizer(analyzer="char_wb", ngram_range = (1,3)), "text")
    ],                         
    remainder = "passthrough"
)

pipeline_multiout = Pipeline(
    [
        ("preprocessor", preprocessor),
        ("clf", MultiOutputClassifier(LogisticRegression(max_iter = 10_000))),
    ]
)
display(pipeline_multiout)

In [7]:
cross_valid = cross_validate(pipeline_multiout, 
                             X_train, y_train, 
                             cv = 5, scoring = ["accuracy"], n_jobs = -1)
print("test_accuracy:", cross_valid["test_accuracy"].mean())

test_accuracy: 0.23553231174340783


In [8]:
y_pred = cross_val_predict(pipeline_multiout, X_train, y_train, cv = 2)

In [9]:
# Посмотрим на различные метрики
print(classification_report(y_train, y_pred, zero_division = 0))

              precision    recall  f1-score   support

           0       0.87      0.39      0.54       661
           1       0.89      0.09      0.16       270
           2       0.80      0.37      0.50       486
           3       0.93      0.21      0.35       289
           4       0.00      0.00      0.00       108
           5       0.00      0.00      0.00        44
           6       0.00      0.00      0.00        16
           7       0.00      0.00      0.00        27
           8       1.00      0.01      0.02       109
           9       0.00      0.00      0.00         9
          10       0.00      0.00      0.00        76
          11       0.00      0.00      0.00        87
          12       0.97      0.41      0.57       491
          13       0.00      0.00      0.00        29
          14       0.00      0.00      0.00        62
          15       0.00      0.00      0.00        66
          16       0.00      0.00      0.00       166
          17       0.00    

In [10]:
# Посмотрим на целевую метрику
accuracy_score(y_train, y_pred)

0.19334775554353706

###  Тренировка окончательной модели

In [11]:
pipeline_multiout.fit(X_train, y_train)

##  Предсказание и загрузка решения

In [71]:
pred_test = pipeline_multiout.predict(df_test[["text"]].astype("str"))

In [72]:
res = pd.DataFrame(np.hstack([df_test["index"].values.reshape(df_test.shape[0], 1), pred_test]),
                  columns = ["index"]+[f"{i}" for i in range(50)])

In [73]:
res.head()

Unnamed: 0,index,0,1,2,3,4,5,6,7,8,...,40,41,42,43,44,45,46,47,48,49
0,3135,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,4655,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
2,22118,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,23511,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,45,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [74]:
#res['non_zero_count'] = np.count_nonzero(res.iloc[:, 1:51], axis=1)
#res['non_zero_label'] = res.drop('non_zero_count',axis=1).drop('index',axis=1).apply(lambda r: r.index[r.ne(0)].to_list(), axis=1)
res["target"] = res.drop('index',axis=1).where(res == 0, 
                                  other=res.apply(lambda x: x.name), 
                                  axis=1).where(res != 0, 
                                                other="").apply(lambda row: ' '.join(row.values), axis=1).replace(' ','')
res

Unnamed: 0,index,0,1,2,3,4,5,6,7,8,...,41,42,43,44,45,46,47,48,49,target
0,3135,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,
1,4655,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,12 ...
2,22118,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,2 ...
3,23511,1,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0 ...
4,45,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9010,3523,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,
9011,24925,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,
9012,6327,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,
9013,530,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,


In [61]:
res.iloc[:, 1:].sum()

0                                                           750
1                                                           174
2                                                           706
3                                                           222
4                                                             0
5                                                             0
6                                                             0
7                                                             0
8                                                            23
9                                                             0
10                                                            0
11                                                           30
12                                                          743
13                                                            0
14                                                            0
15                                      

In [None]:
res["0"].value_counts()

In [70]:

res[["index"]+["target"]].to_csv(DATA_DIR / "submission.csv", index=False)