In [1]:
import torch
import pandas as pd
import numpy as np
from transformers import pipeline
from datasets import load_dataset, Dataset, DatasetDict
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report, f1_score

In [2]:
pd.set_option('max_colwidth', 500)

Загрузим модель с HF

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
classifier = pipeline("text-classification",
                     'Maldopast/bge-ecom-trends-classifier',
                      device=device,
                      batch_size=16,
                      return_all_scores=True,
                     )



In [5]:
# train = pd.read_csv('data/prepared_train.csv', index_col=[0])
data = DatasetDict.load_from_disk('data/train_valid_split_prepared')
trend_names = pd.read_csv('data/trend_names.csv')['0'].to_list()
id2label = {k:v for k, v in enumerate(trend_names)}
label2id = {v:k for k, v in id2label.items()}
valid_df = data['validation'].to_pandas()
test_df = pd.read_csv('data/prepared_test.csv')
test_df['text'] = test_df['text'].fillna('')
test_data = Dataset.from_pandas(test_df)
sample_sub = pd.read_csv('data/sample_submission.csv')
train_common = pd.read_csv('data/train_common.csv').set_index('text')['id_labels'].to_dict()
train_common = {k:eval(v) for k,v in train_common.items()}
train_common = {k: sorted(v) for k, v in train_common.items()}

In [7]:
def get_valid_preds(classifier, data, text_col='text', trend_names=trend_names):
    
    """Функция для получения предсказаний на валидационной части датасета"""
    
    valid_preds = classifier(data['validation'][text_col])
    scores = []
    for row_pred in valid_preds:
        row_scores = []
        for d in row_pred:
            row_scores.append(d['score'])
        scores.append(row_scores)
    valid_preds_df = pd.DataFrame(scores, columns=trend_names)
    return valid_preds_df

In [8]:
def get_ts_preds_general(preds, ts=0.5):

    """Проходимся по массивам с предсказаниями, если они выше порога, добавляем в предсказанные лейблы"""

    selected_preds = []
    for p in preds:
        row_pred = []
        for d in p:
            if d['score'] > ts:
                row_pred.append(label2id[d['label']])
        if row_pred == []:
            max_p = 0
            idx = 0
            for d in p:
                if d['score'] > max_p:
                    max_p = d['score']
                    idx = label2id[d['label']]
            row_pred.append(idx)
        selected_preds.append(row_pred)

    return selected_preds

In [14]:
p = classifier('ghbd')[0]
probs = np.array([d['score'] for d in p]) > 0.55
probs

array([False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False, False, False, False, False,
       False, False, False, False, False])

In [36]:
probs_dict = {d['label']: d['score'] for d in p}

In [37]:
df = pd.DataFrame.from_dict(probs_dict, orient='index').reset_index()
df.columns = ['Тренд', 'Вероятность']
df = df.sort_values('Вероятность', ascending=False)

In [43]:
df['Вероятность'][:5]

19    0.109557
18    0.015978
0     0.009375
20    0.005738
2     0.005079
Name: Вероятность, dtype: float64

Сделаем предсказания

In [10]:
valid_preds_df = get_valid_preds(classifier, data, text_col='text_mark')
test_preds = classifier(test_data['text_mark'])

In [11]:
def check_common(preds):

    """Функция для проверки нахождения текста в обучающей выборке и проставление точно таких же меток, как в трейне"""
    
    for n, text in enumerate(test_df['text'].values):
        if text in train_common.keys():
            if preds[n] != train_common[text]:
                print(n, preds[n], train_common[text])
                preds[n] = train_common[text]
                print('*' * 30)
    return preds

In [6]:
def get_pred_df(preds, sub, ts=0.5, ind_dict=None, common=False):

    """Функция для получения датафрейма с предсказаниями"""

    sub_ = sub.copy()
    
    if ind_dict:
        preds = get_ts_preds(preds, ind_dict)
    else:
        preds = get_ts_preds_general(preds, ts)

    if common:
        preds = check_common(preds)

    sub_['target'] = pd.Series(preds).apply(lambda x: ' '.join(str(n) for n in x))
    display(sub_['target'].apply(lambda x: len(x.split(' '))).value_counts())
    return sub_, preds

Выберем порог 0.55 как более строгий, чем просто 0.5 (эвристика)

In [13]:
gen_df, prs = get_pred_df(test_preds, sample_sub, ts=0.55, common=True)

627 [20] [19]
******************************
718 [20] [19]
******************************
751 [20] [19]
******************************
885 [19] [18]
******************************
1008 [20] [19]
******************************
1172 [20] [19]
******************************
1957 [20] [19]
******************************
3096 [20] [19]
******************************
3142 [20] [19]
******************************
3461 [20] [19]
******************************
3776 [18] [19]
******************************
3944 [19] [18]
******************************
4207 [20] [19]
******************************
4385 [20] [19]
******************************
4844 [20] [19]
******************************
5080 [20] [19]
******************************
5244 [19] [18]
******************************
5610 [20] [19]
******************************
7219 [20] [19]
******************************
7253 [19] [18]
******************************
7329 [20] [19]
******************************
7904 [20] [19]
***********************

target
1    5956
2    2354
3     598
4     100
5       6
6       1
Name: count, dtype: int64

In [5]:
# загрузим сабмит, который дал лучший скор
best_df = pd.read_csv('data/deepvk_stw_text_mark_gen_055_fulltr_common.csv')
best_df['target'].apply(lambda x: len(x.split())).value_counts()

1    5956
2    2354
3     598
4     100
5       6
6       1
Name: target, dtype: int64

проверим, что полученные предсказания с помощью модели соответствуют предсказаниям лучшего сабмита

In [16]:
(gen_df == best_df).sum()

index     9015
target    9015
dtype: int64

In [19]:
gen_df.to_csv('data/best_submit.csv', index=False)