In [62]:
from langchain.schema import HumanMessage, SystemMessage
from langchain.chat_models.gigachat import GigaChat
import os
import pandas as pd
import tqdm
import re
import ast
import numpy as np
import time
import joblib
from collections import Counter
from sklearn.metrics import confusion_matrix, roc_auc_score, top_k_accuracy_score,\
                            f1_score, precision_score, recall_score, average_precision_score

In [7]:
GIGACHAT_TOKEN = os.environ.get('GIGACHAT_TOKEN')

In [6]:
df_test = pd.read_csv('final_markup/test.csv').sample(n=5000, random_state=999)

In [9]:
df_test.head()

Unnamed: 0,text,tokens,labels
10502,"Теперь известные как Алая Ведьма и Ртуть, Ванд...","['теперь', 'известные', 'как', 'алая', 'ведьма...","['o', 'o', 'o', 'o', 'o', 'o', ',', 'o', 'o', ..."
66487,Клуб официально зарегистрирован Управлением ку...,"['клуб', 'официально', 'зарегистрирован', 'упр...","['o', 'o', 'o', 'o', 'o', 'o', 'o', '.', 'o', ..."
54312,Через село проходит автодорога Терло — Самбор....,"['через', 'село', 'проходит', 'автодорога', 'т...","['o', 'o', 'o', 'o', 'o', '.', 'o', 'o', '.']"
4481,В 1958 году была восстановлена пятиярусная тэн...,"['в', '1958', 'году', 'была', 'восстановлена',...","['o', 'o', 'o', 'o', 'o', 'o', 'o', 'o', 'o', ..."
21,Висоза-ду-Сеара (порт. Viçosa do Ceará) — муни...,"['висозадусеара', 'порт', 'viçosa', 'do', 'cea...","['o', '.', 'o', 'o', 'o', 'o', 'o', ',', 'o', ..."


In [45]:
def gigachat_pred(text):
    chat = GigaChat(model='GigaChat:latest',
                    credentials=GIGACHAT_TOKEN,
                    verify_ssl_certs=False)

    messages = [SystemMessage(content="Расставь в тексте знаки препинания."),
                HumanMessage(content='Текст: \n' + text)]
    answer = chat.invoke(messages).content
    with open('gigachat_preds_test.txt', 'a') as file:
        file.write(answer + '\n')
        file.close()
    return answer

In [46]:
punctuation_signs = ['!', ',', '.', '...', ':', ';', '?']

def prepare_pred(text):
    text = re.sub('– ', '', text)
    text = re.sub('— ', '', text)
    text = re.sub('"', '', text)
    text = text.lower()
    text = re.sub('\s+', ' ', text)
    
    for sign in punctuation_signs:
        text = text.replace(sign + ' ', ' ')
        
    if text[-1] in punctuation_signs:
        text = text[:-1]
        
    prediction = gigachat_pred(text)
    tokens = [token for token in prediction.split(' ') if token != '']
    labels = []
    
    for token in tokens:
        if (len(token) > 3) & (token[-3:] == '...'):
            labels.append('...')
        elif token[-1] in punctuation_signs:
            labels.append(token[-1])
        else:
            labels.append('o')
    
    return labels

In [54]:
preds = []
true_labels = []

# процесс падал, поэтому продолжал генерировать с момента падения
for id_text in tqdm.tqdm(range(721+547, len(df_test.text.values))):
    time.sleep(3)
    prediction = prepare_pred(df_test.text.values[id_text])   
    needed_labels = ast.literal_eval(df_test.labels.values[id_text])
    
    # если модель выдала что-то не то по размеру
    if len(prediction) != len(needed_labels):        
        not_empty_token_idxs = ~(np.array(ast.literal_eval(df_test.tokens.values[id_text])) == '')
        needed_labels = np.array(needed_labels)[not_empty_token_idxs].tolist()
        
    if len(needed_labels) == len(prediction):
        true_labels += needed_labels
        preds += prediction

  1%|▌                                      | 50/3732 [03:50<4:31:04,  4.42s/it]Giga generation stopped with reason: blacklist
  1%|▌                                      | 54/3732 [04:07<4:22:53,  4.29s/it]Giga generation stopped with reason: blacklist
  2%|▋                                      | 71/3732 [05:25<4:29:03,  4.41s/it]Giga generation stopped with reason: blacklist
  2%|▉                                      | 92/3732 [06:58<4:17:38,  4.25s/it]Giga generation stopped with reason: blacklist
  3%|█                                     | 105/3732 [07:56<4:51:30,  4.82s/it]Giga generation stopped with reason: blacklist
  3%|█▏                                    | 121/3732 [09:09<4:19:52,  4.32s/it]Giga generation stopped with reason: blacklist
  4%|█▌                                    | 156/3732 [11:57<4:43:13,  4.75s/it]Giga generation stopped with reason: blacklist
  5%|█▊                                    | 183/3732 [14:11<4:39:19,  4.72s/it]Giga generation stopped with re

 39%|█████████████▌                     | 1448/3732 [1:52:08<2:48:12,  4.42s/it]Giga generation stopped with reason: blacklist
 39%|█████████████▋                     | 1458/3732 [1:52:59<3:26:20,  5.44s/it]Giga generation stopped with reason: blacklist
 39%|█████████████▋                     | 1459/3732 [1:53:03<3:07:16,  4.94s/it]Giga generation stopped with reason: blacklist
 39%|█████████████▊                     | 1474/3732 [1:54:08<2:53:50,  4.62s/it]Giga generation stopped with reason: blacklist
 40%|█████████████▉                     | 1488/3732 [1:55:18<3:31:55,  5.67s/it]Giga generation stopped with reason: blacklist
 40%|██████████████                     | 1497/3732 [1:56:05<2:55:58,  4.72s/it]Giga generation stopped with reason: blacklist
 41%|██████████████▎                    | 1526/3732 [1:58:25<2:51:04,  4.65s/it]Giga generation stopped with reason: blacklist
 42%|██████████████▋                    | 1560/3732 [2:00:55<2:33:00,  4.23s/it]Giga generation stopped with re

 78%|███████████████████████████▎       | 2906/3732 [3:45:11<1:03:00,  4.58s/it]Giga generation stopped with reason: blacklist
 78%|████████████████████████████▉        | 2922/3732 [3:46:22<59:42,  4.42s/it]Giga generation stopped with reason: blacklist
 79%|███████████████████████████▋       | 2950/3732 [3:48:32<1:00:17,  4.63s/it]Giga generation stopped with reason: blacklist
 79%|█████████████████████████████▎       | 2957/3732 [3:49:04<58:44,  4.55s/it]Giga generation stopped with reason: blacklist
 80%|█████████████████████████████▍       | 2971/3732 [3:50:09<57:07,  4.50s/it]Giga generation stopped with reason: blacklist
 82%|██████████████████████████████▎      | 3058/3732 [3:56:39<48:35,  4.33s/it]Giga generation stopped with reason: blacklist
 83%|██████████████████████████████▊      | 3103/3732 [4:00:11<49:49,  4.75s/it]Giga generation stopped with reason: blacklist
 84%|██████████████████████████████▉      | 3126/3732 [4:02:05<51:07,  5.06s/it]Giga generation stopped with re

In [55]:
len(true_labels), len(preds)

(93111, 93111)

In [56]:
def calc_metrics_no_proba(y_true, y_pred):
    print('Доля пробелов:', (y_true == 7).mean())
#     print('Accuracy:', top_k_accuracy_score(y_true, y_pred_proba, k=1))
#     print('Top-2 Accuracy:', top_k_accuracy_score(y_true, y_pred_proba, k=2))
#     rint('ROC-AUC (OVR):',roc_auc_score(y_true, y_pred_proba, multi_class='ovr'))
#     print('AUC-PR:',average_precision_score(y_true, y_pred_proba, average='weighted'))
    
    metrics = []
    metrics.append(list(dict(sorted(Counter(y_true).items())).values()))
    metrics.append(f1_score(y_true, y_pred, average=None))
    metrics.append(precision_score(y_true, y_pred, average=None, zero_division=0))
    metrics.append(recall_score(y_true, y_pred, average=None, zero_division=0))
#     metrics.append(roc_auc_score(y_true, y_pred_proba, multi_class='ovr', average=None))
#     metrics.append(average_precision_score(y_true, y_pred_proba, average=None))
    metrics_index = ['Count', 'F1-Score', 'Precision', 'Recall']
#                      'ROC-AUC', 'AUC-PR']
    df_metrics = pd.DataFrame(metrics, columns=le.classes_, index=metrics_index)
    
    return df_metrics

In [63]:
le = joblib.load('le.joblib')

y_pred = le.transform(preds)
y_true = le.transform(true_labels)

calc_metrics_no_proba(y_true, y_pred)

Доля пробелов: 0.8539592529346693


Unnamed: 0,!,",",.,...,:,;,?,o
Count,29.0,6836.0,6314.0,9.0,202.0,179.0,29.0,79513.0
F1-Score,0.060606,0.760332,0.804143,0.0,0.361963,0.0,0.490566,0.980632
Precision,0.25,0.726075,0.892733,0.0,0.475806,0.0,0.541667,0.976232
Recall,0.034483,0.797981,0.731549,0.0,0.292079,0.0,0.448276,0.985072
